269 lines
8.5 KiB
Python
269 lines
8.5 KiB
Python
import sys
|
|
import urllib.parse
|
|
from contextlib import AsyncExitStack
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import aiohttp
|
|
import nacl.hash
|
|
from aiohttp import web
|
|
|
|
from ptvp35 import DbInterface, DbManager, KVJson
|
|
|
|
from .cacheddictionary import CachedDictionary
|
|
from .token_client import token_client
|
|
|
|
if sys.version_info < (3, 11):
|
|
from typing_extensions import Self
|
|
else:
|
|
from typing import Self
|
|
|
|
|
|
class Context:
|
|
def __init__(self, db: DbInterface, token_clients: CachedDictionary[str, dict | None]) -> None:
|
|
self.db = db
|
|
self.token_clients = token_clients
|
|
|
|
@classmethod
|
|
def from_app(cls, app: web.Application) -> Self:
|
|
return cls(app["db"], app["tc"])
|
|
|
|
@classmethod
|
|
def from_request(cls, request: web.Request) -> Self:
|
|
return cls.from_app(request.app)
|
|
|
|
async def session_data(self, session: str) -> dict:
|
|
if session is None:
|
|
return {}
|
|
data = self.db.get(session, {})
|
|
if not isinstance(data, dict):
|
|
return {}
|
|
return data
|
|
|
|
def config(self, key: str, default: Any = None) -> Any:
|
|
return self.db.get(("config", key), default)
|
|
|
|
async def set_config(self, key: str, value: Any) -> None:
|
|
await self.db.set(("config", key), value)
|
|
|
|
def app_user_id(self) -> int:
|
|
match self.config("user-id"):
|
|
case int() as result:
|
|
return result
|
|
case _:
|
|
raise RuntimeError("user-id not set")
|
|
|
|
def secret(self) -> str:
|
|
match self.config("secret"):
|
|
case str() as result:
|
|
return result
|
|
case _:
|
|
raise RuntimeError("secret not set")
|
|
|
|
def redirect(self) -> str:
|
|
match self.config("redirect"):
|
|
case str() as result:
|
|
return result
|
|
case _:
|
|
raise RuntimeError("redirect not set")
|
|
|
|
def ready(self) -> bool:
|
|
match self.config("ready", False):
|
|
case bool() as result:
|
|
return result
|
|
case _:
|
|
raise RuntimeError("inconsistent DB state")
|
|
|
|
async def code_token(self, code: str) -> dict:
|
|
if not self.ready():
|
|
raise RuntimeError("not ready")
|
|
client_id = self.app_user_id()
|
|
data = {
|
|
"client_id": str(client_id),
|
|
"client_secret": self.secret(),
|
|
"grant_type": "authorization_code",
|
|
"code": code,
|
|
"redirect_uri": self.redirect(),
|
|
}
|
|
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post("https://discord.com/api/oauth2/token", data=data, headers=headers) as response:
|
|
return await response.json()
|
|
|
|
async def auth(self, session: str, code: str) -> None:
|
|
data = await self.session_data(session)
|
|
data["code"] = code
|
|
data["token"] = await self.code_token(code)
|
|
await self.db.set(session, data)
|
|
|
|
def auth_link(self) -> str:
|
|
if not self.ready():
|
|
return ""
|
|
client_id = self.app_user_id()
|
|
return (
|
|
f"https://discord.com/api/oauth2/authorize?client_id={client_id}"
|
|
f"&redirect_uri={urllib.parse.quote(self.redirect())}&response_type=code&scope=identify"
|
|
)
|
|
|
|
async def token_client(self, access_token: str) -> dict | None:
|
|
return self.token_clients.entry(access_token, None).get_nowait()
|
|
|
|
async def session_client(self, data: dict) -> dict | None:
|
|
match data:
|
|
case {"token": {"access_token": str() as access_token}}:
|
|
pass
|
|
case _:
|
|
return None
|
|
return await self.token_client(access_token)
|
|
|
|
@classmethod
|
|
def client_user(cls, sclient: dict) -> dict | None:
|
|
return sclient.get("user")
|
|
|
|
@classmethod
|
|
def user_username_full(cls, user: dict) -> str | None:
|
|
match user:
|
|
case {"username": str() as username, "discriminator": str() as discriminator}:
|
|
return f"{username}#{discriminator}"
|
|
case _:
|
|
return None
|
|
|
|
@classmethod
|
|
def user_id(cls, user: dict) -> str | int | None:
|
|
return user.get("id")
|
|
|
|
@classmethod
|
|
def user_avatar(cls, user: dict) -> str | None:
|
|
return user.get("avatar")
|
|
|
|
@classmethod
|
|
def user_avatar_url(cls, user: dict) -> str | None:
|
|
cid = cls.user_id(user)
|
|
if cid is None:
|
|
return None
|
|
avatar = cls.user_avatar(user)
|
|
if avatar is None:
|
|
return None
|
|
return f"https://cdn.discordapp.com/avatars/{cid}/{avatar}.png"
|
|
|
|
@classmethod
|
|
def user_status(cls, user: dict) -> dict:
|
|
return {"avatar": cls.user_avatar_url(user), "id": cls.user_id(user), "username": cls.user_username_full(user)}
|
|
|
|
@classmethod
|
|
def client_status(cls, sclient: dict) -> dict:
|
|
user = cls.client_user(sclient)
|
|
return {
|
|
"expires": sclient.get("expires"),
|
|
"user": (None if user is None else cls.user_status(user)),
|
|
}
|
|
|
|
async def session_status(self, session: str) -> dict:
|
|
data = await self.session_data(session)
|
|
sclient = await self.session_client(data)
|
|
return {
|
|
"code_set": data.get("code") is not None,
|
|
"token_set": data.get("token") is not None,
|
|
"client": (None if sclient is None else self.client_status(sclient)),
|
|
}
|
|
|
|
@classmethod
|
|
async def user_id_of(cls, request: web.Request) -> int | None:
|
|
session = request.query.get("session")
|
|
if session is None:
|
|
return None
|
|
context = cls.from_request(request)
|
|
data = await context.session_data(session)
|
|
sclient = await context.session_client(data)
|
|
if sclient is None:
|
|
return None
|
|
user = context.client_user(sclient)
|
|
if user is None:
|
|
return None
|
|
user_id = context.user_id(user)
|
|
if user_id is None:
|
|
return None
|
|
return int(user_id)
|
|
|
|
|
|
def bytes_hash(b: bytes) -> str:
|
|
return nacl.hash.sha256(b).decode()
|
|
|
|
|
|
async def on_cleanup(app: web.Application) -> None:
|
|
async with app["es"]:
|
|
del app["db"]
|
|
del app["tc"]
|
|
del app["es"]
|
|
|
|
|
|
routes = web.RouteTableDef()
|
|
|
|
|
|
@routes.get("/sessiondata/")
|
|
async def sessiondata(request: web.Request) -> web.StreamResponse:
|
|
session = request.query.get("session")
|
|
if session is None:
|
|
raise web.HTTPBadRequest
|
|
return web.json_response(await Context.from_request(request).session_data(session))
|
|
|
|
|
|
@routes.get("/auth/")
|
|
async def auth(request: web.Request) -> web.StreamResponse:
|
|
session = request.query.get("session")
|
|
state = request.query.get("state")
|
|
code = request.query.get("code")
|
|
match session, state, code:
|
|
case str() as session, str() as state, str() as code:
|
|
if bytes_hash(session.encode()) != state:
|
|
raise web.HTTPBadRequest
|
|
context = Context.from_request(request)
|
|
await context.auth(session, code)
|
|
raise web.HTTPOk
|
|
case _:
|
|
raise web.HTTPBadRequest
|
|
|
|
|
|
@routes.post("/config/")
|
|
async def config(request: web.Request) -> web.StreamResponse:
|
|
match await request.json():
|
|
case {"key": str() as key, "value": _ as value}:
|
|
context = Context.from_request(request)
|
|
await context.set_config(key, value)
|
|
raise web.HTTPOk
|
|
case _:
|
|
raise web.HTTPBadRequest
|
|
|
|
|
|
@routes.get("/authlink/")
|
|
async def authlink(request: web.Request) -> web.StreamResponse:
|
|
return web.Response(text=Context.from_request(request).auth_link())
|
|
|
|
|
|
@routes.get("/state/")
|
|
async def get_state(request: web.Request) -> web.Response:
|
|
session = str(request.query.get("session"))
|
|
return web.json_response(data=f"{bytes_hash(session.encode())}")
|
|
|
|
|
|
@routes.get("/status/")
|
|
async def status(request: web.Request) -> web.Response:
|
|
session = str(request.query.get("session"))
|
|
return web.json_response(data=await Context.from_request(request).session_status(session))
|
|
|
|
|
|
@routes.get("/userid/")
|
|
async def user_id_of(request: web.Request) -> web.Response:
|
|
return web.json_response(data=await Context.user_id_of(request))
|
|
|
|
|
|
async def get_app() -> web.Application:
|
|
app = web.Application()
|
|
app.on_cleanup.append(on_cleanup)
|
|
async with AsyncExitStack() as es:
|
|
app["db"] = await es.enter_async_context(DbManager(Path("/data/session.db"), kvfactory=KVJson()))
|
|
app["tc"] = CachedDictionary(token_client)
|
|
app["es"] = es.pop_all()
|
|
app.add_routes(routes)
|
|
return app
|