diff --git a/src/rstd/collections/tree.rs b/src/rstd/collections/tree.rs index bf74889..61c1826 100644 --- a/src/rstd/collections/tree.rs +++ b/src/rstd/collections/tree.rs @@ -68,6 +68,19 @@ pub struct Tree<'a, Ctx: Context<'a>, A: Mentionable<'a, Ctx>> { height: u64, } +impl<'a, Ctx: Context<'a>, A: Mentionable<'a, Ctx>> Tree<'a, Ctx, A> { + fn validate_height(&self) -> Result<(), HeightError> { + if let Nullable::Null(_) = self.node { + if self.height != 0 { + Err(HeightError::LeafHeight(self.height))? + } + } else if self.height == 0 { + Err(HeightError::NodeHeight)? + } + Ok(()) + } +} + #[derive(Clone)] pub struct NodeFactory(F); @@ -162,7 +175,9 @@ impl<'a, Ctx: Context<'a>, F: Factory<'a, Ctx>> Factory<'a, Ctx> for TreeFactory ) -> ParseResult<'a, Ctx, Self> { let node = self.0.deserialize(deserializer, resolver, addresses)?; let height = u64::a_deserialize(deserializer)?; - Ok(Tree { node, height }) + let tree = Tree { node, height }; + tree.validate_height()?; + Ok(tree) } fn extend( @@ -171,6 +186,7 @@ impl<'a, Ctx: Context<'a>, F: Factory<'a, Ctx>> Factory<'a, Ctx> for TreeFactory tail: &[u8], ) -> Result { mentionable.height = u64::a_extend(mentionable.height, tail)?; + mentionable.validate_height()?; Ok(mentionable) } }