import asyncio
import random
from contextlib import ExitStack
from typing import Any, Callable, Coroutine

from nacl.signing import SigningKey, VerifyKey

from plot import *
from rainbowadn.collection.trees.binary import *
from rainbowadn.core import *
from rainbowadn.flow13 import *
from rainbowadn.flow13 import FlowCoin
from rainbowadn.instrument import *
from rainbowadn.testing.resolvers import *
from rainbowadn.v13 import *
from trace_common import *


def get_instrumentations() -> list[Instrumentation]:
    sleep_cc = Concurrency(DelayedResolver, 'sleep')
    return [
        sleep_cc,
        Concurrency(ActiveBinaryTree, 'add'),
        Concurrency(ActiveBinaryTree, 'contains'),
        Concurrency(FlowStandard, 'verify_subset'),
    ]


minted: set[HashPoint[FlowCoin]] = set()
reverse: dict[VerifyKey, SigningKey] = {}


def _generate_subject() -> Subject:
    signing_key = SigningKey.generate()
    verify_key = signing_key.verify_key
    reverse[verify_key] = signing_key
    return Subject(verify_key)


async def _generate_transaction(
        subjects_min: int,
        subjects_max: int,
):
    in_coins: list[FlowCoin] = []
    keys: list[SigningKey] = []
    for _ in range(random.randint(subjects_min, subjects_max)):
        if not minted:
            break
        coin = await minted.pop().resolve()
        in_coins.append(coin)
        keys.append(reverse[(await coin.owner()).verify_key])
    transaction = await FlowTransaction.make(
        in_coins,
        [
            FlowCoinData.of(_generate_subject(), 0)
            for _ in range(random.randint(subjects_min, subjects_max))
        ],
        keys
    )
    for coinhp in await (await transaction.minted_reducer()).reduce(FlowIterate([])):
        minted.add(coinhp)
    return transaction


async def _generate(
        blocks: int,
        subjects_min: int,
        subjects_max: int,
        transactions_min: int,
        transactions_max: int,
) -> BankBlock:
    bank: BankBlock = BankBlock.empty()
    for _ in range(blocks):
        bank = await bank.add(
            await FlowCheque.make(
                [
                    await _generate_transaction(subjects_min, subjects_max)
                    for _ in range(random.randint(transactions_min, transactions_max))
                ]
            )
        )
    print('generated')
    return bank


async def _migrate(bank: BankBlock, params) -> BankBlock:
    assert_true(await bank.verify())
    bank = BankBlock(await get_dr(params['delay'], params['caching']).migrate_resolved(bank.reference))
    print('migrated')
    return bank


async def _instrument(process: Callable[[], Coroutine[Any, Any, None]]) -> list[Instrumentation]:
    with ExitStack() as estack:
        instrumentations: list[Instrumentation] = get_instrumentations()
        for stacked in instrumentations:
            stacked.enter(estack)
        try:
            await process()
        except Terminated:
            pass
    print('deinstrumentation (should be empty):', Instrumentation.deinstrumentation)
    print('instrumented')
    return instrumentations


async def _process(bank: BankBlock) -> None:
    with Measure(bank, 'verify') as measurement:
        assert_true(await bank.verify())
    print('measured', *(f'{t:.3f}' for t in measurement.log))


async def _report(bank: BankBlock):
    with open('trace/latest-report.txt', 'w') as file:
        report = ClassReport()
        await report.walk(bank.reference)
        file.write(report.format())
        print('reported')


async def _trace(params):
    set_gather_linear()
    bank = await _generate(
        params['blocks'],
        *params['subjects'],
        *params['transactions'],
    )
    await _report(bank)
    await _process(bank)
    bank = await _migrate(bank, params)
    set_gather_asyncio()
    with DeintrumentationSize(Instrumentation, 'deinstrument'):
        with Counter(DeintrumentationSize, 'instrument') as de_ctr:
            instrumentations = await _instrument(lambda: _process(bank))
            print(jsonify(de_ctr))
    print('traced')
    return instrumentations


async def trace(params):
    instrumentations = await _trace(params)
    fn = get_fn()
    jsonified = jsonify_list(instrumentations)
    dump(fn, jsonified | {'params': params})
    copy(fn)
    plot(fn)
    print('plotted')


preset_long = dict(blocks=64, subjects=(4, 8), transactions=(8, 16), caching=True, delay=.5)
preset_short = dict(blocks=16, subjects=(4, 8), transactions=(8, 16), caching=True, delay=.5)

if __name__ == '__main__':
    random.seed(659918)
    try:
        asyncio.run(
            trace(
                preset_long
            )
        )
    except KeyboardInterrupt:
        print('interrupted')