From 2945b39e0a237418c5665fcc42aa442edac00781 Mon Sep 17 00:00:00 2001 From: timofey Date: Wed, 1 Mar 2023 19:01:35 +0000 Subject: [PATCH] refactoring --- app/cacheddictionary.py | 41 +++++++++++++++++++++++++ app/main.py | 68 ++++++----------------------------------- app/token_client.py | 10 ++++++ 3 files changed, 61 insertions(+), 58 deletions(-) create mode 100644 app/cacheddictionary.py create mode 100644 app/token_client.py diff --git a/app/cacheddictionary.py b/app/cacheddictionary.py new file mode 100644 index 0000000..a646243 --- /dev/null +++ b/app/cacheddictionary.py @@ -0,0 +1,41 @@ +__all__ = ("CachedEntry", "CachedDictionary") + +import asyncio +import functools +from typing import Any, Callable, Coroutine, Generic, Hashable, TypeVar + +T = TypeVar("T") +TKey = TypeVar("TKey", bound=Hashable) + + +class CachedEntry(Generic[T]): + def __init__(self, value: T, getter: Callable[[], Coroutine[Any, Any, T]]) -> None: + self.__value: T = value + self.__getter = getter + self.__task: asyncio.Future[T] = asyncio.Future() + self.__task.set_result(value) + + async def _set(self) -> T: + self.__value = await self.__getter() + return self.__value + + def get_nowait(self) -> T: + if self.__task.done(): + self.__task = asyncio.create_task(self._set()) + return self.__value + + async def get(self) -> T: + if self.__task.done(): + self.__task = asyncio.create_task(self._set()) + return await self.__task + + +class CachedDictionary(Generic[TKey, T]): + def __init__(self, factory: Callable[[TKey], Coroutine[Any, Any, T]]) -> None: + self.__factory = factory + self.__entries: dict[TKey, CachedEntry[T]] = {} + + def entry(self, key: TKey, default: T) -> CachedEntry[T]: + if key not in self.__entries: + self.__entries[key] = CachedEntry(default, functools.partial(self.__factory, key)) + return self.__entries[key] diff --git a/app/main.py b/app/main.py index 194e295..e63257a 100644 --- a/app/main.py +++ b/app/main.py @@ -1,16 +1,17 @@ -import asyncio -import functools import sys import urllib.parse from contextlib import AsyncExitStack from pathlib import Path -from typing import Any, Callable, Coroutine, Generic, Hashable, TypeVar +from typing import Any import aiohttp import nacl.hash from aiohttp import web -from ptvp35 import DbFactory, DbInterface, KVJson +from ptvp35 import DbInterface, DbManager, KVJson + +from .cacheddictionary import CachedDictionary +from .token_client import token_client if sys.version_info < (3, 11): from typing_extensions import Self @@ -18,50 +19,6 @@ else: from typing import Self -T = TypeVar("T") -TKey = TypeVar("TKey", bound=Hashable) - - -class CachedEntry(Generic[T]): - def __init__(self, value: T, getter: Callable[[], Coroutine[Any, Any, T]]) -> None: - self.__value: T = value - self.__getter = getter - self.__task: asyncio.Future[T] = asyncio.Future() - self.__task.set_result(value) - - async def _set(self) -> T: - self.__value = await self.__getter() - return self.__value - - def get_nowait(self) -> T: - if self.__task.done(): - self.__task = asyncio.create_task(self._set()) - return self.__value - - async def get(self) -> T: - if self.__task.done(): - self.__task = asyncio.create_task(self._set()) - return await self.__task - - -async def _token_client(access_token: str) -> dict | None: - headers = {"Authorization": f"Bearer {access_token}"} - async with aiohttp.ClientSession() as session: - async with session.get("https://discord.com/api/oauth2/@me", headers=headers) as response: - return await response.json() - - -class CachedDictionary(Generic[TKey, T]): - def __init__(self, factory: Callable[[TKey], Coroutine[Any, Any, T]]) -> None: - self.__factory = factory - self.__entries: dict[TKey, CachedEntry[T]] = {} - - def entry(self, key: TKey, default: T) -> CachedEntry[T]: - if key not in self.__entries: - self.__entries[key] = CachedEntry(default, functools.partial(self.__factory, key)) - return self.__entries[key] - - class Context: def __init__(self, db: DbInterface, token_clients: CachedDictionary[str, dict | None]) -> None: self.db = db @@ -93,10 +50,8 @@ class Context: match self.config("user-id"): case int() as result: return result - case None: - raise RuntimeError("user-id not set") case _: - raise RuntimeError("inconsistent DB state") + raise RuntimeError("user-id not set") def secret(self) -> str: match self.config("secret"): @@ -234,16 +189,13 @@ def bytes_hash(b: bytes) -> str: async def on_cleanup(app: web.Application) -> None: async with app["es"]: del app["db"] + del app["tc"] + del app["es"] routes = web.RouteTableDef() -@routes.get("/") -async def home(_request: web.Request) -> web.StreamResponse: - return web.Response(text="sessionservice") - - @routes.get("/sessiondata/") async def sessiondata(request: web.Request) -> web.StreamResponse: session = request.query.get("session") @@ -305,8 +257,8 @@ async def get_app() -> web.Application: app = web.Application() app.on_cleanup.append(on_cleanup) async with AsyncExitStack() as es: - app["db"] = await es.enter_async_context(DbFactory(Path("/data/session.db"), kvfactory=KVJson())) - app["tc"] = CachedDictionary(_token_client) + app["db"] = await es.enter_async_context(DbManager(Path("/data/session.db"), kvfactory=KVJson())) + app["tc"] = CachedDictionary(token_client) app["es"] = es.pop_all() app.add_routes(routes) return app diff --git a/app/token_client.py b/app/token_client.py new file mode 100644 index 0000000..a15a458 --- /dev/null +++ b/app/token_client.py @@ -0,0 +1,10 @@ +__all__ = ("token_client",) + +import aiohttp + + +async def token_client(access_token: str) -> dict | None: + headers = {"Authorization": f"Bearer {access_token}"} + async with aiohttp.ClientSession() as session: + async with session.get("https://discord.com/api/oauth2/@me", headers=headers) as response: + return await response.json()