pub mod context;

use std::{error::Error, fmt::Display};

use crate::{
    flow::binary::*,
    rcore::*,
    rstd::{
        atomic::{au64::*, *},
        inlining::*,
        nullable::*,
        point::*,
    },
};

#[derive(Debug)]
pub enum TreeParseError<E> {
    HeightParse(IntParseError),
    Point(PointParseError),
    Key(E),
    HeightValue(HeightError),
}

impl<E> From<IntParseError> for TreeParseError<E> {
    fn from(value: IntParseError) -> Self {
        Self::HeightParse(value)
    }
}

impl<E> From<PointParseError> for TreeParseError<E> {
    fn from(value: PointParseError) -> Self {
        Self::Point(value)
    }
}

impl<E> From<HeightError> for TreeParseError<E> {
    fn from(value: HeightError) -> Self {
        Self::HeightValue(value)
    }
}

impl<E: Display> Display for TreeParseError<E> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::HeightParse(height_error) => {
                write!(f, "failed to parse tree height: {height_error}")
            }
            Self::Point(point_error) => {
                write!(f, "failed to parse node reference: {point_error}")
            }
            Self::Key(key_error) => {
                write!(f, "failed to parse node key: {key_error}")
            }
            Self::HeightValue(height_error) => {
                write!(f, "invalid height: {height_error}")
            }
        }
    }
}

impl<E: Error> Error for TreeParseError<E> {}

pub struct Node<'a, Ctx: Context<'a>, A: Mentionable<'a, Ctx>> {
    l: Tree<'a, Ctx, A>,
    r: Tree<'a, Ctx, A>,
    key: A,
}

pub struct Tree<'a, Ctx: Context<'a>, A: Mentionable<'a, Ctx>> {
    node: Nullable<'a, Ctx, Node<'a, Ctx, A>>,
    height: u64,
}

impl<'a, Ctx: Context<'a>, A: Mentionable<'a, Ctx>> Tree<'a, Ctx, A> {
    fn validate_height(&self) -> Result<(), HeightError> {
        if let Nullable::Null(_) = self.node {
            if self.height != 0 {
                Err(HeightError::LeafHeight(self.height))?
            }
        } else if self.height == 0 {
            Err(HeightError::NodeHeight)?
        }
        Ok(())
    }
}

#[derive(Clone)]
pub struct NodeFactory<F>(F);

#[derive(Clone)]
pub struct TreeFactory<F>(NullableFactory<NodeFactory<F>>);

impl<'a, Ctx: Context<'a>, A: Mentionable<'a, Ctx>> Serializable for Node<'a, Ctx, A> {
    fn serialize(&self, serializer: &mut dyn Serializer) {
        self.l.serialize(serializer);
        self.r.serialize(serializer);
        self.key.serialize(serializer);
    }
}

impl<'a, Ctx: Context<'a>, A: Mentionable<'a, Ctx>> Serializable for Tree<'a, Ctx, A> {
    fn serialize(&self, serializer: &mut dyn Serializer) {
        self.height.serialize(serializer);
        self.node.serialize(serializer);
    }
}

impl<'a, Ctx: Context<'a>, A: Mentionable<'a, Ctx>> Mentionable<'a, Ctx> for Node<'a, Ctx, A> {
    type Fctr = NodeFactory<A::Fctr>;

    fn factory(&self) -> Self::Fctr {
        NodeFactory(self.key.factory())
    }

    fn points_typed(&self, points: &mut impl TakesPoints<'a, Ctx>) {
        self.l.points_typed(points);
        self.r.points_typed(points);
        self.key.points_typed(points);
    }
}

impl<'a, Ctx: Context<'a>, A: Mentionable<'a, Ctx>> Mentionable<'a, Ctx> for Tree<'a, Ctx, A> {
    type Fctr = TreeFactory<A::Fctr>;

    fn factory(&self) -> Self::Fctr {
        TreeFactory(self.node.factory())
    }

    fn points_typed(&self, points: &mut impl TakesPoints<'a, Ctx>) {
        self.node.points_typed(points);
    }
}

impl<'a, Ctx: Context<'a>, F: Factory<'a, Ctx>> Factory<'a, Ctx> for NodeFactory<F> {
    type Mtbl = Node<'a, Ctx, F::Mtbl>;

    type ParseError = TreeParseError<F::ParseError>;

    fn deserialize(&self, inctx: impl InCtx<'a, Ctx>) -> ParseResult<'a, Ctx, Self> {
        let tree_factory = TreeFactory(NullableFactory::new(self.clone()));
        let (l, inctx) = tree_factory.ideserialize(inctx)?;
        let (r, inctx) = tree_factory.ideserialize(inctx)?;
        let key = self.0.deserialize(inctx).map_err(TreeParseError::Key)?;
        Ok(Node { l, r, key })
    }

    fn extend(&self, mut mentionable: Self::Mtbl, tail: &[u8]) -> ParseResult<'a, Ctx, Self> {
        mentionable.key = self
            .0
            .extend(mentionable.key, tail)
            .map_err(TreeParseError::Key)?;
        Ok(mentionable)
    }
}

impl<'a, Ctx: Context<'a>, F: Factory<'a, Ctx>> Factory<'a, Ctx> for TreeFactory<F> {
    type Mtbl = Tree<'a, Ctx, F::Mtbl>;

    type ParseError = TreeParseError<F::ParseError>;

    fn deserialize(&self, inctx: impl InCtx<'a, Ctx>) -> ParseResult<'a, Ctx, Self> {
        let (node, inctx) = self.0.ideserialize(inctx)?;
        let height = u64::a_deserialize(inctx)?;
        let tree = Tree { node, height };
        tree.validate_height()?;
        Ok(tree)
    }

    fn extend(&self, mut mentionable: Self::Mtbl, tail: &[u8]) -> ParseResult<'a, Ctx, Self> {
        mentionable.height = u64::a_extend(mentionable.height, tail)?;
        mentionable.validate_height()?;
        Ok(mentionable)
    }
}

impl<'a, Ctx: Context<'a>, F: Factory<'a, Ctx>> InlineableFactory<'a, Ctx> for TreeFactory<F> {
    fn extension_error(&self, tail: &[u8]) -> Self::ParseError {
        u64::a_extension_error(tail).into()
    }

    fn ideserialize<I: InCtx<'a, Ctx>>(&self, inctx: I) -> IParseResult<'a, Ctx, Self, I> {
        let (node, inctx) = self.0.ideserialize(inctx)?;
        let (height, inctx) = u64::a_ideserialize(inctx)?;
        let tree = Tree { node, height };
        tree.validate_height()?;
        Ok((tree, inctx))
    }
}

impl<'a, Ctx: Context<'a>, A: Mentionable<'a, Ctx> + Clone> Clone for Node<'a, Ctx, A> {
    fn clone(&self) -> Self {
        Self {
            l: self.l.clone(),
            r: self.r.clone(),
            key: self.key.clone(),
        }
    }
}

impl<'a, Ctx: Context<'a>, A: Mentionable<'a, Ctx> + Clone> Clone for Tree<'a, Ctx, A> {
    fn clone(&self) -> Self {
        Self {
            node: self.node.clone(),
            height: self.height,
        }
    }
}