327 lines
11 KiB
Python
327 lines
11 KiB
Python
import asyncio
|
|
import functools
|
|
import os
|
|
import urllib.parse
|
|
from contextlib import AsyncExitStack
|
|
from pathlib import Path
|
|
from typing import Any, Callable, Coroutine, Generic, Hashable, TypeVar
|
|
|
|
import aiohttp
|
|
import discord
|
|
from aiohttp import web
|
|
from v6d3music.api import Api
|
|
from v6d3music.config import auth_redirect, myroot
|
|
from v6d3music.core.mainservice import MainService
|
|
from v6d3music.utils.bytes_hash import bytes_hash
|
|
|
|
from ptvp35 import *
|
|
from v6d0auth.appfactory import AppFactory
|
|
from v6d0auth.run_app import start_app
|
|
from v6d1tokens.client import request_token
|
|
|
|
|
|
T = TypeVar('T')
|
|
TKey = TypeVar('TKey', bound=Hashable)
|
|
|
|
|
|
class CachedEntry(Generic[T]):
|
|
def __init__(self, value: T, getter: Callable[[], Coroutine[Any, Any, T]]) -> None:
|
|
self.__value: T = value
|
|
self.__getter = getter
|
|
self.__task: asyncio.Future[T] = asyncio.Future()
|
|
self.__task.set_result(value)
|
|
|
|
async def _set(self) -> T:
|
|
self.__value = await self.__getter()
|
|
return self.__value
|
|
|
|
def get_nowait(self) -> T:
|
|
if self.__task.done():
|
|
self.__task = asyncio.create_task(self._set())
|
|
return self.__value
|
|
|
|
async def get(self) -> T:
|
|
if self.__task.done():
|
|
self.__task = asyncio.create_task(self._set())
|
|
return await self.__task
|
|
|
|
|
|
class CachedDictionary(Generic[TKey, T]):
|
|
def __init__(self, factory: Callable[[TKey], Coroutine[Any, Any, T]]) -> None:
|
|
self.__factory = factory
|
|
self.__entries: dict[TKey, CachedEntry[T]] = {}
|
|
|
|
def entry(self, key: TKey, default: T) -> CachedEntry[T]:
|
|
if key not in self.__entries:
|
|
self.__entries[key] = CachedEntry(default, functools.partial(self.__factory, key))
|
|
return self.__entries[key]
|
|
|
|
|
|
class MusicAppFactory(AppFactory):
|
|
htmlroot = Path(__file__).parent / 'html'
|
|
|
|
def __init__(
|
|
self,
|
|
secret: str,
|
|
client: discord.Client,
|
|
api: Api,
|
|
db: DbConnection
|
|
):
|
|
self.secret = secret
|
|
self.redirect = auth_redirect
|
|
self.loop = asyncio.get_running_loop()
|
|
self.client = client
|
|
self._api = api
|
|
self.db = db
|
|
self._token_clients: CachedDictionary[str, dict | None] = CachedDictionary(
|
|
self._token_client
|
|
)
|
|
|
|
def auth_link(self) -> str:
|
|
if self.client.user is None:
|
|
return ''
|
|
else:
|
|
return f'https://discord.com/api/oauth2/authorize?client_id={self.client.user.id}' \
|
|
f'&redirect_uri={urllib.parse.quote(self.redirect)}&response_type=code&scope=identify'
|
|
|
|
def _path(self, file: str):
|
|
return self.htmlroot / file
|
|
|
|
def _file(self, file: str):
|
|
with open(self.htmlroot / file) as f:
|
|
return f.read()
|
|
|
|
async def code_token(self, code: str) -> dict:
|
|
assert self.client.user is not None
|
|
data = {
|
|
'client_id': str(self.client.user.id),
|
|
'client_secret': self.secret,
|
|
'grant_type': 'authorization_code',
|
|
'code': code,
|
|
'redirect_uri': self.redirect
|
|
}
|
|
headers = {
|
|
'Content-Type': 'application/x-www-form-urlencoded'
|
|
}
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post('https://discord.com/api/oauth2/token', data=data, headers=headers) as response:
|
|
return await response.json()
|
|
|
|
async def _token_client(self, access_token: str) -> dict | None:
|
|
headers = {
|
|
'Authorization': f'Bearer {access_token}'
|
|
}
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get('https://discord.com/api/oauth2/@me', headers=headers) as response:
|
|
return await response.json()
|
|
|
|
async def token_client(self, access_token: str) -> dict | None:
|
|
return self._token_clients.entry(access_token, None).get_nowait()
|
|
|
|
async def session_client(self, data: dict) -> dict | None:
|
|
match data:
|
|
case {'token': {'access_token': str() as access_token}}:
|
|
pass
|
|
case _:
|
|
return None
|
|
return await self.token_client(access_token)
|
|
|
|
@classmethod
|
|
def client_status(cls, sclient: dict) -> dict:
|
|
user = cls.client_user(sclient)
|
|
return {
|
|
'expires': sclient.get('expires'),
|
|
'user': (None if user is None else cls.user_status(user)),
|
|
}
|
|
|
|
@classmethod
|
|
def user_status(cls, user: dict) -> dict:
|
|
return {
|
|
'avatar': cls.user_avatar_url(user),
|
|
'id': cls.user_id(user),
|
|
'username': cls.user_username_full(user)
|
|
}
|
|
|
|
@classmethod
|
|
def user_username_full(cls, user: dict) -> str | None:
|
|
match user:
|
|
case {'username': str() as username, 'discriminator': str() as discriminator}:
|
|
return f'{username}#{discriminator}'
|
|
case _:
|
|
return None
|
|
|
|
@classmethod
|
|
def client_user(cls, sclient: dict) -> dict | None:
|
|
return sclient.get('user')
|
|
|
|
@classmethod
|
|
def user_id(cls, user: dict) -> str | int | None:
|
|
return user.get('id')
|
|
|
|
@classmethod
|
|
def user_avatar(cls, user: dict) -> str | None:
|
|
return user.get('avatar')
|
|
|
|
@classmethod
|
|
def user_avatar_url(cls, user: dict) -> str | None:
|
|
cid = cls.user_id(user)
|
|
if cid is None:
|
|
return None
|
|
avatar = cls.user_avatar(user)
|
|
if avatar is None:
|
|
return None
|
|
return f'https://cdn.discordapp.com/avatars/{cid}/{avatar}.png'
|
|
|
|
async def session_status(self, session: str) -> dict:
|
|
data = self.session_data(session)
|
|
sclient = await self.session_client(data)
|
|
return {
|
|
'code_set': data.get('code') is not None,
|
|
'token_set': data.get('token') is not None,
|
|
'client': (None if sclient is None else self.client_status(sclient))
|
|
}
|
|
|
|
async def session_queue(self, session: str):
|
|
data = self.session_data(session)
|
|
sclient = await self.session_client(data)
|
|
if sclient is None:
|
|
return None
|
|
user = self.client_user(sclient)
|
|
if user is None:
|
|
return None
|
|
cid = self.user_id(user)
|
|
return cid
|
|
|
|
def session_data(self, session: str | None) -> dict:
|
|
if session is None:
|
|
return {}
|
|
data = self.db.get(session, {})
|
|
if not isinstance(data, dict):
|
|
return {}
|
|
return data
|
|
|
|
def define_routes(self, routes: web.RouteTableDef) -> None:
|
|
@routes.get('/')
|
|
async def home(_request: web.Request) -> web.StreamResponse:
|
|
return web.FileResponse(self._path('home.html'))
|
|
|
|
@routes.get('/login/')
|
|
async def login(_request: web.Request) -> web.StreamResponse:
|
|
return web.FileResponse(self._path('login.html'))
|
|
|
|
@routes.get('/authlink/')
|
|
async def authlink(_request: web.Request) -> web.StreamResponse:
|
|
return web.Response(text=self.auth_link())
|
|
|
|
@routes.get('/auth/')
|
|
async def auth(request: web.Request) -> web.StreamResponse:
|
|
if 'session' in request.query:
|
|
response = web.HTTPFound('/')
|
|
session = str(request.query.get('session'))
|
|
s_state = str(request.query.get('state'))
|
|
code = str(request.query.get('code'))
|
|
if bytes_hash(session.encode()) != s_state:
|
|
raise web.HTTPBadRequest
|
|
data = self.session_data(session)
|
|
data['code'] = code
|
|
data['token'] = await self.code_token(code)
|
|
await self.db.set(session, data)
|
|
return response
|
|
else:
|
|
return web.FileResponse(self._path('auth.html'))
|
|
|
|
@routes.get('/state/')
|
|
async def get_state(request: web.Request) -> web.Response:
|
|
session = str(request.query.get('session'))
|
|
return web.json_response(
|
|
data=f'{bytes_hash(session.encode())}'
|
|
)
|
|
|
|
@routes.get('/status/')
|
|
async def status(request: web.Request) -> web.Response:
|
|
session = str(request.query.get('session'))
|
|
return web.json_response(
|
|
data=await self.session_status(session)
|
|
)
|
|
|
|
@routes.get('/queue/')
|
|
async def api_queue(request: web.Request) -> web.Response:
|
|
session = str(request.query.get('session'))
|
|
return web.json_response(
|
|
data=await self.session_queue(session)
|
|
)
|
|
|
|
@routes.get('/main.js')
|
|
async def mainjs(_request: web.Request) -> web.StreamResponse:
|
|
return web.FileResponse(self._path('main.js'))
|
|
|
|
@routes.get('/main.css')
|
|
async def maincss(_request: web.Request) -> web.StreamResponse:
|
|
return web.FileResponse(self._path('main.css'))
|
|
|
|
@routes.post('/api/')
|
|
async def api(request: web.Request) -> web.Response:
|
|
session = request.query.get('session')
|
|
data = self.session_data(session)
|
|
sclient = await self.session_client(data)
|
|
if sclient is None:
|
|
raise web.HTTPUnauthorized
|
|
user = self.client_user(sclient)
|
|
if user is None:
|
|
raise web.HTTPUnauthorized
|
|
user_id = self.user_id(user)
|
|
if user_id is None:
|
|
raise web.HTTPUnauthorized
|
|
user_id = int(user_id)
|
|
jsr = await request.json()
|
|
assert isinstance(jsr, dict)
|
|
try:
|
|
return web.json_response(await self._api.api(jsr, user_id))
|
|
except Api.MisusedApi as e:
|
|
return web.json_response(e.json(), status=404)
|
|
|
|
|
|
class AppContext:
|
|
def __init__(self, mainservice: MainService) -> None:
|
|
self.mainservice = mainservice
|
|
|
|
async def start(self) -> tuple[web.Application, asyncio.Task[None]] | None:
|
|
try:
|
|
factory = MusicAppFactory(
|
|
await request_token('music-client', 'token'),
|
|
self.mainservice.client,
|
|
Api(
|
|
self.mainservice,
|
|
{key: value for key, value in os.environ.items() if key.startswith('roles')},
|
|
),
|
|
self.__db
|
|
)
|
|
except aiohttp.ClientConnectorError:
|
|
print('no web app (likely due to no token)')
|
|
else:
|
|
app = factory.app()
|
|
task = asyncio.create_task(start_app(app))
|
|
return app, task
|
|
|
|
async def __aenter__(self) -> 'AppContext':
|
|
async with AsyncExitStack() as es:
|
|
self.__db = await es.enter_async_context(DbFactory(myroot / 'session.db', kvfactory=KVJson()))
|
|
self.__es = es.pop_all()
|
|
self.__task: asyncio.Task[
|
|
tuple[web.Application, asyncio.Task[None]] | None
|
|
] = asyncio.create_task(self.start())
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
async with self.__es:
|
|
if self.__task.done():
|
|
result = await self.__task
|
|
if result is not None:
|
|
app, task = result
|
|
await task
|
|
await app.shutdown()
|
|
await app.cleanup()
|
|
else:
|
|
self.__task.cancel()
|
|
del self.__es
|