sessionservice/app/main.py
2023-03-31 11:28:01 +00:00

269 lines
8.5 KiB
Python

import sys
import urllib.parse
from contextlib import AsyncExitStack
from pathlib import Path
from typing import Any
import aiohttp
import nacl.hash
from aiohttp import web
from ptvp35 import DbInterface, DbManager, KVJson
from .cacheddictionary import CachedDictionary
from .token_client import token_client
if sys.version_info < (3, 11):
from typing_extensions import Self
else:
from typing import Self
class Context:
def __init__(self, db: DbInterface, token_clients: CachedDictionary[str, dict | None]) -> None:
self.db = db
self.token_clients = token_clients
@classmethod
def from_app(cls, app: web.Application) -> Self:
return cls(app["db"], app["tc"])
@classmethod
def from_request(cls, request: web.Request) -> Self:
return cls.from_app(request.app)
async def session_data(self, session: str) -> dict:
if session is None:
return {}
data = self.db.get(session, {})
if not isinstance(data, dict):
return {}
return data
def config(self, key: str, default: Any = None) -> Any:
return self.db.get(("config", key), default)
async def set_config(self, key: str, value: Any) -> None:
await self.db.set(("config", key), value)
def app_user_id(self) -> int:
match self.config("user-id"):
case int() as result:
return result
case _:
raise RuntimeError("user-id not set")
def secret(self) -> str:
match self.config("secret"):
case str() as result:
return result
case _:
raise RuntimeError("secret not set")
def redirect(self) -> str:
match self.config("redirect"):
case str() as result:
return result
case _:
raise RuntimeError("redirect not set")
def ready(self) -> bool:
match self.config("ready", False):
case bool() as result:
return result
case _:
raise RuntimeError("inconsistent DB state")
async def code_token(self, code: str) -> dict:
if not self.ready():
raise RuntimeError("not ready")
client_id = self.app_user_id()
data = {
"client_id": str(client_id),
"client_secret": self.secret(),
"grant_type": "authorization_code",
"code": code,
"redirect_uri": self.redirect(),
}
headers = {"Content-Type": "application/x-www-form-urlencoded"}
async with aiohttp.ClientSession() as session:
async with session.post("https://discord.com/api/oauth2/token", data=data, headers=headers) as response:
return await response.json()
async def auth(self, session: str, code: str) -> None:
data = await self.session_data(session)
data["code"] = code
data["token"] = await self.code_token(code)
await self.db.set(session, data)
def auth_link(self) -> str:
if not self.ready():
return ""
client_id = self.app_user_id()
return (
f"https://discord.com/api/oauth2/authorize?client_id={client_id}"
f"&redirect_uri={urllib.parse.quote(self.redirect())}&response_type=code&scope=identify"
)
async def token_client(self, access_token: str) -> dict | None:
return self.token_clients.entry(access_token, None).get_nowait()
async def session_client(self, data: dict) -> dict | None:
match data:
case {"token": {"access_token": str() as access_token}}:
pass
case _:
return None
return await self.token_client(access_token)
@classmethod
def client_user(cls, sclient: dict) -> dict | None:
return sclient.get("user")
@classmethod
def user_username_full(cls, user: dict) -> str | None:
match user:
case {"username": str() as username, "discriminator": str() as discriminator}:
return f"{username}#{discriminator}"
case _:
return None
@classmethod
def user_id(cls, user: dict) -> str | int | None:
return user.get("id")
@classmethod
def user_avatar(cls, user: dict) -> str | None:
return user.get("avatar")
@classmethod
def user_avatar_url(cls, user: dict) -> str | None:
cid = cls.user_id(user)
if cid is None:
return None
avatar = cls.user_avatar(user)
if avatar is None:
return None
return f"https://cdn.discordapp.com/avatars/{cid}/{avatar}.png"
@classmethod
def user_status(cls, user: dict) -> dict:
return {"avatar": cls.user_avatar_url(user), "id": cls.user_id(user), "username": cls.user_username_full(user)}
@classmethod
def client_status(cls, sclient: dict) -> dict:
user = cls.client_user(sclient)
return {
"expires": sclient.get("expires"),
"user": (None if user is None else cls.user_status(user)),
}
async def session_status(self, session: str) -> dict:
data = await self.session_data(session)
sclient = await self.session_client(data)
return {
"code_set": data.get("code") is not None,
"token_set": data.get("token") is not None,
"client": (None if sclient is None else self.client_status(sclient)),
}
@classmethod
async def user_id_of(cls, request: web.Request) -> int | None:
session = request.query.get("session")
if session is None:
return None
context = cls.from_request(request)
data = await context.session_data(session)
sclient = await context.session_client(data)
if sclient is None:
return None
user = context.client_user(sclient)
if user is None:
return None
user_id = context.user_id(user)
if user_id is None:
return None
return int(user_id)
def bytes_hash(b: bytes) -> str:
return nacl.hash.sha256(b).decode()
async def on_cleanup(app: web.Application) -> None:
async with app["es"]:
del app["db"]
del app["tc"]
del app["es"]
routes = web.RouteTableDef()
@routes.get("/sessiondata/")
async def sessiondata(request: web.Request) -> web.StreamResponse:
session = request.query.get("session")
if session is None:
raise web.HTTPBadRequest
return web.json_response(await Context.from_request(request).session_data(session))
@routes.get("/auth/")
async def auth(request: web.Request) -> web.StreamResponse:
session = request.query.get("session")
state = request.query.get("state")
code = request.query.get("code")
match session, state, code:
case str() as session, str() as state, str() as code:
if bytes_hash(session.encode()) != state:
raise web.HTTPBadRequest
context = Context.from_request(request)
await context.auth(session, code)
raise web.HTTPOk
case _:
raise web.HTTPBadRequest
@routes.post("/config/")
async def config(request: web.Request) -> web.StreamResponse:
match await request.json():
case {"key": str() as key, "value": _ as value}:
context = Context.from_request(request)
await context.set_config(key, value)
raise web.HTTPOk
case _:
raise web.HTTPBadRequest
@routes.get("/authlink/")
async def authlink(request: web.Request) -> web.StreamResponse:
return web.Response(text=Context.from_request(request).auth_link())
@routes.get("/state/")
async def get_state(request: web.Request) -> web.Response:
session = str(request.query.get("session"))
return web.json_response(data=f"{bytes_hash(session.encode())}")
@routes.get("/status/")
async def status(request: web.Request) -> web.Response:
session = str(request.query.get("session"))
return web.json_response(data=await Context.from_request(request).session_status(session))
@routes.get("/userid/")
async def user_id_of(request: web.Request) -> web.Response:
return web.json_response(data=await Context.user_id_of(request))
async def get_app() -> web.Application:
app = web.Application()
app.on_cleanup.append(on_cleanup)
async with AsyncExitStack() as es:
app["db"] = await es.enter_async_context(DbManager(Path("/data/session.db"), kvfactory=KVJson()))
app["tc"] = CachedDictionary(token_client)
app["es"] = es.pop_all()
app.add_routes(routes)
return app