import json
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import numpy as np

__all__ = ("plot",)


def plottable(log: list[tuple[float, int]]):
    if log:
        return np.array(log).transpose()
    else:
        return np.array([[], []])


def format_params(params) -> str:
    match params:
        case dict():
            return "{" + " ".join(f"{key}={format_params(value)}" for key, value in params.items()) + "}"
        case list():
            return f'[{" ".join(format_params(value) for value in params)}]'
        case _:
            return json.dumps(params)


def plot(fn: str):
    plt.rcParams["figure.figsize"] = [16, 9]
    plt.style.use("dark_background")
    plt.subplots_adjust(left=0.05, right=0.99, top=0.95, bottom=0.05)
    plt.xlabel("time (s)")
    plt.ylabel("concurrency (1)")

    with open(fn) as file:
        jsonified: dict[str, Any] = json.load(file)

    title = fn
    if (params := jsonified.pop("params", None)) is not None:
        title += f" {format_params(params)}"
    plt.title(title)

    def logplot(plot_function, metric: str, **kwargs):
        if (log := jsonified.pop(metric, None)) is not None:
            plot_function(*plottable(log), label=f"{metric} ({len(log)})", **kwargs)

    logplot(plt.plot, "DelayedResolver:sleep:concurrency")
    logplot(plt.plot, "ActiveBinaryTree:add:concurrency")
    logplot(plt.plot, "ActiveBinaryTree:contains:concurrency")
    logplot(plt.plot, "FlowStandard:verify_subset:concurrency")
    logplot(plt.plot, "Stack:list:concurrency")
    logplot(plt.scatter, "ActiveBinaryTree:add:entry", c="tomato", zorder=100, s=0.5)
    logplot(plt.scatter, "ActiveBinaryTree:add:exit", c="gold", zorder=99, s=0.5)

    plt.legend()
    plt.savefig(f"{fn}.png")
    plt.show()
    plt.clf()


if __name__ == "__main__":
    Path("trace").mkdir(exist_ok=True)
    if Path("trace/latest.json").exists():
        plot("trace/latest.json")
    for fp in list(Path("trace").glob("*.json")):
        if fp != Path("trace/latest.json"):
            plot(str(fp))