refactoring

This commit is contained in:
AF 2023-03-01 19:01:35 +00:00
parent 079820c413
commit 2945b39e0a
3 changed files with 61 additions and 58 deletions

41
app/cacheddictionary.py Normal file
View File

@ -0,0 +1,41 @@
__all__ = ("CachedEntry", "CachedDictionary")
import asyncio
import functools
from typing import Any, Callable, Coroutine, Generic, Hashable, TypeVar
T = TypeVar("T")
TKey = TypeVar("TKey", bound=Hashable)
class CachedEntry(Generic[T]):
def __init__(self, value: T, getter: Callable[[], Coroutine[Any, Any, T]]) -> None:
self.__value: T = value
self.__getter = getter
self.__task: asyncio.Future[T] = asyncio.Future()
self.__task.set_result(value)
async def _set(self) -> T:
self.__value = await self.__getter()
return self.__value
def get_nowait(self) -> T:
if self.__task.done():
self.__task = asyncio.create_task(self._set())
return self.__value
async def get(self) -> T:
if self.__task.done():
self.__task = asyncio.create_task(self._set())
return await self.__task
class CachedDictionary(Generic[TKey, T]):
def __init__(self, factory: Callable[[TKey], Coroutine[Any, Any, T]]) -> None:
self.__factory = factory
self.__entries: dict[TKey, CachedEntry[T]] = {}
def entry(self, key: TKey, default: T) -> CachedEntry[T]:
if key not in self.__entries:
self.__entries[key] = CachedEntry(default, functools.partial(self.__factory, key))
return self.__entries[key]

View File

@ -1,16 +1,17 @@
import asyncio
import functools
import sys
import urllib.parse
from contextlib import AsyncExitStack
from pathlib import Path
from typing import Any, Callable, Coroutine, Generic, Hashable, TypeVar
from typing import Any
import aiohttp
import nacl.hash
from aiohttp import web
from ptvp35 import DbFactory, DbInterface, KVJson
from ptvp35 import DbInterface, DbManager, KVJson
from .cacheddictionary import CachedDictionary
from .token_client import token_client
if sys.version_info < (3, 11):
from typing_extensions import Self
@ -18,50 +19,6 @@ else:
from typing import Self
T = TypeVar("T")
TKey = TypeVar("TKey", bound=Hashable)
class CachedEntry(Generic[T]):
def __init__(self, value: T, getter: Callable[[], Coroutine[Any, Any, T]]) -> None:
self.__value: T = value
self.__getter = getter
self.__task: asyncio.Future[T] = asyncio.Future()
self.__task.set_result(value)
async def _set(self) -> T:
self.__value = await self.__getter()
return self.__value
def get_nowait(self) -> T:
if self.__task.done():
self.__task = asyncio.create_task(self._set())
return self.__value
async def get(self) -> T:
if self.__task.done():
self.__task = asyncio.create_task(self._set())
return await self.__task
async def _token_client(access_token: str) -> dict | None:
headers = {"Authorization": f"Bearer {access_token}"}
async with aiohttp.ClientSession() as session:
async with session.get("https://discord.com/api/oauth2/@me", headers=headers) as response:
return await response.json()
class CachedDictionary(Generic[TKey, T]):
def __init__(self, factory: Callable[[TKey], Coroutine[Any, Any, T]]) -> None:
self.__factory = factory
self.__entries: dict[TKey, CachedEntry[T]] = {}
def entry(self, key: TKey, default: T) -> CachedEntry[T]:
if key not in self.__entries:
self.__entries[key] = CachedEntry(default, functools.partial(self.__factory, key))
return self.__entries[key]
class Context:
def __init__(self, db: DbInterface, token_clients: CachedDictionary[str, dict | None]) -> None:
self.db = db
@ -93,10 +50,8 @@ class Context:
match self.config("user-id"):
case int() as result:
return result
case None:
raise RuntimeError("user-id not set")
case _:
raise RuntimeError("inconsistent DB state")
raise RuntimeError("user-id not set")
def secret(self) -> str:
match self.config("secret"):
@ -234,16 +189,13 @@ def bytes_hash(b: bytes) -> str:
async def on_cleanup(app: web.Application) -> None:
async with app["es"]:
del app["db"]
del app["tc"]
del app["es"]
routes = web.RouteTableDef()
@routes.get("/")
async def home(_request: web.Request) -> web.StreamResponse:
return web.Response(text="sessionservice")
@routes.get("/sessiondata/")
async def sessiondata(request: web.Request) -> web.StreamResponse:
session = request.query.get("session")
@ -305,8 +257,8 @@ async def get_app() -> web.Application:
app = web.Application()
app.on_cleanup.append(on_cleanup)
async with AsyncExitStack() as es:
app["db"] = await es.enter_async_context(DbFactory(Path("/data/session.db"), kvfactory=KVJson()))
app["tc"] = CachedDictionary(_token_client)
app["db"] = await es.enter_async_context(DbManager(Path("/data/session.db"), kvfactory=KVJson()))
app["tc"] = CachedDictionary(token_client)
app["es"] = es.pop_all()
app.add_routes(routes)
return app

10
app/token_client.py Normal file
View File

@ -0,0 +1,10 @@
__all__ = ("token_client",)
import aiohttp
async def token_client(access_token: str) -> dict | None:
headers = {"Authorization": f"Bearer {access_token}"}
async with aiohttp.ClientSession() as session:
async with session.get("https://discord.com/api/oauth2/@me", headers=headers) as response:
return await response.json()