diff --git a/src/rstd/collections.rs b/src/rstd/collections.rs index 579e1a6..73697df 100644 --- a/src/rstd/collections.rs +++ b/src/rstd/collections.rs @@ -1,5 +1,6 @@ //! Standard generic collections. +pub mod avl; pub mod pair; pub mod rbtree; pub mod stack; diff --git a/src/rstd/collections/avl.rs b/src/rstd/collections/avl.rs new file mode 100644 index 0000000..8413dc9 --- /dev/null +++ b/src/rstd/collections/avl.rs @@ -0,0 +1,165 @@ +use std::{error::Error, fmt::Display, rc::Rc}; + +use crate::rcore::*; +use crate::rstd::{ + atomic::{au64::*, *}, + nullable::*, + point::*, +}; + +#[derive(Debug)] +pub enum TreeParseError { + Int(IntParseError), + Point(PointParseError), + Key(E), + LeafHeight(u64), + Balance(u64, u64), +} + +impl From for TreeParseError { + fn from(value: IntParseError) -> Self { + Self::Int(value) + } +} + +impl From for TreeParseError { + fn from(value: PointParseError) -> Self { + Self::Point(value) + } +} + +impl Display for TreeParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Int(int_error) => { + f.write_fmt(format_args!("failed to parse AVL tree height: {int_error}")) + } + Self::Point(point_error) => f.write_fmt(format_args!( + "failed to parse AVL node reference: {point_error}" + )), + Self::Key(key_error) => { + f.write_fmt(format_args!("failed to parse AVL node key: {key_error}")) + } + Self::LeafHeight(height) => { + f.write_fmt(format_args!("invalid AVL leaf height: {height}!=0.")) + } + Self::Balance(lh, rh) => f.write_fmt(format_args!("unbalanced AVL node: {lh} {rh}.")), + } + } +} + +impl Error for TreeParseError {} + +struct AvlNode<'a, Ctx: Context<'a>, A: Mentionable<'a, Ctx>> { + l: AvlTree<'a, Ctx, A>, + r: AvlTree<'a, Ctx, A>, + key: Rc, +} + +struct AvlTree<'a, Ctx: Context<'a>, A: Mentionable<'a, Ctx>> { + node: Nullable<'a, Ctx, AvlNode<'a, Ctx, A>>, + height: u64, +} + +#[derive(Clone)] +struct AvlNodeFactory(F); + +#[derive(Clone)] +struct AvlTreeFactory(NullableFactory>); + +impl<'a, Ctx: Context<'a>, A: Mentionable<'a, Ctx>> Serializable for AvlNode<'a, Ctx, A> { + fn serialize(&self, serializer: &mut dyn Serializer) { + self.l.serialize(serializer); + self.r.serialize(serializer); + self.key.serialize(serializer); + } +} + +impl<'a, Ctx: Context<'a>, A: Mentionable<'a, Ctx>> Serializable for AvlTree<'a, Ctx, A> { + fn serialize(&self, serializer: &mut dyn Serializer) { + self.height.serialize(serializer); + self.node.serialize(serializer); + } +} + +impl<'a, Ctx: Context<'a>, A: Mentionable<'a, Ctx>> Mentionable<'a, Ctx> for AvlNode<'a, Ctx, A> { + type Fctr = AvlNodeFactory; + + fn factory(&self) -> Self::Fctr { + AvlNodeFactory(self.key.factory()) + } + + fn points_typed(&self, points: &mut impl TakesPoints<'a, Ctx>) { + self.l.points_typed(points); + self.r.points_typed(points); + self.key.points_typed(points); + } +} + +impl<'a, Ctx: Context<'a>, A: Mentionable<'a, Ctx>> Mentionable<'a, Ctx> for AvlTree<'a, Ctx, A> { + type Fctr = AvlTreeFactory; + + fn factory(&self) -> Self::Fctr { + AvlTreeFactory(self.node.factory()) + } + + fn points_typed(&self, points: &mut impl TakesPoints<'a, Ctx>) { + self.node.points_typed(points); + } +} + +impl<'a, Ctx: Context<'a>, F: Factory<'a, Ctx>> Factory<'a, Ctx> for AvlNodeFactory { + type Mtbl = AvlNode<'a, Ctx, F::Mtbl>; + + type ParseError = TreeParseError; + + fn deserialize( + &self, + deserializer: &mut dyn Deserializer, + resolver: Rc>, + addresses: &mut Addresses, + ) -> ParseResult<'a, Ctx, Self> { + let tree_factory = AvlTreeFactory(NullableFactory::new(self.clone())); + let l = tree_factory.deserialize(deserializer, resolver.clone(), addresses)?; + let r = tree_factory.deserialize(deserializer, resolver.clone(), addresses)?; + if std::cmp::max(l.height, r.height) - std::cmp::min(l.height, r.height) > 1 { + return Err(TreeParseError::Balance(l.height, r.height)); + } + let key = self + .0 + .deserialize(deserializer, resolver.clone(), addresses) + .map_err(TreeParseError::Key)? + .into(); + Ok(AvlNode { l, r, key }) + } + + fn unexpected_tail(&self, tail: &[u8]) -> Self::ParseError { + TreeParseError::Key(self.0.unexpected_tail(tail)) + } +} + +impl<'a, Ctx: Context<'a>, F: Factory<'a, Ctx>> Factory<'a, Ctx> for AvlTreeFactory { + type Mtbl = AvlTree<'a, Ctx, F::Mtbl>; + + type ParseError = TreeParseError; + + fn deserialize( + &self, + deserializer: &mut dyn Deserializer, + resolver: Rc>, + addresses: &mut Addresses, + ) -> ParseResult<'a, Ctx, Self> { + let node = self.0.deserialize(deserializer, resolver, addresses)?; + let height = u64::a_deserialize(deserializer)?; + if let Nullable::Null(_) = node { + if height != 0 { + return Err(TreeParseError::LeafHeight(height)); + } + } + Ok(AvlTree { node, height }) + } + + fn unexpected_tail(&self, tail: &[u8]) -> Self::ParseError { + u64::a_unexpected_tail(tail).into() + } +} diff --git a/src/rstd/collections/rbtree.rs b/src/rstd/collections/rbtree.rs index 1d7c4d3..f76a71e 100644 --- a/src/rstd/collections/rbtree.rs +++ b/src/rstd/collections/rbtree.rs @@ -22,28 +22,27 @@ pub enum TreeParseError { impl From for TreeParseError { fn from(value: BooleanParseError) -> Self { - TreeParseError::Boolean(value) + Self::Boolean(value) } } impl From for TreeParseError { fn from(value: PointParseError) -> Self { - TreeParseError::Point(value) + Self::Point(value) } } impl Display for TreeParseError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - TreeParseError::Boolean(boolean_error) => { - f.write_fmt(format_args!("failed to parse RB flag: {}", boolean_error)) + Self::Boolean(boolean_error) => { + f.write_fmt(format_args!("failed to parse RB flag: {boolean_error}")) } - TreeParseError::Point(point_error) => f.write_fmt(format_args!( - "failed to parse RB tree reference: {}", - point_error + Self::Point(point_error) => f.write_fmt(format_args!( + "failed to parse RB tree reference: {point_error}" )), - TreeParseError::Key(key_error) => { - f.write_fmt(format_args!("failed to parse RB tree key: {}", key_error)) + Self::Key(key_error) => { + f.write_fmt(format_args!("failed to parse RB tree key: {key_error}")) } } }