import sys import urllib.parse from contextlib import AsyncExitStack from pathlib import Path from typing import Any import aiohttp import nacl.hash from aiohttp import web from ptvp35 import DbInterface, DbManager, KVJson from .cacheddictionary import CachedDictionary from .token_client import token_client if sys.version_info < (3, 11): from typing_extensions import Self else: from typing import Self class Context: def __init__(self, db: DbInterface, token_clients: CachedDictionary[str, dict | None]) -> None: self.db = db self.token_clients = token_clients @classmethod def from_app(cls, app: web.Application) -> Self: return cls(app["db"], app["tc"]) @classmethod def from_request(cls, request: web.Request) -> Self: return cls.from_app(request.app) async def session_data(self, session: str) -> dict: if session is None: return {} data = self.db.get(session, {}) if not isinstance(data, dict): return {} return data def config(self, key: str, default: Any = None) -> Any: return self.db.get(("config", key), default) async def set_config(self, key: str, value: Any) -> None: await self.db.set(("config", key), value) def app_user_id(self) -> int: match self.config("user-id"): case int() as result: return result case _: raise RuntimeError("user-id not set") def secret(self) -> str: match self.config("secret"): case str() as result: return result case _: raise RuntimeError("secret not set") def redirect(self) -> str: match self.config("redirect"): case str() as result: return result case _: raise RuntimeError("redirect not set") def ready(self) -> bool: match self.config("ready", False): case bool() as result: return result case _: raise RuntimeError("inconsistent DB state") async def code_token(self, code: str) -> dict: if not self.ready(): raise RuntimeError("not ready") client_id = self.app_user_id() data = { "client_id": str(client_id), "client_secret": self.secret(), "grant_type": "authorization_code", "code": code, "redirect_uri": self.redirect(), } headers = {"Content-Type": "application/x-www-form-urlencoded"} async with aiohttp.ClientSession() as session: async with session.post("https://discord.com/api/oauth2/token", data=data, headers=headers) as response: return await response.json() async def auth(self, session: str, code: str) -> None: data = await self.session_data(session) data["code"] = code data["token"] = await self.code_token(code) await self.db.set(session, data) def auth_link(self) -> str: if not self.ready(): return "" client_id = self.app_user_id() return ( f"https://discord.com/api/oauth2/authorize?client_id={client_id}" f"&redirect_uri={urllib.parse.quote(self.redirect())}&response_type=code&scope=identify" ) async def token_client(self, access_token: str) -> dict | None: return self.token_clients.entry(access_token, None).get_nowait() async def session_client(self, data: dict) -> dict | None: match data: case {"token": {"access_token": str() as access_token}}: pass case _: return None return await self.token_client(access_token) @classmethod def client_user(cls, sclient: dict) -> dict | None: return sclient.get("user") @classmethod def user_username_full(cls, user: dict) -> str | None: match user: case {"username": str() as username, "discriminator": str() as discriminator}: return f"{username}#{discriminator}" case _: return None @classmethod def user_id(cls, user: dict) -> str | int | None: return user.get("id") @classmethod def user_avatar(cls, user: dict) -> str | None: return user.get("avatar") @classmethod def user_avatar_url(cls, user: dict) -> str | None: cid = cls.user_id(user) if cid is None: return None avatar = cls.user_avatar(user) if avatar is None: return None return f"https://cdn.discordapp.com/avatars/{cid}/{avatar}.png" @classmethod def user_status(cls, user: dict) -> dict: return {"avatar": cls.user_avatar_url(user), "id": cls.user_id(user), "username": cls.user_username_full(user)} @classmethod def client_status(cls, sclient: dict) -> dict: user = cls.client_user(sclient) return { "expires": sclient.get("expires"), "user": (None if user is None else cls.user_status(user)), } async def session_status(self, session: str) -> dict: data = await self.session_data(session) sclient = await self.session_client(data) return { "code_set": data.get("code") is not None, "token_set": data.get("token") is not None, "client": (None if sclient is None else self.client_status(sclient)), } @classmethod async def user_id_of(cls, request: web.Request) -> int | None: session = request.query.get("session") if session is None: return None context = cls.from_request(request) data = await context.session_data(session) sclient = await context.session_client(data) if sclient is None: return None user = context.client_user(sclient) if user is None: return None user_id = context.user_id(user) if user_id is None: return None return int(user_id) def bytes_hash(b: bytes) -> str: return nacl.hash.sha256(b).decode() async def on_cleanup(app: web.Application) -> None: async with app["es"]: del app["db"] del app["tc"] del app["es"] routes = web.RouteTableDef() @routes.get("/sessiondata/") async def sessiondata(request: web.Request) -> web.StreamResponse: session = request.query.get("session") if session is None: raise web.HTTPBadRequest return web.json_response(await Context.from_request(request).session_data(session)) @routes.get("/auth/") async def auth(request: web.Request) -> web.StreamResponse: session = request.query.get("session") state = request.query.get("state") code = request.query.get("code") match session, state, code: case str() as session, str() as state, str() as code: if bytes_hash(session.encode()) != state: raise web.HTTPBadRequest context = Context.from_request(request) await context.auth(session, code) raise web.HTTPOk case _: raise web.HTTPBadRequest @routes.post("/config/") async def config(request: web.Request) -> web.StreamResponse: match await request.json(): case {"key": str() as key, "value": _ as value}: context = Context.from_request(request) await context.set_config(key, value) raise web.HTTPOk case _: raise web.HTTPBadRequest @routes.get("/authlink/") async def authlink(request: web.Request) -> web.StreamResponse: return web.Response(text=Context.from_request(request).auth_link()) @routes.get("/state/") async def get_state(request: web.Request) -> web.Response: session = str(request.query.get("session")) return web.json_response(data=f"{bytes_hash(session.encode())}") @routes.get("/status/") async def status(request: web.Request) -> web.Response: session = str(request.query.get("session")) return web.json_response(data=await Context.from_request(request).session_status(session)) @routes.get("/userid/") async def user_id_of(request: web.Request) -> web.Response: return web.json_response(data=await Context.user_id_of(request)) async def get_app() -> web.Application: app = web.Application() app.on_cleanup.append(on_cleanup) async with AsyncExitStack() as es: app["db"] = await es.enter_async_context(DbManager(Path("/data/session.db"), kvfactory=KVJson())) app["tc"] = CachedDictionary(token_client) app["es"] = es.pop_all() app.add_routes(routes) return app