refactoring
This commit is contained in:
parent
079820c413
commit
2945b39e0a
41
app/cacheddictionary.py
Normal file
41
app/cacheddictionary.py
Normal file
@ -0,0 +1,41 @@
|
||||
__all__ = ("CachedEntry", "CachedDictionary")
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from typing import Any, Callable, Coroutine, Generic, Hashable, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
TKey = TypeVar("TKey", bound=Hashable)
|
||||
|
||||
|
||||
class CachedEntry(Generic[T]):
|
||||
def __init__(self, value: T, getter: Callable[[], Coroutine[Any, Any, T]]) -> None:
|
||||
self.__value: T = value
|
||||
self.__getter = getter
|
||||
self.__task: asyncio.Future[T] = asyncio.Future()
|
||||
self.__task.set_result(value)
|
||||
|
||||
async def _set(self) -> T:
|
||||
self.__value = await self.__getter()
|
||||
return self.__value
|
||||
|
||||
def get_nowait(self) -> T:
|
||||
if self.__task.done():
|
||||
self.__task = asyncio.create_task(self._set())
|
||||
return self.__value
|
||||
|
||||
async def get(self) -> T:
|
||||
if self.__task.done():
|
||||
self.__task = asyncio.create_task(self._set())
|
||||
return await self.__task
|
||||
|
||||
|
||||
class CachedDictionary(Generic[TKey, T]):
|
||||
def __init__(self, factory: Callable[[TKey], Coroutine[Any, Any, T]]) -> None:
|
||||
self.__factory = factory
|
||||
self.__entries: dict[TKey, CachedEntry[T]] = {}
|
||||
|
||||
def entry(self, key: TKey, default: T) -> CachedEntry[T]:
|
||||
if key not in self.__entries:
|
||||
self.__entries[key] = CachedEntry(default, functools.partial(self.__factory, key))
|
||||
return self.__entries[key]
|
68
app/main.py
68
app/main.py
@ -1,16 +1,17 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import sys
|
||||
import urllib.parse
|
||||
from contextlib import AsyncExitStack
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Coroutine, Generic, Hashable, TypeVar
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
import nacl.hash
|
||||
from aiohttp import web
|
||||
|
||||
from ptvp35 import DbFactory, DbInterface, KVJson
|
||||
from ptvp35 import DbInterface, DbManager, KVJson
|
||||
|
||||
from .cacheddictionary import CachedDictionary
|
||||
from .token_client import token_client
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from typing_extensions import Self
|
||||
@ -18,50 +19,6 @@ else:
|
||||
from typing import Self
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
TKey = TypeVar("TKey", bound=Hashable)
|
||||
|
||||
|
||||
class CachedEntry(Generic[T]):
|
||||
def __init__(self, value: T, getter: Callable[[], Coroutine[Any, Any, T]]) -> None:
|
||||
self.__value: T = value
|
||||
self.__getter = getter
|
||||
self.__task: asyncio.Future[T] = asyncio.Future()
|
||||
self.__task.set_result(value)
|
||||
|
||||
async def _set(self) -> T:
|
||||
self.__value = await self.__getter()
|
||||
return self.__value
|
||||
|
||||
def get_nowait(self) -> T:
|
||||
if self.__task.done():
|
||||
self.__task = asyncio.create_task(self._set())
|
||||
return self.__value
|
||||
|
||||
async def get(self) -> T:
|
||||
if self.__task.done():
|
||||
self.__task = asyncio.create_task(self._set())
|
||||
return await self.__task
|
||||
|
||||
|
||||
async def _token_client(access_token: str) -> dict | None:
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get("https://discord.com/api/oauth2/@me", headers=headers) as response:
|
||||
return await response.json()
|
||||
|
||||
|
||||
class CachedDictionary(Generic[TKey, T]):
|
||||
def __init__(self, factory: Callable[[TKey], Coroutine[Any, Any, T]]) -> None:
|
||||
self.__factory = factory
|
||||
self.__entries: dict[TKey, CachedEntry[T]] = {}
|
||||
|
||||
def entry(self, key: TKey, default: T) -> CachedEntry[T]:
|
||||
if key not in self.__entries:
|
||||
self.__entries[key] = CachedEntry(default, functools.partial(self.__factory, key))
|
||||
return self.__entries[key]
|
||||
|
||||
|
||||
class Context:
|
||||
def __init__(self, db: DbInterface, token_clients: CachedDictionary[str, dict | None]) -> None:
|
||||
self.db = db
|
||||
@ -93,10 +50,8 @@ class Context:
|
||||
match self.config("user-id"):
|
||||
case int() as result:
|
||||
return result
|
||||
case None:
|
||||
raise RuntimeError("user-id not set")
|
||||
case _:
|
||||
raise RuntimeError("inconsistent DB state")
|
||||
raise RuntimeError("user-id not set")
|
||||
|
||||
def secret(self) -> str:
|
||||
match self.config("secret"):
|
||||
@ -234,16 +189,13 @@ def bytes_hash(b: bytes) -> str:
|
||||
async def on_cleanup(app: web.Application) -> None:
|
||||
async with app["es"]:
|
||||
del app["db"]
|
||||
del app["tc"]
|
||||
del app["es"]
|
||||
|
||||
|
||||
routes = web.RouteTableDef()
|
||||
|
||||
|
||||
@routes.get("/")
|
||||
async def home(_request: web.Request) -> web.StreamResponse:
|
||||
return web.Response(text="sessionservice")
|
||||
|
||||
|
||||
@routes.get("/sessiondata/")
|
||||
async def sessiondata(request: web.Request) -> web.StreamResponse:
|
||||
session = request.query.get("session")
|
||||
@ -305,8 +257,8 @@ async def get_app() -> web.Application:
|
||||
app = web.Application()
|
||||
app.on_cleanup.append(on_cleanup)
|
||||
async with AsyncExitStack() as es:
|
||||
app["db"] = await es.enter_async_context(DbFactory(Path("/data/session.db"), kvfactory=KVJson()))
|
||||
app["tc"] = CachedDictionary(_token_client)
|
||||
app["db"] = await es.enter_async_context(DbManager(Path("/data/session.db"), kvfactory=KVJson()))
|
||||
app["tc"] = CachedDictionary(token_client)
|
||||
app["es"] = es.pop_all()
|
||||
app.add_routes(routes)
|
||||
return app
|
||||
|
10
app/token_client.py
Normal file
10
app/token_client.py
Normal file
@ -0,0 +1,10 @@
|
||||
__all__ = ("token_client",)
|
||||
|
||||
import aiohttp
|
||||
|
||||
|
||||
async def token_client(access_token: str) -> dict | None:
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get("https://discord.com/api/oauth2/@me", headers=headers) as response:
|
||||
return await response.json()
|
Loading…
Reference in New Issue
Block a user