use super::*; trait IntoTree<'a, Trees: BinaryTrees<'a>> { fn into_tree(self, trees: &Trees) -> BTWrap<'a, Trees, Trees::Tree>; } struct T(T); impl<'a, Trees: BinaryTrees<'a>> IntoTree<'a, Trees> for T { fn into_tree(self, _trees: &Trees) -> BTWrap<'a, Trees, Trees::Tree> { Trees::pure(self.0) } } impl<'a, Trees: BinaryTreesAvl<'a>, L: IntoTree<'a, Trees>, R: IntoTree<'a, Trees>> IntoTree<'a, Trees> for (L, Trees::Key, R) { fn into_tree(self, trees: &Trees) -> BTWrap<'a, Trees, Trees::Tree> { let trees = trees.clone(); Trees::T::bind2( self.0.into_tree(&trees), self.2.into_tree(&trees), move |tl, tr| trees.join_key_balanced_tree(tl, self.1, tr), ) } } trait BinaryTreesAvlExt<'a>: BinaryTreesAvl<'a> { fn make_node( self, itl: impl IntoTree<'a, Self>, key: Self::Key, itr: impl IntoTree<'a, Self>, ) -> BTWrap<'a, Self, Self::Node> { Self::T::bind2(itl.into_tree(&self), itr.into_tree(&self), |tl, tr| { self.join_key_balanced(tl, key, tr) }) } } impl<'a, Trees: BinaryTreesAvl<'a>> BinaryTreesAvlExt<'a> for Trees {} pub trait BinaryTreesAvl<'a>: BinaryTreesHeight<'a> + BinaryTreesTreeOf<'a> + BinaryTreesTryJoin<'a> { fn assume_node(&self, tree: &Self::Tree) -> BTWrap<'a, Self, Self::Node> { match self.refer(tree) { Some(reference) => self.resolve(&reference), None => self.height_error(HeightError::LeafHeight(self.height(tree))), } } fn assume_bind( self, tree: &Self::Tree, f: impl 'a + FnOnce(Self, Self::Tree, Self::Key, Self::Tree) -> BTWrap<'a, Self, T>, ) -> BTWrap<'a, Self, T> { Self::bind(self.assume_node(tree), move |node| { let (tl, tr, key) = self.split(&node); f(self, tl, key, tr) }) } fn join_key_balanced_tree( &self, tl: Self::Tree, key: Self::Key, tr: Self::Tree, ) -> BTWrap<'a, Self, Self::Tree> { self.clone() .tree_bind(self.clone().join_key_balanced(tl, key, tr)) } fn join_key_balanced( self, tl: Self::Tree, key: Self::Key, tr: Self::Tree, ) -> BTWrap<'a, Self, Self::Node> { let (hl, hr) = (self.height(&tl), self.height(&tr)); match (hl.saturating_sub(hr), hr.saturating_sub(hl)) { (0, 0) | (0, 1) | (1, 0) => self.try_join(tl, key, tr), (0, _) => self.assume_bind(&tr, |ctx, trl, kr, trr| { let (rlh, rrh) = (ctx.height(&trl), ctx.height(&trr)); if rlh > rrh { ctx.assume_bind(&trl, |ctx, trll, krl, trlr| { ctx.make_node((T(tl), key, T(trll)), krl, (T(trlr), kr, T(trr))) }) } else { ctx.make_node((T(tl), key, T(trl)), kr, T(trr)) } }), (_, 0) => self.assume_bind(&tl, |ctx, tll, kl, tlr| { let (hll, hlr) = (ctx.height(&tll), ctx.height(&tlr)); if hll < hlr { ctx.assume_bind(&tlr, |ctx, tlrl, klr, tlrr| { ctx.make_node((T(tll), kl, T(tlrl)), klr, (T(tlrr), key, T(tr))) }) } else { ctx.make_node(T(tll), kl, (T(tlr), key, T(tr))) } }), (_, _) => unreachable!(), } } } impl<'a, BT: BinaryTreesHeight<'a> + BinaryTreesTreeOf<'a> + BinaryTreesTryJoin<'a>> BinaryTreesAvl<'a> for BT { }