diff --git a/v6d3music/app.py b/v6d3music/app.py index a76dcb4..c8bc5d9 100644 --- a/v6d3music/app.py +++ b/v6d3music/app.py @@ -1,23 +1,61 @@ import asyncio +import functools import urllib.parse from pathlib import Path -from typing import Optional +from typing import Any, Callable, Coroutine, Generic, Hashable, TypeVar import aiohttp import discord from aiohttp import web +from v6d3music.api import Api +from v6d3music.config import auth_redirect, myroot +from v6d3music.utils.bytes_hash import bytes_hash + from ptvp35 import Db, KVJson from v6d0auth.appfactory import AppFactory from v6d0auth.run_app import start_app from v6d1tokens.client import request_token -from v6d3music.config import auth_redirect, myroot -from v6d3music.utils.bytes_hash import bytes_hash -from v6d3music.api import Api - session_db = Db(myroot / 'session.db', kvfactory=KVJson()) +T = TypeVar('T') +TKey = TypeVar('TKey', bound=Hashable) + + +class CachedEntry(Generic[T]): + def __init__(self, value: T, getter: Callable[[], Coroutine[Any, Any, T]]) -> None: + self.__value: T = value + self.__getter = getter + self.__task: asyncio.Future[T] = asyncio.Future() + self.__task.set_result(value) + + async def _set(self) -> T: + self.__value = await self.__getter() + return self.__value + + def get_nowait(self) -> T: + if self.__task.done(): + self.__task = asyncio.create_task(self._set()) + return self.__value + + async def get(self) -> T: + if self.__task.done(): + self.__task = asyncio.create_task(self._set()) + return await self.__task + + +class CachedDictionary(Generic[TKey, T]): + def __init__(self, factory: Callable[[TKey], Coroutine[Any, Any, T]]) -> None: + self.__factory = factory + self.__entries: dict[TKey, CachedEntry[T]] = {} + + def entry(self, key: TKey, default: T) -> CachedEntry[T]: + if key not in self.__entries: + self.__entries[key] = CachedEntry(default, functools.partial(self.__factory, key)) + return self.__entries[key] + + class MusicAppFactory(AppFactory): htmlroot = Path(__file__).parent / 'html' @@ -32,6 +70,9 @@ class MusicAppFactory(AppFactory): self.loop = asyncio.get_running_loop() self.client = client self._api = api + self._token_clients: CachedDictionary[str, dict | None] = CachedDictionary( + self._token_client + ) def auth_link(self) -> str: if self.client.user is None: @@ -40,28 +81,13 @@ class MusicAppFactory(AppFactory): return f'https://discord.com/api/oauth2/authorize?client_id={self.client.user.id}' \ f'&redirect_uri={urllib.parse.quote(self.redirect)}&response_type=code&scope=identify' + def _path(self, file: str): + return self.htmlroot / file + def _file(self, file: str): with open(self.htmlroot / file) as f: return f.read() - async def file(self, file: str): - return await self.loop.run_in_executor( - None, - self._file, - file - ) - - async def html_resp(self, file: str): - text = await self.file(f'{file}.html') - text = text.replace( - '$$DISCORD_AUTH$$', - self.auth_link() - ) - return web.Response( - text=text, - content_type='text/html' - ) - async def code_token(self, code: str) -> dict: assert self.client.user is not None data = { @@ -78,13 +104,7 @@ class MusicAppFactory(AppFactory): async with session.post('https://discord.com/api/oauth2/token', data=data, headers=headers) as response: return await response.json() - @classmethod - async def session_client(cls, data: dict) -> Optional[dict]: - match data: - case {'token': {'access_token': str() as access_token}}: - pass - case _: - return None + async def _token_client(self, access_token: str) -> dict | None: headers = { 'Authorization': f'Bearer {access_token}' } @@ -92,6 +112,17 @@ class MusicAppFactory(AppFactory): async with session.get('https://discord.com/api/oauth2/@me', headers=headers) as response: return await response.json() + async def token_client(self, access_token: str) -> dict | None: + return self._token_clients.entry(access_token, None).get_nowait() + + async def session_client(self, data: dict) -> dict | None: + match data: + case {'token': {'access_token': str() as access_token}}: + pass + case _: + return None + return await self.token_client(access_token) + @classmethod def client_status(cls, sclient: dict) -> dict: user = cls.client_user(sclient) @@ -109,7 +140,7 @@ class MusicAppFactory(AppFactory): } @classmethod - def user_username_full(cls, user: dict) -> Optional[str]: + def user_username_full(cls, user: dict) -> str | None: match user: case {'username': str() as username, 'discriminator': str() as discriminator}: return f'{username}#{discriminator}' @@ -117,19 +148,19 @@ class MusicAppFactory(AppFactory): return None @classmethod - def client_user(cls, sclient: dict) -> Optional[dict]: + def client_user(cls, sclient: dict) -> dict | None: return sclient.get('user') @classmethod - def user_id(cls, user: dict) -> Optional[str | int]: + def user_id(cls, user: dict) -> str | int | None: return user.get('id') @classmethod - def user_avatar(cls, user: dict) -> Optional[str]: + def user_avatar(cls, user: dict) -> str | None: return user.get('avatar') @classmethod - def user_avatar_url(cls, user: dict) -> Optional[str]: + def user_avatar_url(cls, user: dict) -> str | None: cid = cls.user_id(user) if cid is None: return None @@ -169,15 +200,19 @@ class MusicAppFactory(AppFactory): def define_routes(self, routes: web.RouteTableDef) -> None: @routes.get('/') - async def home(_request: web.Request) -> web.Response: - return await self.html_resp('home') + async def home(_request: web.Request) -> web.StreamResponse: + return web.FileResponse(self._path('home.html')) @routes.get('/login/') - async def login(_request: web.Request) -> web.Response: - return await self.html_resp('login') + async def login(_request: web.Request) -> web.StreamResponse: + return web.FileResponse(self._path('login.html')) + + @routes.get('/authlink/') + async def authlink(_request: web.Request) -> web.StreamResponse: + return web.Response(text=self.auth_link()) @routes.get('/auth/') - async def auth(request: web.Request) -> web.Response: + async def auth(request: web.Request) -> web.StreamResponse: if 'session' in request.query: response = web.HTTPFound('/') session = str(request.query.get('session')) @@ -191,7 +226,7 @@ class MusicAppFactory(AppFactory): await session_db.set(session, data) return response else: - return await self.html_resp('auth') + return web.FileResponse(self._path('auth.html')) @routes.get('/state/') async def get_state(request: web.Request) -> web.Response: @@ -215,17 +250,12 @@ class MusicAppFactory(AppFactory): ) @routes.get('/main.js') - async def mainjs(_request: web.Request) -> web.Response: - return web.Response( - text=await self.file('main.js') - ) + async def mainjs(_request: web.Request) -> web.StreamResponse: + return web.FileResponse(self._path('main.js')) @routes.get('/main.css') - async def maincss(_request: web.Request) -> web.Response: - return web.Response( - text=await self.file('main.css'), - content_type='text/css' - ) + async def maincss(_request: web.Request) -> web.StreamResponse: + return web.FileResponse(self._path('main.css')) @routes.post('/api/') async def api(request: web.Request) -> web.Response: diff --git a/v6d3music/html/home.html b/v6d3music/html/home.html index d45a117..4fa9d5c 100644 --- a/v6d3music/html/home.html +++ b/v6d3music/html/home.html @@ -3,7 +3,11 @@
- +