diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..68bc17f --- /dev/null +++ b/.gitignore @@ -0,0 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..9863271 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,13 @@ +FROM python:3.11 + +WORKDIR /code/ + +COPY requirements.txt requirements.txt + +RUN pip install --no-cache-dir --upgrade -r requirements.txt + +COPY app app + +RUN python3 -m app.main + +CMD ["python3", "-m", "app"] diff --git a/app/__main__.py b/app/__main__.py new file mode 100644 index 0000000..469d871 --- /dev/null +++ b/app/__main__.py @@ -0,0 +1,7 @@ +from aiohttp import web + +from .main import get_app + +web.run_app( + get_app(), host='0.0.0.0', port=80 +) diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..194e295 --- /dev/null +++ b/app/main.py @@ -0,0 +1,312 @@ +import asyncio +import functools +import sys +import urllib.parse +from contextlib import AsyncExitStack +from pathlib import Path +from typing import Any, Callable, Coroutine, Generic, Hashable, TypeVar + +import aiohttp +import nacl.hash +from aiohttp import web + +from ptvp35 import DbFactory, DbInterface, KVJson + +if sys.version_info < (3, 11): + from typing_extensions import Self +else: + from typing import Self + + +T = TypeVar("T") +TKey = TypeVar("TKey", bound=Hashable) + + +class CachedEntry(Generic[T]): + def __init__(self, value: T, getter: Callable[[], Coroutine[Any, Any, T]]) -> None: + self.__value: T = value + self.__getter = getter + self.__task: asyncio.Future[T] = asyncio.Future() + self.__task.set_result(value) + + async def _set(self) -> T: + self.__value = await self.__getter() + return self.__value + + def get_nowait(self) -> T: + if self.__task.done(): + self.__task = asyncio.create_task(self._set()) + return self.__value + + async def get(self) -> T: + if self.__task.done(): + self.__task = asyncio.create_task(self._set()) + return await self.__task + + +async def _token_client(access_token: str) -> dict | None: + headers = {"Authorization": f"Bearer {access_token}"} + async with aiohttp.ClientSession() as session: + async with session.get("https://discord.com/api/oauth2/@me", headers=headers) as response: + return await response.json() + + +class CachedDictionary(Generic[TKey, T]): + def __init__(self, factory: Callable[[TKey], Coroutine[Any, Any, T]]) -> None: + self.__factory = factory + self.__entries: dict[TKey, CachedEntry[T]] = {} + + def entry(self, key: TKey, default: T) -> CachedEntry[T]: + if key not in self.__entries: + self.__entries[key] = CachedEntry(default, functools.partial(self.__factory, key)) + return self.__entries[key] + + +class Context: + def __init__(self, db: DbInterface, token_clients: CachedDictionary[str, dict | None]) -> None: + self.db = db + self.token_clients = token_clients + + @classmethod + def from_app(cls, app: web.Application) -> Self: + return cls(app["db"], app["tc"]) + + @classmethod + def from_request(cls, request: web.Request) -> Self: + return cls.from_app(request.app) + + async def session_data(self, session: str) -> dict: + if session is None: + return {} + data = self.db.get(session, {}) + if not isinstance(data, dict): + return {} + return data + + def config(self, key: str, default: Any = None) -> Any: + return self.db.get(("config", key), default) + + async def set_config(self, key: str, value: Any) -> None: + await self.db.set(("config", key), value) + + def app_user_id(self) -> int: + match self.config("user-id"): + case int() as result: + return result + case None: + raise RuntimeError("user-id not set") + case _: + raise RuntimeError("inconsistent DB state") + + def secret(self) -> str: + match self.config("secret"): + case str() as result: + return result + case _: + raise RuntimeError("secret not set") + + def redirect(self) -> str: + return "https://music.parrrate.ru/auth/" + + def ready(self) -> bool: + match self.config("ready", False): + case bool() as result: + return result + case _: + raise RuntimeError("inconsistent DB state") + + async def code_token(self, code: str) -> dict: + if not self.ready(): + raise RuntimeError("not ready") + client_id = self.app_user_id() + data = { + "client_id": str(client_id), + "client_secret": self.secret(), + "grant_type": "authorization_code", + "code": code, + "redirect_uri": self.redirect(), + } + headers = {"Content-Type": "application/x-www-form-urlencoded"} + async with aiohttp.ClientSession() as session: + async with session.post("https://discord.com/api/oauth2/token", data=data, headers=headers) as response: + return await response.json() + + async def auth(self, session: str, code: str) -> None: + data = await self.session_data(session) + data["code"] = code + data["token"] = await self.code_token(code) + await self.db.set(session, data) + + def auth_link(self) -> str: + if not self.ready(): + return "" + client_id = self.app_user_id() + return ( + f"https://discord.com/api/oauth2/authorize?client_id={client_id}" + f"&redirect_uri={urllib.parse.quote(self.redirect())}&response_type=code&scope=identify" + ) + + async def token_client(self, access_token: str) -> dict | None: + return self.token_clients.entry(access_token, None).get_nowait() + + async def session_client(self, data: dict) -> dict | None: + match data: + case {"token": {"access_token": str() as access_token}}: + pass + case _: + return None + return await self.token_client(access_token) + + @classmethod + def client_user(cls, sclient: dict) -> dict | None: + return sclient.get("user") + + @classmethod + def user_username_full(cls, user: dict) -> str | None: + match user: + case {"username": str() as username, "discriminator": str() as discriminator}: + return f"{username}#{discriminator}" + case _: + return None + + @classmethod + def user_id(cls, user: dict) -> str | int | None: + return user.get("id") + + @classmethod + def user_avatar(cls, user: dict) -> str | None: + return user.get("avatar") + + @classmethod + def user_avatar_url(cls, user: dict) -> str | None: + cid = cls.user_id(user) + if cid is None: + return None + avatar = cls.user_avatar(user) + if avatar is None: + return None + return f"https://cdn.discordapp.com/avatars/{cid}/{avatar}.png" + + @classmethod + def user_status(cls, user: dict) -> dict: + return {"avatar": cls.user_avatar_url(user), "id": cls.user_id(user), "username": cls.user_username_full(user)} + + @classmethod + def client_status(cls, sclient: dict) -> dict: + user = cls.client_user(sclient) + return { + "expires": sclient.get("expires"), + "user": (None if user is None else cls.user_status(user)), + } + + async def session_status(self, session: str) -> dict: + data = await self.session_data(session) + sclient = await self.session_client(data) + return { + "code_set": data.get("code") is not None, + "token_set": data.get("token") is not None, + "client": (None if sclient is None else self.client_status(sclient)), + } + + @classmethod + async def user_id_of(cls, request: web.Request) -> int | None: + session = request.query.get("session") + if session is None: + return None + context = cls.from_request(request) + data = await context.session_data(session) + sclient = await context.session_client(data) + if sclient is None: + return None + user = context.client_user(sclient) + if user is None: + return None + user_id = context.user_id(user) + if user_id is None: + return None + return int(user_id) + + +def bytes_hash(b: bytes) -> str: + return nacl.hash.sha256(b).decode() + + +async def on_cleanup(app: web.Application) -> None: + async with app["es"]: + del app["db"] + + +routes = web.RouteTableDef() + + +@routes.get("/") +async def home(_request: web.Request) -> web.StreamResponse: + return web.Response(text="sessionservice") + + +@routes.get("/sessiondata/") +async def sessiondata(request: web.Request) -> web.StreamResponse: + session = request.query.get("session") + if session is None: + raise web.HTTPBadRequest + return web.json_response(await Context.from_request(request).session_data(session)) + + +@routes.get("/auth/") +async def auth(request: web.Request) -> web.StreamResponse: + session = request.query.get("session") + state = request.query.get("state") + code = request.query.get("code") + match session, state, code: + case str() as session, str() as state, str() as code: + if bytes_hash(session.encode()) != state: + raise web.HTTPBadRequest + context = Context.from_request(request) + await context.auth(session, code) + raise web.HTTPOk + case _: + raise web.HTTPBadRequest + + +@routes.post("/config/") +async def config(request: web.Request) -> web.StreamResponse: + match await request.json(): + case {"key": str() as key, "value": _ as value}: + context = Context.from_request(request) + await context.set_config(key, value) + raise web.HTTPOk + case _: + raise web.HTTPBadRequest + + +@routes.get("/authlink/") +async def authlink(request: web.Request) -> web.StreamResponse: + return web.Response(text=Context.from_request(request).auth_link()) + + +@routes.get("/state/") +async def get_state(request: web.Request) -> web.Response: + session = str(request.query.get("session")) + return web.json_response(data=f"{bytes_hash(session.encode())}") + + +@routes.get("/status/") +async def status(request: web.Request) -> web.Response: + session = str(request.query.get("session")) + return web.json_response(data=await Context.from_request(request).session_status(session)) + + +@routes.get("/userid/") +async def user_id_of(request: web.Request) -> web.Response: + return web.json_response(data=await Context.user_id_of(request)) + + +async def get_app() -> web.Application: + app = web.Application() + app.on_cleanup.append(on_cleanup) + async with AsyncExitStack() as es: + app["db"] = await es.enter_async_context(DbFactory(Path("/data/session.db"), kvfactory=KVJson())) + app["tc"] = CachedDictionary(_token_client) + app["es"] = es.pop_all() + app.add_routes(routes) + return app diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..3d3457f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +PyNaCl==1.5.0 +aiohttp==3.8.4 +ptvp35 @ git+https://gitea.parrrate.ru/PTV/ptvp35.git@fffff4973e0fbc7fdf3425e6486ce6378dccf821