diff --git a/main.py b/main.py index 9a285be..4e8fa71 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,36 @@ -import unittest +from rainbowadn.testing.instrument import * -if __name__ == '__main__': - unittest.main('rainbowadn.testing.test_all') + +class Print(Instrumentation): + def __init__(self, target, methodname: str, msg: str): + super().__init__(target, methodname) + self.msg = msg + + def instrument(self, method, *args, **kwargs): + print(self.msg, end=' ') + return method(*args, **kwargs) + + +class C: + @classmethod + def m(cls): + print('m') + + +with Print(Instrumentation, '__exit__', 'exit'): + print1 = Print(C, 'm', '1') + print2 = Print(C, 'm', '2') + print3 = Print(C, 'm', '3') + C.m() + print1.__enter__() + C.m() + print2.__enter__() + C.m() + print3.__enter__() + C.m() + print1.__exit__(None, None, None) + C.m() + print2.__exit__(None, None, None) + C.m() + print3.__exit__(None, None, None) + C.m() diff --git a/rainbowadn/testing/instrument/instrumentation.py b/rainbowadn/testing/instrument/instrumentation.py index 4e5d301..e5f9c27 100644 --- a/rainbowadn/testing/instrument/instrumentation.py +++ b/rainbowadn/testing/instrument/instrumentation.py @@ -22,13 +22,18 @@ class Instrumentation(Generic[IType]): raise NotImplementedError def __enter__(self: IType) -> IType: - assert not hasattr(self, 'method') + self: Instrumentation + assert not hasattr(self, '_method') + assert not hasattr(self, '_wrap') method = getattr(self.target, self.methodname) assert callable(method) self._method = method @functools.wraps(method) def wrap(*args, **kwargs): + nonlocal method + while method in self.deinstrumentation: + self._method = method = self.deinstrumentation.pop(method) return self.instrument(method, *args, **kwargs) self._wrap = wrap diff --git a/trace.py b/trace.py index bbe9847..52ec884 100644 --- a/trace.py +++ b/trace.py @@ -139,7 +139,10 @@ async def _instrument(bank: BankChain) -> list[Instrumentation]: class DeintrumentationSize(Instrumentation): def instrument(self, method, *args, **kwargs): - print('deinstrumentation size', len(self.deinstrumentation)) + print( + f'deinstrumentation size @ {target_str(self.target)}:{self.methodname}', + len(self.deinstrumentation) + ) return method(*args, **kwargs) @@ -152,7 +155,10 @@ async def _trace(): ) bank = await _migrate(bank) set_gather_asyncio() - instrumentations = await _instrument(bank) + with DeintrumentationSize(Instrumentation, 'deinstrument'): + with Counter(DeintrumentationSize, 'instrument') as de_ctr: + instrumentations = await _instrument(bank) + print(jsonify(de_ctr)) print('traced') return instrumentations