refactoring
This commit is contained in:
parent
079820c413
commit
2945b39e0a
41
app/cacheddictionary.py
Normal file
41
app/cacheddictionary.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
__all__ = ("CachedEntry", "CachedDictionary")
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import functools
|
||||||
|
from typing import Any, Callable, Coroutine, Generic, Hashable, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
TKey = TypeVar("TKey", bound=Hashable)
|
||||||
|
|
||||||
|
|
||||||
|
class CachedEntry(Generic[T]):
|
||||||
|
def __init__(self, value: T, getter: Callable[[], Coroutine[Any, Any, T]]) -> None:
|
||||||
|
self.__value: T = value
|
||||||
|
self.__getter = getter
|
||||||
|
self.__task: asyncio.Future[T] = asyncio.Future()
|
||||||
|
self.__task.set_result(value)
|
||||||
|
|
||||||
|
async def _set(self) -> T:
|
||||||
|
self.__value = await self.__getter()
|
||||||
|
return self.__value
|
||||||
|
|
||||||
|
def get_nowait(self) -> T:
|
||||||
|
if self.__task.done():
|
||||||
|
self.__task = asyncio.create_task(self._set())
|
||||||
|
return self.__value
|
||||||
|
|
||||||
|
async def get(self) -> T:
|
||||||
|
if self.__task.done():
|
||||||
|
self.__task = asyncio.create_task(self._set())
|
||||||
|
return await self.__task
|
||||||
|
|
||||||
|
|
||||||
|
class CachedDictionary(Generic[TKey, T]):
|
||||||
|
def __init__(self, factory: Callable[[TKey], Coroutine[Any, Any, T]]) -> None:
|
||||||
|
self.__factory = factory
|
||||||
|
self.__entries: dict[TKey, CachedEntry[T]] = {}
|
||||||
|
|
||||||
|
def entry(self, key: TKey, default: T) -> CachedEntry[T]:
|
||||||
|
if key not in self.__entries:
|
||||||
|
self.__entries[key] = CachedEntry(default, functools.partial(self.__factory, key))
|
||||||
|
return self.__entries[key]
|
68
app/main.py
68
app/main.py
@ -1,16 +1,17 @@
|
|||||||
import asyncio
|
|
||||||
import functools
|
|
||||||
import sys
|
import sys
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Coroutine, Generic, Hashable, TypeVar
|
from typing import Any
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import nacl.hash
|
import nacl.hash
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
from ptvp35 import DbFactory, DbInterface, KVJson
|
from ptvp35 import DbInterface, DbManager, KVJson
|
||||||
|
|
||||||
|
from .cacheddictionary import CachedDictionary
|
||||||
|
from .token_client import token_client
|
||||||
|
|
||||||
if sys.version_info < (3, 11):
|
if sys.version_info < (3, 11):
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
@ -18,50 +19,6 @@ else:
|
|||||||
from typing import Self
|
from typing import Self
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
TKey = TypeVar("TKey", bound=Hashable)
|
|
||||||
|
|
||||||
|
|
||||||
class CachedEntry(Generic[T]):
|
|
||||||
def __init__(self, value: T, getter: Callable[[], Coroutine[Any, Any, T]]) -> None:
|
|
||||||
self.__value: T = value
|
|
||||||
self.__getter = getter
|
|
||||||
self.__task: asyncio.Future[T] = asyncio.Future()
|
|
||||||
self.__task.set_result(value)
|
|
||||||
|
|
||||||
async def _set(self) -> T:
|
|
||||||
self.__value = await self.__getter()
|
|
||||||
return self.__value
|
|
||||||
|
|
||||||
def get_nowait(self) -> T:
|
|
||||||
if self.__task.done():
|
|
||||||
self.__task = asyncio.create_task(self._set())
|
|
||||||
return self.__value
|
|
||||||
|
|
||||||
async def get(self) -> T:
|
|
||||||
if self.__task.done():
|
|
||||||
self.__task = asyncio.create_task(self._set())
|
|
||||||
return await self.__task
|
|
||||||
|
|
||||||
|
|
||||||
async def _token_client(access_token: str) -> dict | None:
|
|
||||||
headers = {"Authorization": f"Bearer {access_token}"}
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get("https://discord.com/api/oauth2/@me", headers=headers) as response:
|
|
||||||
return await response.json()
|
|
||||||
|
|
||||||
|
|
||||||
class CachedDictionary(Generic[TKey, T]):
|
|
||||||
def __init__(self, factory: Callable[[TKey], Coroutine[Any, Any, T]]) -> None:
|
|
||||||
self.__factory = factory
|
|
||||||
self.__entries: dict[TKey, CachedEntry[T]] = {}
|
|
||||||
|
|
||||||
def entry(self, key: TKey, default: T) -> CachedEntry[T]:
|
|
||||||
if key not in self.__entries:
|
|
||||||
self.__entries[key] = CachedEntry(default, functools.partial(self.__factory, key))
|
|
||||||
return self.__entries[key]
|
|
||||||
|
|
||||||
|
|
||||||
class Context:
|
class Context:
|
||||||
def __init__(self, db: DbInterface, token_clients: CachedDictionary[str, dict | None]) -> None:
|
def __init__(self, db: DbInterface, token_clients: CachedDictionary[str, dict | None]) -> None:
|
||||||
self.db = db
|
self.db = db
|
||||||
@ -93,10 +50,8 @@ class Context:
|
|||||||
match self.config("user-id"):
|
match self.config("user-id"):
|
||||||
case int() as result:
|
case int() as result:
|
||||||
return result
|
return result
|
||||||
case None:
|
|
||||||
raise RuntimeError("user-id not set")
|
|
||||||
case _:
|
case _:
|
||||||
raise RuntimeError("inconsistent DB state")
|
raise RuntimeError("user-id not set")
|
||||||
|
|
||||||
def secret(self) -> str:
|
def secret(self) -> str:
|
||||||
match self.config("secret"):
|
match self.config("secret"):
|
||||||
@ -234,16 +189,13 @@ def bytes_hash(b: bytes) -> str:
|
|||||||
async def on_cleanup(app: web.Application) -> None:
|
async def on_cleanup(app: web.Application) -> None:
|
||||||
async with app["es"]:
|
async with app["es"]:
|
||||||
del app["db"]
|
del app["db"]
|
||||||
|
del app["tc"]
|
||||||
|
del app["es"]
|
||||||
|
|
||||||
|
|
||||||
routes = web.RouteTableDef()
|
routes = web.RouteTableDef()
|
||||||
|
|
||||||
|
|
||||||
@routes.get("/")
|
|
||||||
async def home(_request: web.Request) -> web.StreamResponse:
|
|
||||||
return web.Response(text="sessionservice")
|
|
||||||
|
|
||||||
|
|
||||||
@routes.get("/sessiondata/")
|
@routes.get("/sessiondata/")
|
||||||
async def sessiondata(request: web.Request) -> web.StreamResponse:
|
async def sessiondata(request: web.Request) -> web.StreamResponse:
|
||||||
session = request.query.get("session")
|
session = request.query.get("session")
|
||||||
@ -305,8 +257,8 @@ async def get_app() -> web.Application:
|
|||||||
app = web.Application()
|
app = web.Application()
|
||||||
app.on_cleanup.append(on_cleanup)
|
app.on_cleanup.append(on_cleanup)
|
||||||
async with AsyncExitStack() as es:
|
async with AsyncExitStack() as es:
|
||||||
app["db"] = await es.enter_async_context(DbFactory(Path("/data/session.db"), kvfactory=KVJson()))
|
app["db"] = await es.enter_async_context(DbManager(Path("/data/session.db"), kvfactory=KVJson()))
|
||||||
app["tc"] = CachedDictionary(_token_client)
|
app["tc"] = CachedDictionary(token_client)
|
||||||
app["es"] = es.pop_all()
|
app["es"] = es.pop_all()
|
||||||
app.add_routes(routes)
|
app.add_routes(routes)
|
||||||
return app
|
return app
|
||||||
|
10
app/token_client.py
Normal file
10
app/token_client.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
__all__ = ("token_client",)
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
|
||||||
|
async def token_client(access_token: str) -> dict | None:
|
||||||
|
headers = {"Authorization": f"Bearer {access_token}"}
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get("https://discord.com/api/oauth2/@me", headers=headers) as response:
|
||||||
|
return await response.json()
|
Loading…
Reference in New Issue
Block a user