refactoring

This commit is contained in:
AF 2023-03-01 19:01:35 +00:00
parent 079820c413
commit 2945b39e0a
3 changed files with 61 additions and 58 deletions

41
app/cacheddictionary.py Normal file
View File

@ -0,0 +1,41 @@
__all__ = ("CachedEntry", "CachedDictionary")
import asyncio
import functools
from typing import Any, Callable, Coroutine, Generic, Hashable, TypeVar
T = TypeVar("T")
TKey = TypeVar("TKey", bound=Hashable)
class CachedEntry(Generic[T]):
def __init__(self, value: T, getter: Callable[[], Coroutine[Any, Any, T]]) -> None:
self.__value: T = value
self.__getter = getter
self.__task: asyncio.Future[T] = asyncio.Future()
self.__task.set_result(value)
async def _set(self) -> T:
self.__value = await self.__getter()
return self.__value
def get_nowait(self) -> T:
if self.__task.done():
self.__task = asyncio.create_task(self._set())
return self.__value
async def get(self) -> T:
if self.__task.done():
self.__task = asyncio.create_task(self._set())
return await self.__task
class CachedDictionary(Generic[TKey, T]):
def __init__(self, factory: Callable[[TKey], Coroutine[Any, Any, T]]) -> None:
self.__factory = factory
self.__entries: dict[TKey, CachedEntry[T]] = {}
def entry(self, key: TKey, default: T) -> CachedEntry[T]:
if key not in self.__entries:
self.__entries[key] = CachedEntry(default, functools.partial(self.__factory, key))
return self.__entries[key]

View File

@ -1,16 +1,17 @@
import asyncio
import functools
import sys import sys
import urllib.parse import urllib.parse
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Coroutine, Generic, Hashable, TypeVar from typing import Any
import aiohttp import aiohttp
import nacl.hash import nacl.hash
from aiohttp import web from aiohttp import web
from ptvp35 import DbFactory, DbInterface, KVJson from ptvp35 import DbInterface, DbManager, KVJson
from .cacheddictionary import CachedDictionary
from .token_client import token_client
if sys.version_info < (3, 11): if sys.version_info < (3, 11):
from typing_extensions import Self from typing_extensions import Self
@ -18,50 +19,6 @@ else:
from typing import Self from typing import Self
T = TypeVar("T")
TKey = TypeVar("TKey", bound=Hashable)
class CachedEntry(Generic[T]):
def __init__(self, value: T, getter: Callable[[], Coroutine[Any, Any, T]]) -> None:
self.__value: T = value
self.__getter = getter
self.__task: asyncio.Future[T] = asyncio.Future()
self.__task.set_result(value)
async def _set(self) -> T:
self.__value = await self.__getter()
return self.__value
def get_nowait(self) -> T:
if self.__task.done():
self.__task = asyncio.create_task(self._set())
return self.__value
async def get(self) -> T:
if self.__task.done():
self.__task = asyncio.create_task(self._set())
return await self.__task
async def _token_client(access_token: str) -> dict | None:
headers = {"Authorization": f"Bearer {access_token}"}
async with aiohttp.ClientSession() as session:
async with session.get("https://discord.com/api/oauth2/@me", headers=headers) as response:
return await response.json()
class CachedDictionary(Generic[TKey, T]):
def __init__(self, factory: Callable[[TKey], Coroutine[Any, Any, T]]) -> None:
self.__factory = factory
self.__entries: dict[TKey, CachedEntry[T]] = {}
def entry(self, key: TKey, default: T) -> CachedEntry[T]:
if key not in self.__entries:
self.__entries[key] = CachedEntry(default, functools.partial(self.__factory, key))
return self.__entries[key]
class Context: class Context:
def __init__(self, db: DbInterface, token_clients: CachedDictionary[str, dict | None]) -> None: def __init__(self, db: DbInterface, token_clients: CachedDictionary[str, dict | None]) -> None:
self.db = db self.db = db
@ -93,10 +50,8 @@ class Context:
match self.config("user-id"): match self.config("user-id"):
case int() as result: case int() as result:
return result return result
case None:
raise RuntimeError("user-id not set")
case _: case _:
raise RuntimeError("inconsistent DB state") raise RuntimeError("user-id not set")
def secret(self) -> str: def secret(self) -> str:
match self.config("secret"): match self.config("secret"):
@ -234,16 +189,13 @@ def bytes_hash(b: bytes) -> str:
async def on_cleanup(app: web.Application) -> None: async def on_cleanup(app: web.Application) -> None:
async with app["es"]: async with app["es"]:
del app["db"] del app["db"]
del app["tc"]
del app["es"]
routes = web.RouteTableDef() routes = web.RouteTableDef()
@routes.get("/")
async def home(_request: web.Request) -> web.StreamResponse:
return web.Response(text="sessionservice")
@routes.get("/sessiondata/") @routes.get("/sessiondata/")
async def sessiondata(request: web.Request) -> web.StreamResponse: async def sessiondata(request: web.Request) -> web.StreamResponse:
session = request.query.get("session") session = request.query.get("session")
@ -305,8 +257,8 @@ async def get_app() -> web.Application:
app = web.Application() app = web.Application()
app.on_cleanup.append(on_cleanup) app.on_cleanup.append(on_cleanup)
async with AsyncExitStack() as es: async with AsyncExitStack() as es:
app["db"] = await es.enter_async_context(DbFactory(Path("/data/session.db"), kvfactory=KVJson())) app["db"] = await es.enter_async_context(DbManager(Path("/data/session.db"), kvfactory=KVJson()))
app["tc"] = CachedDictionary(_token_client) app["tc"] = CachedDictionary(token_client)
app["es"] = es.pop_all() app["es"] = es.pop_all()
app.add_routes(routes) app.add_routes(routes)
return app return app

10
app/token_client.py Normal file
View File

@ -0,0 +1,10 @@
__all__ = ("token_client",)
import aiohttp
async def token_client(access_token: str) -> dict | None:
headers = {"Authorization": f"Bearer {access_token}"}
async with aiohttp.ClientSession() as session:
async with session.get("https://discord.com/api/oauth2/@me", headers=headers) as response:
return await response.json()