sessionservice
This commit is contained in:
parent
4aa5a679c1
commit
07bae21312
2
setup.py
2
setup.py
@ -6,7 +6,7 @@ setup(
|
||||
packages=['v6d3music'],
|
||||
url='',
|
||||
license='',
|
||||
author='PARRRATE T&V',
|
||||
author='PARRRATE TNV',
|
||||
author_email='',
|
||||
description=''
|
||||
)
|
||||
|
156
v6d3music/api.py
156
v6d3music/api.py
@ -3,41 +3,41 @@ import time
|
||||
|
||||
import discord
|
||||
from typing_extensions import Self
|
||||
from v6d2ctx.integration.responsetype import *
|
||||
from v6d2ctx.integration.targets import *
|
||||
|
||||
from rainbowadn.instrument import Instrumentation
|
||||
from v6d2ctx.context import *
|
||||
from v6d2ctx.integration.responsetype import *
|
||||
from v6d2ctx.integration.targets import *
|
||||
from v6d3music.core.mainaudio import *
|
||||
from v6d3music.core.mainservice import *
|
||||
|
||||
__all__ = ('Api',)
|
||||
__all__ = ("Api",)
|
||||
|
||||
|
||||
class Api:
|
||||
class MisusedApi(Exception):
|
||||
def json(self) -> dict:
|
||||
return {'error': list(map(str, self.args)), 'errormessage': str(self)}
|
||||
return {"error": list(map(str, self.args)), "errormessage": str(self)}
|
||||
|
||||
class UnknownApi(MisusedApi):
|
||||
def json(self) -> dict:
|
||||
return super().json() | {'unknownapi': None}
|
||||
return super().json() | {"unknownapi": None}
|
||||
|
||||
class ExplicitFailure(MisusedApi):
|
||||
def __init__(self, explicit: Explicit) -> None:
|
||||
super().__init__(*explicit.args)
|
||||
|
||||
def json(self) -> dict:
|
||||
return super().json() | {'explicit': None}
|
||||
return super().json() | {"explicit": None}
|
||||
|
||||
def __init__(self, mainservice: MainService, roles: dict[str, str]) -> None:
|
||||
self.mainservice = mainservice
|
||||
self.client = mainservice.client
|
||||
self.roles = roles
|
||||
self.targets = mainservice.targets
|
||||
self.targets.register_instance(self, 'api', Async)
|
||||
self.targets.register_instrumentation('Count', lambda t, n: Count(t, n))
|
||||
self.targets.register_instrumentation('Concurrency', lambda t, n: Concurrency(t, n), Async)
|
||||
self.targets.register_instance(self, "api", Async)
|
||||
self.targets.register_instrumentation("Count", lambda t, n: Count(t, n))
|
||||
self.targets.register_instrumentation("Concurrency", lambda t, n: Concurrency(t, n), Async)
|
||||
|
||||
def user_id(self) -> int | None:
|
||||
if self.client.user is None:
|
||||
@ -46,15 +46,15 @@ class Api:
|
||||
return self.client.user.id
|
||||
|
||||
def is_operator(self, user_id: int) -> bool:
|
||||
return '(operator)' in self.roles.get(f'roles{user_id}', '')
|
||||
return "(operator)" in self.roles.get(f"roles{user_id}", "")
|
||||
|
||||
async def api(self, request: dict, user_id: int) -> ResponseType:
|
||||
response = await UserApi(ApiSession(self), request, user_id).api()
|
||||
match response, request:
|
||||
case {'time': _}, _:
|
||||
case {"time": _}, _:
|
||||
pass
|
||||
case dict() as d, {'time': _}:
|
||||
response = d | {'time': time.time()}
|
||||
case dict() as d, {"time": _}:
|
||||
response = d | {"time": time.time()}
|
||||
return response
|
||||
|
||||
|
||||
@ -65,7 +65,7 @@ class ApiSession:
|
||||
|
||||
def api(self):
|
||||
if self.__complexity <= 0:
|
||||
raise Api.MisusedApi('hit complexity limit')
|
||||
raise Api.MisusedApi("hit complexity limit")
|
||||
self.__complexity -= 1
|
||||
return self.__api
|
||||
|
||||
@ -73,7 +73,7 @@ class ApiSession:
|
||||
class UserApi:
|
||||
class UnknownMember(Api.MisusedApi):
|
||||
def json(self) -> dict:
|
||||
return super().json() | {'unknownmember': None}
|
||||
return super().json() | {"unknownmember": None}
|
||||
|
||||
def __init__(self, session: ApiSession, request: dict, user_id: int) -> None:
|
||||
self.session = session
|
||||
@ -86,21 +86,19 @@ class UserApi:
|
||||
|
||||
async def subs(self, requests: list[dict] | dict[str, dict]) -> ResponseType:
|
||||
match self.request:
|
||||
case {'idkey': str() as idkey}:
|
||||
case {"idkey": str() as idkey}:
|
||||
pass
|
||||
case _:
|
||||
idkey = 'type'
|
||||
idkey = "type"
|
||||
match self.request:
|
||||
case {'idbase': dict() as base}:
|
||||
case {"idbase": dict() as base}:
|
||||
pass
|
||||
case _:
|
||||
base = {}
|
||||
match requests:
|
||||
case list():
|
||||
return list(
|
||||
await asyncio.gather(
|
||||
*(self.sub(request, key).api() for (key, request) in enumerate(requests))
|
||||
)
|
||||
await asyncio.gather(*(self.sub(request, key).api() for (key, request) in enumerate(requests)))
|
||||
)
|
||||
case dict():
|
||||
items = list(requests.items())
|
||||
@ -109,7 +107,7 @@ class UserApi:
|
||||
)
|
||||
return dict((key, response) for (key, _), response in zip(items, responses))
|
||||
case _:
|
||||
raise Api.MisusedApi('that should not happen')
|
||||
raise Api.MisusedApi("that should not happen")
|
||||
|
||||
def _sub(self, request: dict) -> Self:
|
||||
return UserApi(self.session, request, self.user_id)
|
||||
@ -120,43 +118,43 @@ class UserApi:
|
||||
sub._key = key
|
||||
return sub
|
||||
|
||||
async def to_guild_api(self, guild_id: int) -> 'GuildApi':
|
||||
async def to_guild_api(self, guild_id: int) -> "GuildApi":
|
||||
guild = self.client.get_guild(guild_id) or await self.client.fetch_guild(guild_id)
|
||||
if guild is None:
|
||||
raise UserApi.UnknownMember('unknown guild')
|
||||
raise UserApi.UnknownMember("unknown guild")
|
||||
member = guild.get_member(self.user_id) or await guild.fetch_member(self.user_id)
|
||||
if member is None:
|
||||
raise UserApi.UnknownMember('unknown member of a guild')
|
||||
raise UserApi.UnknownMember("unknown member of a guild")
|
||||
return GuildApi(self, member)
|
||||
|
||||
async def to_operator_api(self) -> 'OperatorApi':
|
||||
async def to_operator_api(self) -> "OperatorApi":
|
||||
if not self.pi.is_operator(self.user_id):
|
||||
raise UserApi.UnknownMember('not an operator')
|
||||
raise UserApi.UnknownMember("not an operator")
|
||||
return OperatorApi(self)
|
||||
|
||||
def _api_text(self) -> str:
|
||||
return 'user api'
|
||||
return "user api"
|
||||
|
||||
async def _fall_through_api(self) -> ResponseType:
|
||||
match self.request:
|
||||
case {'type': '?'}:
|
||||
return f'this is {self._api_text()}'
|
||||
case {'type': '*', 'requests': list() | dict() as requests}:
|
||||
case {"type": "?"}:
|
||||
return f"this is {self._api_text()}"
|
||||
case {"type": "*", "requests": list() | dict() as requests}:
|
||||
return await self.subs(requests)
|
||||
case _:
|
||||
raise Api.UnknownApi(f'unknown {self._api_text()}')
|
||||
raise Api.UnknownApi(f"unknown {self._api_text()}")
|
||||
|
||||
async def _api(self) -> ResponseType:
|
||||
match self.request:
|
||||
case {'guild': str() as guild_id_str} if guild_id_str.isdecimal() and len(guild_id_str) < 100:
|
||||
self.request.pop('guild')
|
||||
case {"guild": str() as guild_id_str} if guild_id_str.isdecimal() and len(guild_id_str) < 100:
|
||||
self.request.pop("guild")
|
||||
return await (await self.to_guild_api(int(guild_id_str))).api()
|
||||
case {'operator': _}:
|
||||
self.request.pop('operator')
|
||||
case {"operator": _}:
|
||||
self.request.pop("operator")
|
||||
return await (await self.to_operator_api()).api()
|
||||
case {'type': 'ping', 't': (float() | int()) as t}:
|
||||
case {"type": "ping", "t": (float() | int()) as t}:
|
||||
return time.time() - t
|
||||
case {'type': 'guilds'}:
|
||||
case {"type": "guilds"}:
|
||||
guilds = []
|
||||
for guild in self.client.guilds:
|
||||
if guild.get_member(self.user_id) is not None:
|
||||
@ -172,10 +170,10 @@ class UserApi:
|
||||
except Explicit as e:
|
||||
raise Api.ExplicitFailure(e) from e
|
||||
except Api.MisusedApi as e:
|
||||
catches = self.request.get('catches', {})
|
||||
catches = self.request.get("catches", {})
|
||||
if len(e.args) and (key := e.args[0]) in catches:
|
||||
return catches[key]
|
||||
if '*' in catches:
|
||||
if "*" in catches:
|
||||
return e.json()
|
||||
raise
|
||||
|
||||
@ -183,50 +181,48 @@ class UserApi:
|
||||
class GuildApi(UserApi):
|
||||
class VoiceNotConnected(Api.MisusedApi):
|
||||
def json(self) -> dict:
|
||||
return super().json() | {'notconnected': None}
|
||||
return super().json() | {"notconnected": None}
|
||||
|
||||
def __init__(self, api: UserApi, member: discord.Member) -> None:
|
||||
super().__init__(api.session, api.request, member.id)
|
||||
self.member = member
|
||||
self.guild = member.guild
|
||||
|
||||
async def to_voice_api(self) -> 'VoiceApi':
|
||||
async def to_voice_api(self) -> "VoiceApi":
|
||||
voice = self.member.voice
|
||||
if voice is None:
|
||||
raise GuildApi.VoiceNotConnected('you are not connected to voice')
|
||||
raise GuildApi.VoiceNotConnected("you are not connected to voice")
|
||||
channel = voice.channel
|
||||
if channel is None:
|
||||
raise GuildApi.VoiceNotConnected('you are not connected to a voice channel')
|
||||
raise GuildApi.VoiceNotConnected("you are not connected to a voice channel")
|
||||
if self.client.user is None:
|
||||
raise GuildApi.VoiceNotConnected('bot client user not initialised')
|
||||
raise GuildApi.VoiceNotConnected("bot client user not initialised")
|
||||
if self.client.user.id not in channel.voice_states:
|
||||
raise GuildApi.VoiceNotConnected('bot not connected')
|
||||
raise GuildApi.VoiceNotConnected("bot not connected")
|
||||
return VoiceApi(self, channel)
|
||||
|
||||
def _sub(self, request: dict) -> Self:
|
||||
return GuildApi(super()._sub(request), self.member)
|
||||
|
||||
def _api_text(self) -> str:
|
||||
return 'guild api'
|
||||
return "guild api"
|
||||
|
||||
async def _api(self) -> ResponseType:
|
||||
match self.request:
|
||||
case {'voice': _}:
|
||||
self.request.pop('voice')
|
||||
case {"voice": _}:
|
||||
self.request.pop("voice")
|
||||
return await (await self.to_voice_api()).api()
|
||||
case _:
|
||||
return await self._fall_through_api()
|
||||
|
||||
|
||||
class VoiceApi(GuildApi):
|
||||
def __init__(
|
||||
self, api: GuildApi, channel: discord.VoiceChannel | discord.StageChannel
|
||||
) -> None:
|
||||
def __init__(self, api: GuildApi, channel: discord.VoiceChannel | discord.StageChannel) -> None:
|
||||
super().__init__(api, api.member)
|
||||
self.channel = channel
|
||||
self.mainservice = self.pi.mainservice
|
||||
|
||||
async def to_main_api(self) -> 'MainApi':
|
||||
async def to_main_api(self) -> "MainApi":
|
||||
vc = await self.mainservice.raw_vc_for_member(self.member)
|
||||
main = await self.mainservice.mode(create=False, force_play=False).main_for_raw_vc(vc)
|
||||
return MainApi(self, vc, main)
|
||||
@ -235,21 +231,19 @@ class VoiceApi(GuildApi):
|
||||
return VoiceApi(super()._sub(request), self.channel)
|
||||
|
||||
def _api_text(self) -> str:
|
||||
return 'voice api'
|
||||
return "voice api"
|
||||
|
||||
async def _api(self) -> ResponseType:
|
||||
match self.request:
|
||||
case {'main': _}:
|
||||
self.request.pop('main')
|
||||
case {"main": _}:
|
||||
self.request.pop("main")
|
||||
return await (await self.to_main_api()).api()
|
||||
case _:
|
||||
return await self._fall_through_api()
|
||||
|
||||
|
||||
class MainApi(VoiceApi):
|
||||
def __init__(
|
||||
self, api: VoiceApi, vc: discord.VoiceClient, main: MainAudio
|
||||
) -> None:
|
||||
def __init__(self, api: VoiceApi, vc: discord.VoiceClient, main: MainAudio) -> None:
|
||||
super().__init__(api, api.channel)
|
||||
self.vc = vc
|
||||
self.main = main
|
||||
@ -258,20 +252,20 @@ class MainApi(VoiceApi):
|
||||
return MainApi(super()._sub(request), self.vc, self.main)
|
||||
|
||||
def _api_text(self) -> str:
|
||||
return 'main api'
|
||||
return "main api"
|
||||
|
||||
async def _api(self) -> ResponseType:
|
||||
match self.request:
|
||||
case {'type': 'volume'}:
|
||||
case {"type": "volume"}:
|
||||
return self.main.volume
|
||||
case {'type': 'playing'}:
|
||||
case {"type": "playing"}:
|
||||
return self.vc.is_playing()
|
||||
case {'type': 'paused'}:
|
||||
case {"type": "paused"}:
|
||||
return self.vc.is_paused()
|
||||
case {'type': 'queueformat'}:
|
||||
case {"type": "queueformat"}:
|
||||
return await self.main.queue.format()
|
||||
case {'type': 'queuejson'}:
|
||||
return await self.main.queue.pubjson(self.member, self.request.get('limit', 1000))
|
||||
case {"type": "queuejson"}:
|
||||
return await self.main.queue.pubjson(self.member, self.request.get("limit", 1000))
|
||||
case _:
|
||||
return await self._fall_through_api()
|
||||
|
||||
@ -287,30 +281,30 @@ class OperatorApi(UserApi):
|
||||
return OperatorApi(super()._sub(request))
|
||||
|
||||
def _api_text(self) -> str:
|
||||
return 'operator api'
|
||||
return "operator api"
|
||||
|
||||
async def _api(self) -> ResponseType:
|
||||
match self.request:
|
||||
case {'target': str() as targetname}:
|
||||
case {"target": str() as targetname}:
|
||||
return await InstrumentationApi(self, targetname).api()
|
||||
case {'type': 'resetmonitoring'}:
|
||||
case {"type": "resetmonitoring"}:
|
||||
return self.pi.mainservice.pmonitoring.reset()
|
||||
case {'type': 'guilds'}:
|
||||
case {"type": "guilds"}:
|
||||
guilds = []
|
||||
for guild in self.client.guilds:
|
||||
if self._guild_visible(guild):
|
||||
guilds.append(
|
||||
{
|
||||
'id': str(guild.id),
|
||||
'member_count': guild.member_count,
|
||||
'name': guild.name,
|
||||
"id": str(guild.id),
|
||||
"member_count": guild.member_count,
|
||||
"name": guild.name,
|
||||
}
|
||||
)
|
||||
return guilds
|
||||
case {'type': 'sleep', 'duration': (float() | int()) as duration, 'echo': _ as echo}:
|
||||
case {"type": "sleep", "duration": (float() | int()) as duration, "echo": _ as echo}:
|
||||
await asyncio.sleep(duration)
|
||||
return echo
|
||||
case {'type': 'pool'}:
|
||||
case {"type": "pool"}:
|
||||
return self.pi.mainservice.pool_json()
|
||||
case _:
|
||||
return await self._fall_through_api()
|
||||
@ -319,7 +313,7 @@ class OperatorApi(UserApi):
|
||||
class InstrumentationApi(OperatorApi):
|
||||
class UnknownTarget(Api.UnknownApi):
|
||||
def json(self) -> dict:
|
||||
return super().json() | {'unknowntarget': None}
|
||||
return super().json() | {"unknowntarget": None}
|
||||
|
||||
def __init__(self, api: OperatorApi, targetname: str) -> None:
|
||||
super().__init__(api)
|
||||
@ -327,20 +321,18 @@ class InstrumentationApi(OperatorApi):
|
||||
self.targetname = targetname
|
||||
target_tuple = self.targets.targets.get(targetname, None)
|
||||
if target_tuple is None:
|
||||
raise InstrumentationApi.UnknownTarget('unknown target', targetname)
|
||||
raise InstrumentationApi.UnknownTarget("unknown target", targetname)
|
||||
self.target, self.methodname = target_tuple.value
|
||||
|
||||
def _sub(self, request: dict) -> Self:
|
||||
return InstrumentationApi(super()._sub(request), self.targetname)
|
||||
|
||||
def _api_text(self) -> str:
|
||||
return 'instrumentation api'
|
||||
return "instrumentation api"
|
||||
|
||||
async def _api(self) -> ResponseType:
|
||||
match self.request:
|
||||
case {
|
||||
'type': str() as instrumentationname
|
||||
} if (
|
||||
case {"type": str() as instrumentationname} if (
|
||||
instrumentation_factory := self.targets.instrumentations.get(instrumentationname)
|
||||
) is not None:
|
||||
try:
|
||||
@ -355,7 +347,7 @@ class InstrumentationApi(OperatorApi):
|
||||
)
|
||||
except KeyError as e:
|
||||
raise InstrumentationApi.UnknownTarget(
|
||||
'binding failed', self.targetname, instrumentationname, str(e)
|
||||
"binding failed", self.targetname, instrumentationname, str(e)
|
||||
) from e
|
||||
if not isinstance(instrumentation, JsonLike):
|
||||
raise TypeError
|
||||
|
284
v6d3music/app.py
284
v6d3music/app.py
@ -1,244 +1,34 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import urllib.parse
|
||||
from contextlib import AsyncExitStack
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Coroutine, Generic, Hashable, TypeVar
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
|
||||
from ptvp35 import *
|
||||
from v6d0auth.appfactory import *
|
||||
from v6d0auth.run_app import *
|
||||
from v6d1tokens.client import *
|
||||
from v6d3music.api import *
|
||||
from v6d3music.config import auth_redirect, myroot
|
||||
from v6d3music.utils.bytes_hash import *
|
||||
from v6d0auth.appfactory import AppFactory
|
||||
from v6d0auth.run_app import start_app
|
||||
from v6d1tokens.client import request_token
|
||||
from v6d3music.api import Api
|
||||
from v6d3music.core.set_config import set_config
|
||||
|
||||
__all__ = ('AppContext',)
|
||||
|
||||
T = TypeVar('T')
|
||||
TKey = TypeVar('TKey', bound=Hashable)
|
||||
|
||||
|
||||
class CachedEntry(Generic[T]):
|
||||
def __init__(self, value: T, getter: Callable[[], Coroutine[Any, Any, T]]) -> None:
|
||||
self.__value: T = value
|
||||
self.__getter = getter
|
||||
self.__task: asyncio.Future[T] = asyncio.Future()
|
||||
self.__task.set_result(value)
|
||||
|
||||
async def _set(self) -> T:
|
||||
self.__value = await self.__getter()
|
||||
return self.__value
|
||||
|
||||
def get_nowait(self) -> T:
|
||||
if self.__task.done():
|
||||
self.__task = asyncio.create_task(self._set())
|
||||
return self.__value
|
||||
|
||||
async def get(self) -> T:
|
||||
if self.__task.done():
|
||||
self.__task = asyncio.create_task(self._set())
|
||||
return await self.__task
|
||||
|
||||
|
||||
class CachedDictionary(Generic[TKey, T]):
|
||||
def __init__(self, factory: Callable[[TKey], Coroutine[Any, Any, T]]) -> None:
|
||||
self.__factory = factory
|
||||
self.__entries: dict[TKey, CachedEntry[T]] = {}
|
||||
|
||||
def entry(self, key: TKey, default: T) -> CachedEntry[T]:
|
||||
if key not in self.__entries:
|
||||
self.__entries[key] = CachedEntry(default, functools.partial(self.__factory, key))
|
||||
return self.__entries[key]
|
||||
__all__ = ("AppContext",)
|
||||
|
||||
|
||||
class MusicAppFactory(AppFactory):
|
||||
def __init__(
|
||||
self,
|
||||
secret: str,
|
||||
api: Api,
|
||||
db: DbConnection
|
||||
):
|
||||
def __init__(self, secret: str, api: Api):
|
||||
self.secret = secret
|
||||
self.redirect = auth_redirect
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self._api = api
|
||||
self.db = db
|
||||
self._token_clients: CachedDictionary[str, dict | None] = CachedDictionary(
|
||||
self._token_client
|
||||
)
|
||||
|
||||
def auth_link(self) -> str:
|
||||
client_id = self._api.user_id()
|
||||
if client_id is None:
|
||||
return ''
|
||||
else:
|
||||
return f'https://discord.com/api/oauth2/authorize?client_id={client_id}' \
|
||||
f'&redirect_uri={urllib.parse.quote(self.redirect)}&response_type=code&scope=identify'
|
||||
|
||||
async def code_token(self, code: str) -> dict:
|
||||
client_id = self._api.user_id()
|
||||
assert client_id is not None
|
||||
data = {
|
||||
'client_id': str(client_id),
|
||||
'client_secret': self.secret,
|
||||
'grant_type': 'authorization_code',
|
||||
'code': code,
|
||||
'redirect_uri': self.redirect
|
||||
}
|
||||
headers = {
|
||||
'Content-Type': 'application/x-www-form-urlencoded'
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post('https://discord.com/api/oauth2/token', data=data, headers=headers) as response:
|
||||
return await response.json()
|
||||
|
||||
async def _token_client(self, access_token: str) -> dict | None:
|
||||
headers = {
|
||||
'Authorization': f'Bearer {access_token}'
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get('https://discord.com/api/oauth2/@me', headers=headers) as response:
|
||||
return await response.json()
|
||||
|
||||
async def token_client(self, access_token: str) -> dict | None:
|
||||
return self._token_clients.entry(access_token, None).get_nowait()
|
||||
|
||||
async def session_client(self, data: dict) -> dict | None:
|
||||
match data:
|
||||
case {'token': {'access_token': str() as access_token}}:
|
||||
pass
|
||||
case _:
|
||||
return None
|
||||
return await self.token_client(access_token)
|
||||
|
||||
@classmethod
|
||||
def client_status(cls, sclient: dict) -> dict:
|
||||
user = cls.client_user(sclient)
|
||||
return {
|
||||
'expires': sclient.get('expires'),
|
||||
'user': (None if user is None else cls.user_status(user)),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def user_status(cls, user: dict) -> dict:
|
||||
return {
|
||||
'avatar': cls.user_avatar_url(user),
|
||||
'id': cls.user_id(user),
|
||||
'username': cls.user_username_full(user)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def user_username_full(cls, user: dict) -> str | None:
|
||||
match user:
|
||||
case {'username': str() as username, 'discriminator': str() as discriminator}:
|
||||
return f'{username}#{discriminator}'
|
||||
case _:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def client_user(cls, sclient: dict) -> dict | None:
|
||||
return sclient.get('user')
|
||||
|
||||
@classmethod
|
||||
def user_id(cls, user: dict) -> str | int | None:
|
||||
return user.get('id')
|
||||
|
||||
@classmethod
|
||||
def user_avatar(cls, user: dict) -> str | None:
|
||||
return user.get('avatar')
|
||||
|
||||
@classmethod
|
||||
def user_avatar_url(cls, user: dict) -> str | None:
|
||||
cid = cls.user_id(user)
|
||||
if cid is None:
|
||||
return None
|
||||
avatar = cls.user_avatar(user)
|
||||
if avatar is None:
|
||||
return None
|
||||
return f'https://cdn.discordapp.com/avatars/{cid}/{avatar}.png'
|
||||
|
||||
async def session_status(self, session: str) -> dict:
|
||||
data = self.session_data(session)
|
||||
sclient = await self.session_client(data)
|
||||
return {
|
||||
'code_set': data.get('code') is not None,
|
||||
'token_set': data.get('token') is not None,
|
||||
'client': (None if sclient is None else self.client_status(sclient))
|
||||
}
|
||||
|
||||
async def session_queue(self, session: str):
|
||||
data = self.session_data(session)
|
||||
sclient = await self.session_client(data)
|
||||
if sclient is None:
|
||||
return None
|
||||
user = self.client_user(sclient)
|
||||
if user is None:
|
||||
return None
|
||||
cid = self.user_id(user)
|
||||
return cid
|
||||
|
||||
def session_data(self, session: str | None) -> dict:
|
||||
if session is None:
|
||||
return {}
|
||||
data = self.db.get(session, {})
|
||||
if not isinstance(data, dict):
|
||||
return {}
|
||||
return data
|
||||
|
||||
def define_routes(self, routes: web.RouteTableDef) -> None:
|
||||
@routes.get('/authlink/')
|
||||
async def authlink(_request: web.Request) -> web.StreamResponse:
|
||||
return web.Response(text=self.auth_link())
|
||||
|
||||
@routes.get('/auth/')
|
||||
async def auth(request: web.Request) -> web.StreamResponse:
|
||||
session = request.query.get('session')
|
||||
state = request.query.get('state')
|
||||
code = request.query.get('code')
|
||||
match session, state, code:
|
||||
case str() as session, str() as state, str() as code:
|
||||
if bytes_hash(session.encode()) != state:
|
||||
raise web.HTTPBadRequest
|
||||
data = self.session_data(session)
|
||||
data['code'] = code
|
||||
data['token'] = await self.code_token(code)
|
||||
await self.db.set(session, data)
|
||||
return web.HTTPFound('/')
|
||||
case _:
|
||||
raise web.HTTPBadRequest
|
||||
|
||||
@routes.get('/state/')
|
||||
async def get_state(request: web.Request) -> web.Response:
|
||||
session = str(request.query.get('session'))
|
||||
return web.json_response(
|
||||
data=f'{bytes_hash(session.encode())}'
|
||||
)
|
||||
|
||||
@routes.get('/status/')
|
||||
async def status(request: web.Request) -> web.Response:
|
||||
session = str(request.query.get('session'))
|
||||
return web.json_response(
|
||||
data=await self.session_status(session)
|
||||
)
|
||||
|
||||
@routes.post('/api/')
|
||||
@routes.post("/api/")
|
||||
async def api(request: web.Request) -> web.Response:
|
||||
session = request.query.get('session')
|
||||
data = self.session_data(session)
|
||||
sclient = await self.session_client(data)
|
||||
if sclient is None:
|
||||
raise web.HTTPUnauthorized
|
||||
user = self.client_user(sclient)
|
||||
if user is None:
|
||||
raise web.HTTPUnauthorized
|
||||
user_id = self.user_id(user)
|
||||
async with aiohttp.ClientSession() as s, s.get(
|
||||
"http://sessionservice/userid/", params={"session": request.query.get("session")}
|
||||
) as response:
|
||||
user_id: int | None = await response.json()
|
||||
if user_id is None:
|
||||
raise web.HTTPUnauthorized
|
||||
user_id = int(user_id)
|
||||
if not isinstance(user_id, int):
|
||||
raise TypeError
|
||||
jsr = await request.json()
|
||||
assert isinstance(jsr, dict)
|
||||
try:
|
||||
@ -246,12 +36,6 @@ class MusicAppFactory(AppFactory):
|
||||
except Api.MisusedApi as e:
|
||||
return web.json_response(e.json(), status=404)
|
||||
|
||||
@routes.get('/whaturl/')
|
||||
async def whaturl(request: web.Request) -> web.StreamResponse:
|
||||
if request.headers.get('X-Forwarded-Proto') == 'https':
|
||||
request = request.clone(scheme='https')
|
||||
return web.json_response(str(request.url))
|
||||
|
||||
|
||||
class AppContext:
|
||||
def __init__(self, api: Api) -> None:
|
||||
@ -259,36 +43,26 @@ class AppContext:
|
||||
|
||||
async def start(self) -> tuple[web.Application, asyncio.Task[None]] | None:
|
||||
try:
|
||||
factory = MusicAppFactory(
|
||||
await request_token('music-client', 'token'),
|
||||
self.api,
|
||||
self.__db
|
||||
)
|
||||
factory = MusicAppFactory(await request_token("music-client", "token"), self.api)
|
||||
await set_config("secret", factory.secret)
|
||||
except aiohttp.ClientConnectorError:
|
||||
print('no web app (likely due to no token)')
|
||||
print("no web app (likely due to no token)")
|
||||
else:
|
||||
app = factory.app()
|
||||
task = asyncio.create_task(start_app(app))
|
||||
return app, task
|
||||
|
||||
async def __aenter__(self) -> 'AppContext':
|
||||
async with AsyncExitStack() as es:
|
||||
self.__db = await es.enter_async_context(DbFactory(myroot / 'session.db', kvfactory=KVJson()))
|
||||
self.__task: asyncio.Task[
|
||||
tuple[web.Application, asyncio.Task[None]] | None
|
||||
] = asyncio.create_task(self.start())
|
||||
self.__es = es.pop_all()
|
||||
return self
|
||||
async def __aenter__(self) -> "AppContext":
|
||||
self.__task: asyncio.Task[tuple[web.Application, asyncio.Task[None]] | None] = asyncio.create_task(self.start())
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
async with self.__es:
|
||||
if self.__task.done():
|
||||
result = await self.__task
|
||||
if result is not None:
|
||||
app, task = result
|
||||
await task
|
||||
await app.shutdown()
|
||||
await app.cleanup()
|
||||
else:
|
||||
self.__task.cancel()
|
||||
del self.__es
|
||||
if self.__task.done():
|
||||
result = await self.__task
|
||||
if result is not None:
|
||||
app, task = result
|
||||
await task
|
||||
await app.shutdown()
|
||||
await app.cleanup()
|
||||
else:
|
||||
self.__task.cancel()
|
||||
|
@ -2,7 +2,6 @@ import os
|
||||
|
||||
from v6d0auth.config import root
|
||||
|
||||
prefix = os.getenv('v6prefix', '?/')
|
||||
auth_redirect = os.getenv('v6redirect', 'https://music.parrrate.ru/auth/')
|
||||
myroot = root / 'v6d3music'
|
||||
prefix = os.getenv("v6prefix", "?/")
|
||||
myroot = root / "v6d3music"
|
||||
myroot.mkdir(exist_ok=True)
|
||||
|
24
v6d3music/core/set_config.py
Normal file
24
v6d3music/core/set_config.py
Normal file
@ -0,0 +1,24 @@
|
||||
import asyncio
|
||||
from typing import Any, Callable, Coroutine, TypeVar
|
||||
|
||||
import aiohttp
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def repeat(repeated: Callable[[], Coroutine[Any, Any, T]]) -> T:
|
||||
for _ in range(60):
|
||||
try:
|
||||
return await repeated()
|
||||
except aiohttp.ClientConnectorError:
|
||||
await asyncio.sleep(1)
|
||||
raise RuntimeError("cannot reach sessionservice")
|
||||
|
||||
|
||||
async def set_config(key: str, value: Any) -> None:
|
||||
json = {"key": key, "value": value}
|
||||
async def call() -> None:
|
||||
async with aiohttp.ClientSession() as s, s.post("http://sessionservice/config/", json=json) as response:
|
||||
if response.status != 200:
|
||||
raise RuntimeError("config request failed")
|
||||
await repeat(call)
|
@ -6,13 +6,13 @@ import time
|
||||
from traceback import print_exc
|
||||
|
||||
import discord
|
||||
from v6d2ctx.integration.event import *
|
||||
from v6d2ctx.integration.targets import *
|
||||
|
||||
from ptvp35 import *
|
||||
from rainbowadn.instrument import Instrumentation
|
||||
from v6d1tokens.client import *
|
||||
from v6d2ctx.handle_content import *
|
||||
from v6d2ctx.integration.event import *
|
||||
from v6d2ctx.integration.targets import *
|
||||
from v6d2ctx.pain import *
|
||||
from v6d2ctx.serve import *
|
||||
from v6d3music.api import *
|
||||
@ -22,6 +22,7 @@ from v6d3music.config import prefix
|
||||
from v6d3music.core.caching import *
|
||||
from v6d3music.core.default_effects import *
|
||||
from v6d3music.core.mainservice import *
|
||||
from v6d3music.core.set_config import set_config
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
@ -47,7 +48,7 @@ _client = MusicClient(
|
||||
)
|
||||
|
||||
|
||||
banned_guilds = set(map(int, filter(bool, map(str.strip, os.getenv('banned_guilds', '').split(':')))))
|
||||
banned_guilds = set(map(int, filter(bool, map(str.strip, os.getenv("banned_guilds", "").split(":")))))
|
||||
|
||||
|
||||
def guild_allowed(guild: discord.Guild | None) -> bool:
|
||||
@ -71,31 +72,36 @@ def register_handlers(client: discord.Client, mainservice: MainService):
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
print('ready')
|
||||
print("ready")
|
||||
if client.user is None:
|
||||
raise RuntimeError
|
||||
await set_config("user-id", client.user.id)
|
||||
await set_config("ready", True)
|
||||
await client.change_presence(
|
||||
activity=discord.Game(
|
||||
name='феноменально',
|
||||
name="феноменально",
|
||||
)
|
||||
)
|
||||
await mainservice.restore()
|
||||
print("ready (startup finished)")
|
||||
|
||||
|
||||
class UpgradeABMInit(Instrumentation):
|
||||
def __init__(self):
|
||||
super().__init__(ABlockMonitor, '__init__')
|
||||
super().__init__(ABlockMonitor, "__init__")
|
||||
|
||||
def instrument(self, method, abm, *, threshold=0.0, delta=10.0, interval=0.0):
|
||||
print('created upgraded')
|
||||
print("created upgraded")
|
||||
method(abm, threshold=threshold, delta=delta, interval=interval)
|
||||
abm.threshold = threshold
|
||||
|
||||
|
||||
class UpgradeABMTask(Instrumentation):
|
||||
def __init__(self):
|
||||
super().__init__(ABlockMonitor, '_monitor')
|
||||
super().__init__(ABlockMonitor, "_monitor")
|
||||
|
||||
async def instrument(self, _, abm):
|
||||
print('started upgraded')
|
||||
print("started upgraded")
|
||||
while True:
|
||||
delta = abm.delta
|
||||
t = time.time()
|
||||
@ -104,8 +110,7 @@ class UpgradeABMTask(Instrumentation):
|
||||
delay = spent - delta
|
||||
if delay > abm.threshold:
|
||||
abm.threshold = delay
|
||||
print(
|
||||
f'upgraded block monitor reached new peak delay {delay:.4f}')
|
||||
print(f"upgraded block monitor reached new peak delay {delay:.4f}")
|
||||
interval = abm.interval
|
||||
if interval > 0:
|
||||
await asyncio.sleep(interval)
|
||||
@ -130,45 +135,47 @@ class PathPrint(Instrumentation):
|
||||
print(self.pref, db._DbConnection__path) # type: ignore
|
||||
except Exception:
|
||||
from traceback import print_exc
|
||||
|
||||
print_exc()
|
||||
return result
|
||||
|
||||
|
||||
def _db_ee() -> contextlib.ExitStack:
|
||||
with contextlib.ExitStack() as es:
|
||||
es.enter_context(PathPrint('_initialize', 'open :'))
|
||||
es.enter_context(PathPrint('aclose', 'close:'))
|
||||
es.enter_context(PathPrint("_initialize", "open :"))
|
||||
es.enter_context(PathPrint("aclose", "close:"))
|
||||
return es.pop_all()
|
||||
raise RuntimeError
|
||||
|
||||
|
||||
async def amain(client: discord.Client):
|
||||
roles = {key: value for key, value in os.environ.items() if key.startswith('roles')}
|
||||
roles = {key: value for key, value in os.environ.items() if key.startswith("roles")}
|
||||
async with (
|
||||
client,
|
||||
DefaultEffects() as defaulteffects,
|
||||
MainService(Targets(), defaulteffects, client, Events()) as mainservice,
|
||||
AppContext(Api(mainservice, roles)),
|
||||
ABlockMonitor(delta=0.5)
|
||||
ABlockMonitor(delta=0.5),
|
||||
):
|
||||
register_handlers(client, mainservice)
|
||||
if 'guerilla' in sys.argv:
|
||||
if "guerilla" in sys.argv:
|
||||
from pathlib import Path
|
||||
tokenpath = Path('.token.txt')
|
||||
|
||||
tokenpath = Path(".token.txt")
|
||||
if tokenpath.exists():
|
||||
token = tokenpath.read_text()
|
||||
else:
|
||||
token = input('token:')
|
||||
token = input("token:")
|
||||
tokenpath.write_text(token)
|
||||
elif (token_ := os.getenv('trial_token')):
|
||||
elif token_ := os.getenv("trial_token"):
|
||||
token = token_
|
||||
else:
|
||||
token = await request_token('music', 'token')
|
||||
token = await request_token("music", "token")
|
||||
await client.login(token)
|
||||
if os.getenv('v6tor', None) is None:
|
||||
print('no tor')
|
||||
if os.getenv("v6tor", None) is None:
|
||||
print("no tor")
|
||||
await client.connect()
|
||||
print('exited')
|
||||
print("exited")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user