From d6d3b873de4e67cf41c6ce22c99f962c66e59f91 Mon Sep 17 00:00:00 2001 From: timofey Date: Sat, 24 Dec 2022 07:32:49 +0000 Subject: [PATCH] at-of isolation --- v6d2ctx/at_of.py | 25 +++++++++++++++++++++++++ v6d2ctx/context.py | 20 ++------------------ v6d2ctx/handle_args.py | 8 +++++--- v6d2ctx/handle_command.py | 8 +++++--- v6d2ctx/handle_content.py | 7 +++++-- 5 files changed, 42 insertions(+), 26 deletions(-) create mode 100644 v6d2ctx/at_of.py diff --git a/v6d2ctx/at_of.py b/v6d2ctx/at_of.py new file mode 100644 index 0000000..c29339d --- /dev/null +++ b/v6d2ctx/at_of.py @@ -0,0 +1,25 @@ +from typing import Callable, Generic, TypeVar + +from .context import Implicit + +K = TypeVar('K') +V = TypeVar('V') + + +class AtOf(Generic[K, V]): + def __call__(self) -> tuple[Callable[[K], Callable[[V], V]], Callable[[K], V]]: + bucket: dict[K, V] = {} + + def at(key: K) -> Callable[[V], V]: + def wrap(value: V) -> V: + bucket[key] = value + return value + return wrap + + def of(key: K) -> V: + try: + return bucket[key] + except IndexError: + raise Implicit + + return at, of diff --git a/v6d2ctx/context.py b/v6d2ctx/context.py index 47d1fca..8f70d2a 100644 --- a/v6d2ctx/context.py +++ b/v6d2ctx/context.py @@ -1,7 +1,7 @@ import asyncio import time from io import StringIO -from typing import Union, Optional, Callable, Awaitable +from typing import Awaitable, Callable, Optional, Union import discord @@ -47,16 +47,7 @@ def escape(s: str): return res.getvalue() -buckets: dict[str, dict[str, Callable[[Context, list[str]], Awaitable[None]]]] = {} - - -def at(bucket: str, name: str): - def wrap(f: Callable[[Context, list[str]], Awaitable[None]]): - buckets.setdefault(bucket, {})[name] = f - - return f - - return wrap +command_type = Callable[[Context, list[str]], Awaitable[None]] class Explicit(Exception): @@ -68,13 +59,6 @@ class Implicit(Exception): pass -def of(bucket: str, name: str) -> Callable[[Context, list[str]], Awaitable[None]]: - try: - return buckets[bucket][name] - except KeyError: - raise Implicit - - benchmarks: dict[str, dict[str, float]] = {} _t = time.perf_counter() diff --git a/v6d2ctx/handle_args.py b/v6d2ctx/handle_args.py index 8190504..0ad589c 100644 --- a/v6d2ctx/handle_args.py +++ b/v6d2ctx/handle_args.py @@ -1,17 +1,19 @@ +from typing import Callable + import discord -from v6d2ctx.context import Context, Implicit, Explicit +from v6d2ctx.context import Context, Explicit, Implicit, command_type from v6d2ctx.handle_command import handle_command -async def handle_args(message: discord.Message, args: list[str], client: discord.Client): +async def handle_args(of: Callable[[str], command_type], message: discord.Message, args: list[str], client: discord.Client): match args: case []: return case [command_name, *command_args]: ctx = Context(message, client) try: - await handle_command(ctx, command_name, command_args) + await handle_command(of, ctx, command_name, command_args) except Implicit: pass except Explicit as e: diff --git a/v6d2ctx/handle_command.py b/v6d2ctx/handle_command.py index 21263ed..9071e6b 100644 --- a/v6d2ctx/handle_command.py +++ b/v6d2ctx/handle_command.py @@ -1,5 +1,7 @@ -from v6d2ctx.context import Context, of +from typing import Awaitable, Callable + +from v6d2ctx.context import Context, command_type -async def handle_command(ctx: Context, name: str, args: list[str]) -> None: - await of('commands', name)(ctx, args) +async def handle_command(of: Callable[[str], command_type], ctx: Context, name: str, args: list[str]) -> None: + await of(name)(ctx, args) diff --git a/v6d2ctx/handle_content.py b/v6d2ctx/handle_content.py index 962e9c2..726dcf4 100644 --- a/v6d2ctx/handle_content.py +++ b/v6d2ctx/handle_content.py @@ -1,15 +1,18 @@ import shlex +from typing import Callable import discord from v6d2ctx.handle_args import handle_args +from v6d2ctx.context import command_type -async def handle_content(message: discord.Message, content: str, prefix: str, client: discord.Client): + +async def handle_content(of: Callable[[str], command_type], message: discord.Message, content: str, prefix: str, client: discord.Client): if message.author.bot: return if not content.startswith(prefix): return content = content.removeprefix(prefix) args = shlex.split(content) - await handle_args(message, args, client) + await handle_args(of, message, args, client)