v6d3music/v6d3music/app.py
2022-12-24 16:37:05 +00:00

327 lines
11 KiB
Python

import asyncio
import functools
import os
import urllib.parse
from contextlib import AsyncExitStack
from pathlib import Path
from typing import Any, Callable, Coroutine, Generic, Hashable, TypeVar
import aiohttp
import discord
from aiohttp import web
from v6d3music.api import Api
from v6d3music.config import auth_redirect, myroot
from v6d3music.core.mainservice import MainService
from v6d3music.utils.bytes_hash import bytes_hash
from ptvp35 import *
from v6d0auth.appfactory import AppFactory
from v6d0auth.run_app import start_app
from v6d1tokens.client import request_token
T = TypeVar('T')
TKey = TypeVar('TKey', bound=Hashable)
class CachedEntry(Generic[T]):
def __init__(self, value: T, getter: Callable[[], Coroutine[Any, Any, T]]) -> None:
self.__value: T = value
self.__getter = getter
self.__task: asyncio.Future[T] = asyncio.Future()
self.__task.set_result(value)
async def _set(self) -> T:
self.__value = await self.__getter()
return self.__value
def get_nowait(self) -> T:
if self.__task.done():
self.__task = asyncio.create_task(self._set())
return self.__value
async def get(self) -> T:
if self.__task.done():
self.__task = asyncio.create_task(self._set())
return await self.__task
class CachedDictionary(Generic[TKey, T]):
def __init__(self, factory: Callable[[TKey], Coroutine[Any, Any, T]]) -> None:
self.__factory = factory
self.__entries: dict[TKey, CachedEntry[T]] = {}
def entry(self, key: TKey, default: T) -> CachedEntry[T]:
if key not in self.__entries:
self.__entries[key] = CachedEntry(default, functools.partial(self.__factory, key))
return self.__entries[key]
class MusicAppFactory(AppFactory):
htmlroot = Path(__file__).parent / 'html'
def __init__(
self,
secret: str,
client: discord.Client,
api: Api,
db: DbConnection
):
self.secret = secret
self.redirect = auth_redirect
self.loop = asyncio.get_running_loop()
self.client = client
self._api = api
self.db = db
self._token_clients: CachedDictionary[str, dict | None] = CachedDictionary(
self._token_client
)
def auth_link(self) -> str:
if self.client.user is None:
return ''
else:
return f'https://discord.com/api/oauth2/authorize?client_id={self.client.user.id}' \
f'&redirect_uri={urllib.parse.quote(self.redirect)}&response_type=code&scope=identify'
def _path(self, file: str):
return self.htmlroot / file
def _file(self, file: str):
with open(self.htmlroot / file) as f:
return f.read()
async def code_token(self, code: str) -> dict:
assert self.client.user is not None
data = {
'client_id': str(self.client.user.id),
'client_secret': self.secret,
'grant_type': 'authorization_code',
'code': code,
'redirect_uri': self.redirect
}
headers = {
'Content-Type': 'application/x-www-form-urlencoded'
}
async with aiohttp.ClientSession() as session:
async with session.post('https://discord.com/api/oauth2/token', data=data, headers=headers) as response:
return await response.json()
async def _token_client(self, access_token: str) -> dict | None:
headers = {
'Authorization': f'Bearer {access_token}'
}
async with aiohttp.ClientSession() as session:
async with session.get('https://discord.com/api/oauth2/@me', headers=headers) as response:
return await response.json()
async def token_client(self, access_token: str) -> dict | None:
return self._token_clients.entry(access_token, None).get_nowait()
async def session_client(self, data: dict) -> dict | None:
match data:
case {'token': {'access_token': str() as access_token}}:
pass
case _:
return None
return await self.token_client(access_token)
@classmethod
def client_status(cls, sclient: dict) -> dict:
user = cls.client_user(sclient)
return {
'expires': sclient.get('expires'),
'user': (None if user is None else cls.user_status(user)),
}
@classmethod
def user_status(cls, user: dict) -> dict:
return {
'avatar': cls.user_avatar_url(user),
'id': cls.user_id(user),
'username': cls.user_username_full(user)
}
@classmethod
def user_username_full(cls, user: dict) -> str | None:
match user:
case {'username': str() as username, 'discriminator': str() as discriminator}:
return f'{username}#{discriminator}'
case _:
return None
@classmethod
def client_user(cls, sclient: dict) -> dict | None:
return sclient.get('user')
@classmethod
def user_id(cls, user: dict) -> str | int | None:
return user.get('id')
@classmethod
def user_avatar(cls, user: dict) -> str | None:
return user.get('avatar')
@classmethod
def user_avatar_url(cls, user: dict) -> str | None:
cid = cls.user_id(user)
if cid is None:
return None
avatar = cls.user_avatar(user)
if avatar is None:
return None
return f'https://cdn.discordapp.com/avatars/{cid}/{avatar}.png'
async def session_status(self, session: str) -> dict:
data = self.session_data(session)
sclient = await self.session_client(data)
return {
'code_set': data.get('code') is not None,
'token_set': data.get('token') is not None,
'client': (None if sclient is None else self.client_status(sclient))
}
async def session_queue(self, session: str):
data = self.session_data(session)
sclient = await self.session_client(data)
if sclient is None:
return None
user = self.client_user(sclient)
if user is None:
return None
cid = self.user_id(user)
return cid
def session_data(self, session: str | None) -> dict:
if session is None:
return {}
data = self.db.get(session, {})
if not isinstance(data, dict):
return {}
return data
def define_routes(self, routes: web.RouteTableDef) -> None:
@routes.get('/')
async def home(_request: web.Request) -> web.StreamResponse:
return web.FileResponse(self._path('home.html'))
@routes.get('/login/')
async def login(_request: web.Request) -> web.StreamResponse:
return web.FileResponse(self._path('login.html'))
@routes.get('/authlink/')
async def authlink(_request: web.Request) -> web.StreamResponse:
return web.Response(text=self.auth_link())
@routes.get('/auth/')
async def auth(request: web.Request) -> web.StreamResponse:
if 'session' in request.query:
response = web.HTTPFound('/')
session = str(request.query.get('session'))
s_state = str(request.query.get('state'))
code = str(request.query.get('code'))
if bytes_hash(session.encode()) != s_state:
raise web.HTTPBadRequest
data = self.session_data(session)
data['code'] = code
data['token'] = await self.code_token(code)
await self.db.set(session, data)
return response
else:
return web.FileResponse(self._path('auth.html'))
@routes.get('/state/')
async def get_state(request: web.Request) -> web.Response:
session = str(request.query.get('session'))
return web.json_response(
data=f'{bytes_hash(session.encode())}'
)
@routes.get('/status/')
async def status(request: web.Request) -> web.Response:
session = str(request.query.get('session'))
return web.json_response(
data=await self.session_status(session)
)
@routes.get('/queue/')
async def api_queue(request: web.Request) -> web.Response:
session = str(request.query.get('session'))
return web.json_response(
data=await self.session_queue(session)
)
@routes.get('/main.js')
async def mainjs(_request: web.Request) -> web.StreamResponse:
return web.FileResponse(self._path('main.js'))
@routes.get('/main.css')
async def maincss(_request: web.Request) -> web.StreamResponse:
return web.FileResponse(self._path('main.css'))
@routes.post('/api/')
async def api(request: web.Request) -> web.Response:
session = request.query.get('session')
data = self.session_data(session)
sclient = await self.session_client(data)
if sclient is None:
raise web.HTTPUnauthorized
user = self.client_user(sclient)
if user is None:
raise web.HTTPUnauthorized
user_id = self.user_id(user)
if user_id is None:
raise web.HTTPUnauthorized
user_id = int(user_id)
jsr = await request.json()
assert isinstance(jsr, dict)
try:
return web.json_response(await self._api.api(jsr, user_id))
except Api.MisusedApi as e:
return web.json_response(e.json(), status=404)
class AppContext:
def __init__(self, mainservice: MainService) -> None:
self.mainservice = mainservice
async def start(self) -> tuple[web.Application, asyncio.Task[None]] | None:
try:
factory = MusicAppFactory(
await request_token('music-client', 'token'),
self.mainservice.client,
Api(
self.mainservice,
{key: value for key, value in os.environ.items() if key.startswith('roles')},
),
self.__db
)
except aiohttp.ClientConnectorError:
print('no web app (likely due to no token)')
else:
app = factory.app()
task = asyncio.create_task(start_app(app))
return app, task
async def __aenter__(self) -> 'AppContext':
async with AsyncExitStack() as es:
self.__db = await es.enter_async_context(DbFactory(myroot / 'session.db', kvfactory=KVJson()))
self.__es = es.pop_all()
self.__task: asyncio.Task[
tuple[web.Application, asyncio.Task[None]] | None
] = asyncio.create_task(self.start())
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
async with self.__es:
if self.__task.done():
result = await self.__task
if result is not None:
app, task = result
await task
await app.shutdown()
await app.cleanup()
else:
self.__task.cancel()
del self.__es