import json
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

__all__ = ('plot',)


def plottable(log: list[tuple[float, int]]):
    if log:
        return np.array(log).transpose()
    else:
        return np.array([[], []])


def format_params(params) -> str:
    match params:
        case dict():
            return '{' + ' '.join(f'{key}={format_params(value)}' for key, value in params.items()) + '}'
        case list():
            return f'[{" ".join(format_params(value) for value in params)}]'
        case _:
            return json.dumps(params)


def plot(fn: str):
    plt.rcParams['figure.figsize'] = [16, 9]
    plt.style.use('dark_background')
    plt.subplots_adjust(left=0.05, right=0.99, top=0.95, bottom=0.05)
    plt.xlabel('time (s)')
    plt.ylabel('concurrency (1)')

    with open(fn) as file:
        jsonified: dict[str] = json.load(file)

    title = fn
    if (params := jsonified.pop('params', None)) is not None:
        title += f' {format_params(params)}'
    plt.title(title)

    def logplot(plot_function, metric: str, **kwargs):
        if (log := jsonified.pop(metric, None)) is not None:
            plot_function(*plottable(log), label=f'{metric} ({len(log)})', **kwargs)

    logplot(plt.plot, 'DelayedResolver:sleep:concurrency')
    logplot(plt.plot, 'ActiveBinaryTree:add:concurrency')
    logplot(plt.plot, 'ActiveBinaryTree:contains:concurrency')
    logplot(plt.plot, 'FlowStandard:verify_subset:concurrency')
    logplot(plt.plot, 'Stack:list:concurrency')
    logplot(plt.scatter, 'ActiveBinaryTree:add:entry', c='tomato', zorder=100, s=.5)
    logplot(plt.scatter, 'ActiveBinaryTree:add:exit', c='gold', zorder=99, s=.5)

    plt.legend()
    plt.show()
    plt.clf()


if __name__ == '__main__':
    Path('trace').mkdir(exist_ok=True)
    if Path('trace/latest.json').exists():
        plot('trace/latest.json')
    for fp in list(Path('trace').glob('*.json')):
        if fp != Path('trace/latest.json'):
            plot(str(fp))