use std::fmt::Display; use super::{avl::*, bound::*, *}; pub trait BinaryTreesUnbalanced<'a>: BinaryTreesHeight<'a> { fn tree_of_with_height(&self, node: Self::Node, height: u64) -> BTWrap<'a, Self, Self::Tree>; fn balancing_error(&self, error: BalancingError) -> BTWrap<'a, Self, T>; fn balancing_bind( &self, ra: Result, f: impl FnOnce(A) -> BTWrap<'a, Self, B>, ) -> BTWrap<'a, Self, B> { match ra { Ok(a) => f(a), Err(e) => self.balancing_error(e), } } fn node_heights(&self, node: &Self::Node) -> (u64, u64) { let (tl, tr, _) = self.split(node); (self.height(&tl), self.height(&tr)) } } #[derive(Clone)] pub struct BalancedTrees(BT); impl BalancedTrees { pub fn new(bt: BT) -> Self { Self(bt) } } impl<'a, BT: FunctorContext<'a>> FunctorContext<'a> for BalancedTrees { type T = BT::T; } #[derive(Debug)] pub enum BalancingError { Height(HeightError), Balance(u64, u64), HeightOverflow, HeightMismatch { children: (u64, u64), parent: u64 }, } impl Display for BalancingError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Height(height_error) => write!(f, "invalid height: {height_error}."), Self::Balance(hl, hr) => write!(f, "unbalanced node: {hl} {hr}."), Self::HeightOverflow => write!(f, "tree height overflow."), Self::HeightMismatch { children, parent } => { write!(f, "child-parent height mismatch: {children:?}, {parent}.") } } } } fn balanced(hl: u64, hr: u64) -> Result<(), BalancingError> { if hl.abs_diff(hr) > 1 { Err(BalancingError::Balance(hl, hr)) } else { Ok(()) } } fn parent_height(hl: u64, hr: u64) -> Result { balanced(hl, hr)?; std::cmp::max(hl, hr) .checked_add(1) .ok_or(BalancingError::HeightOverflow) } fn matches_height(hl: u64, hr: u64, hp: u64) -> Result<(), BalancingError> { if parent_height(hl, hr)? == hp { Ok(()) } else { Err(BalancingError::HeightMismatch { children: (hl, hr), parent: hp, }) } } impl<'a, BT: BinaryTreesUnbalanced<'a>> BinaryTrees<'a> for BalancedTrees { type Node = BT::Node; type Reference = (BT::Reference, u64); type Tree = BT::Tree; type Key = BT::Key; type Comparator = BT::Comparator; type _Tm = Self::T; fn comparator(&self) -> &Self::Comparator { self.0.comparator() } fn split(&self, node: &Self::Node) -> Split<'a, Self> { self.0.split(node) } fn resolve(&self, (reference, hp): &Self::Reference) -> BTWrap<'a, Self, Self::Node> { let hp = *hp; let ctx = self.0.clone(); Self::bind(self.0.resolve(reference), move |node| { let (hl, hr) = ctx.node_heights(&node); ctx.balancing_bind(matches_height(hl, hr, hp), |_| Self::pure(node)) }) } fn equal(&self, (rl, hl): &Self::Reference, (rr, hr): &Self::Reference) -> bool { hl == hr && self.0.equal(rl, rr) } fn refer(&self, tree: &Self::Tree) -> Option { Some((self.0.refer(tree)?, self.0.height(tree))) } } impl<'a, BT: BinaryTreesUnbalanced<'a>> BinaryTreesTreeOf<'a> for BalancedTrees { fn tree_of(&self, node: Self::Node) -> BTWrap<'a, Self, Self::Tree> { let (hl, hr) = self.0.node_heights(&node); self.0.balancing_bind(parent_height(hl, hr), |height| { self.0.tree_of_with_height(node, height) }) } } impl<'a, BT: BinaryTreesUnbalanced<'a> + BinaryTreesEmpty<'a>> BinaryTreesEmpty<'a> for BalancedTrees { fn empty(&self) -> Self::Tree { self.0.empty() } fn split_key_empty( &self, tree: Self::Tree, key: Self::Key, ) -> BTWrap<'a, Self, KeySplit<'a, Self>> { self.0.split_key_empty(tree, key) } } impl<'a, BT: BinaryTreesUnbalanced<'a>> BinaryTreesHeight<'a> for BalancedTrees { fn height(&self, tree: &Self::Tree) -> u64 { self.0.height(tree) } fn height_error(&self, error: HeightError) -> BTWrap<'a, Self, T> { self.0.height_error(error) } } impl<'a, BT: BinaryTreesUnbalanced<'a> + BinaryTreesTryJoin<'a>> BinaryTreesTryJoin<'a> for BalancedTrees { fn try_join( &self, tl: Self::Tree, key: Self::Key, tr: Self::Tree, ) -> BTWrap<'a, Self, Self::Node> { let (hl, hr) = (self.0.height(&tl), self.0.height(&tr)); self.0 .balancing_bind(balanced(hl, hr), |_| self.0.try_join(tl, key, tr)) } } impl<'a, BT: BinaryTreesUnbalanced<'a> + BinaryTreesBindable<'a>> BinaryTreesBindable<'a> for BalancedTrees { fn bounds_error(&self, error: bounds::BoundsError) -> BTWrap<'a, Self, T> { self.0.bounds_error(error) } } impl<'a, BT: BinaryTreesUnbalanced<'a> + BinaryTreesEmpty<'a> + BinaryTreesTryJoin<'a>> BinaryTreesMutable<'a> for BalancedTrees { fn join_key( self, tl: Self::Tree, key: Self::Key, tr: Self::Tree, ) -> BTWrap<'a, Self, Self::Node> { self.join_key_balanced(tl, key, tr) } }