From 07bae21312c16283c701b27ad7c2c12772e9545e Mon Sep 17 00:00:00 2001 From: timofey Date: Tue, 28 Feb 2023 12:32:19 +0000 Subject: [PATCH] sessionservice --- setup.py | 2 +- v6d3music/api.py | 156 +++++++++---------- v6d3music/app.py | 284 ++++------------------------------- v6d3music/config.py | 5 +- v6d3music/core/set_config.py | 24 +++ v6d3music/main.py | 53 ++++--- 6 files changed, 160 insertions(+), 364 deletions(-) create mode 100644 v6d3music/core/set_config.py diff --git a/setup.py b/setup.py index c28d08b..6a6eff3 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( packages=['v6d3music'], url='', license='', - author='PARRRATE T&V', + author='PARRRATE TNV', author_email='', description='' ) diff --git a/v6d3music/api.py b/v6d3music/api.py index 613fdcd..8a3e089 100644 --- a/v6d3music/api.py +++ b/v6d3music/api.py @@ -3,41 +3,41 @@ import time import discord from typing_extensions import Self -from v6d2ctx.integration.responsetype import * -from v6d2ctx.integration.targets import * from rainbowadn.instrument import Instrumentation from v6d2ctx.context import * +from v6d2ctx.integration.responsetype import * +from v6d2ctx.integration.targets import * from v6d3music.core.mainaudio import * from v6d3music.core.mainservice import * -__all__ = ('Api',) +__all__ = ("Api",) class Api: class MisusedApi(Exception): def json(self) -> dict: - return {'error': list(map(str, self.args)), 'errormessage': str(self)} + return {"error": list(map(str, self.args)), "errormessage": str(self)} class UnknownApi(MisusedApi): def json(self) -> dict: - return super().json() | {'unknownapi': None} + return super().json() | {"unknownapi": None} class ExplicitFailure(MisusedApi): def __init__(self, explicit: Explicit) -> None: super().__init__(*explicit.args) def json(self) -> dict: - return super().json() | {'explicit': None} + return super().json() | {"explicit": None} def __init__(self, mainservice: MainService, roles: dict[str, str]) -> None: self.mainservice = mainservice self.client = mainservice.client self.roles = roles self.targets = mainservice.targets - self.targets.register_instance(self, 'api', Async) - self.targets.register_instrumentation('Count', lambda t, n: Count(t, n)) - self.targets.register_instrumentation('Concurrency', lambda t, n: Concurrency(t, n), Async) + self.targets.register_instance(self, "api", Async) + self.targets.register_instrumentation("Count", lambda t, n: Count(t, n)) + self.targets.register_instrumentation("Concurrency", lambda t, n: Concurrency(t, n), Async) def user_id(self) -> int | None: if self.client.user is None: @@ -46,15 +46,15 @@ class Api: return self.client.user.id def is_operator(self, user_id: int) -> bool: - return '(operator)' in self.roles.get(f'roles{user_id}', '') + return "(operator)" in self.roles.get(f"roles{user_id}", "") async def api(self, request: dict, user_id: int) -> ResponseType: response = await UserApi(ApiSession(self), request, user_id).api() match response, request: - case {'time': _}, _: + case {"time": _}, _: pass - case dict() as d, {'time': _}: - response = d | {'time': time.time()} + case dict() as d, {"time": _}: + response = d | {"time": time.time()} return response @@ -65,7 +65,7 @@ class ApiSession: def api(self): if self.__complexity <= 0: - raise Api.MisusedApi('hit complexity limit') + raise Api.MisusedApi("hit complexity limit") self.__complexity -= 1 return self.__api @@ -73,7 +73,7 @@ class ApiSession: class UserApi: class UnknownMember(Api.MisusedApi): def json(self) -> dict: - return super().json() | {'unknownmember': None} + return super().json() | {"unknownmember": None} def __init__(self, session: ApiSession, request: dict, user_id: int) -> None: self.session = session @@ -86,21 +86,19 @@ class UserApi: async def subs(self, requests: list[dict] | dict[str, dict]) -> ResponseType: match self.request: - case {'idkey': str() as idkey}: + case {"idkey": str() as idkey}: pass case _: - idkey = 'type' + idkey = "type" match self.request: - case {'idbase': dict() as base}: + case {"idbase": dict() as base}: pass case _: base = {} match requests: case list(): return list( - await asyncio.gather( - *(self.sub(request, key).api() for (key, request) in enumerate(requests)) - ) + await asyncio.gather(*(self.sub(request, key).api() for (key, request) in enumerate(requests))) ) case dict(): items = list(requests.items()) @@ -109,7 +107,7 @@ class UserApi: ) return dict((key, response) for (key, _), response in zip(items, responses)) case _: - raise Api.MisusedApi('that should not happen') + raise Api.MisusedApi("that should not happen") def _sub(self, request: dict) -> Self: return UserApi(self.session, request, self.user_id) @@ -120,43 +118,43 @@ class UserApi: sub._key = key return sub - async def to_guild_api(self, guild_id: int) -> 'GuildApi': + async def to_guild_api(self, guild_id: int) -> "GuildApi": guild = self.client.get_guild(guild_id) or await self.client.fetch_guild(guild_id) if guild is None: - raise UserApi.UnknownMember('unknown guild') + raise UserApi.UnknownMember("unknown guild") member = guild.get_member(self.user_id) or await guild.fetch_member(self.user_id) if member is None: - raise UserApi.UnknownMember('unknown member of a guild') + raise UserApi.UnknownMember("unknown member of a guild") return GuildApi(self, member) - async def to_operator_api(self) -> 'OperatorApi': + async def to_operator_api(self) -> "OperatorApi": if not self.pi.is_operator(self.user_id): - raise UserApi.UnknownMember('not an operator') + raise UserApi.UnknownMember("not an operator") return OperatorApi(self) def _api_text(self) -> str: - return 'user api' + return "user api" async def _fall_through_api(self) -> ResponseType: match self.request: - case {'type': '?'}: - return f'this is {self._api_text()}' - case {'type': '*', 'requests': list() | dict() as requests}: + case {"type": "?"}: + return f"this is {self._api_text()}" + case {"type": "*", "requests": list() | dict() as requests}: return await self.subs(requests) case _: - raise Api.UnknownApi(f'unknown {self._api_text()}') + raise Api.UnknownApi(f"unknown {self._api_text()}") async def _api(self) -> ResponseType: match self.request: - case {'guild': str() as guild_id_str} if guild_id_str.isdecimal() and len(guild_id_str) < 100: - self.request.pop('guild') + case {"guild": str() as guild_id_str} if guild_id_str.isdecimal() and len(guild_id_str) < 100: + self.request.pop("guild") return await (await self.to_guild_api(int(guild_id_str))).api() - case {'operator': _}: - self.request.pop('operator') + case {"operator": _}: + self.request.pop("operator") return await (await self.to_operator_api()).api() - case {'type': 'ping', 't': (float() | int()) as t}: + case {"type": "ping", "t": (float() | int()) as t}: return time.time() - t - case {'type': 'guilds'}: + case {"type": "guilds"}: guilds = [] for guild in self.client.guilds: if guild.get_member(self.user_id) is not None: @@ -172,10 +170,10 @@ class UserApi: except Explicit as e: raise Api.ExplicitFailure(e) from e except Api.MisusedApi as e: - catches = self.request.get('catches', {}) + catches = self.request.get("catches", {}) if len(e.args) and (key := e.args[0]) in catches: return catches[key] - if '*' in catches: + if "*" in catches: return e.json() raise @@ -183,50 +181,48 @@ class UserApi: class GuildApi(UserApi): class VoiceNotConnected(Api.MisusedApi): def json(self) -> dict: - return super().json() | {'notconnected': None} + return super().json() | {"notconnected": None} def __init__(self, api: UserApi, member: discord.Member) -> None: super().__init__(api.session, api.request, member.id) self.member = member self.guild = member.guild - async def to_voice_api(self) -> 'VoiceApi': + async def to_voice_api(self) -> "VoiceApi": voice = self.member.voice if voice is None: - raise GuildApi.VoiceNotConnected('you are not connected to voice') + raise GuildApi.VoiceNotConnected("you are not connected to voice") channel = voice.channel if channel is None: - raise GuildApi.VoiceNotConnected('you are not connected to a voice channel') + raise GuildApi.VoiceNotConnected("you are not connected to a voice channel") if self.client.user is None: - raise GuildApi.VoiceNotConnected('bot client user not initialised') + raise GuildApi.VoiceNotConnected("bot client user not initialised") if self.client.user.id not in channel.voice_states: - raise GuildApi.VoiceNotConnected('bot not connected') + raise GuildApi.VoiceNotConnected("bot not connected") return VoiceApi(self, channel) def _sub(self, request: dict) -> Self: return GuildApi(super()._sub(request), self.member) def _api_text(self) -> str: - return 'guild api' + return "guild api" async def _api(self) -> ResponseType: match self.request: - case {'voice': _}: - self.request.pop('voice') + case {"voice": _}: + self.request.pop("voice") return await (await self.to_voice_api()).api() case _: return await self._fall_through_api() class VoiceApi(GuildApi): - def __init__( - self, api: GuildApi, channel: discord.VoiceChannel | discord.StageChannel - ) -> None: + def __init__(self, api: GuildApi, channel: discord.VoiceChannel | discord.StageChannel) -> None: super().__init__(api, api.member) self.channel = channel self.mainservice = self.pi.mainservice - async def to_main_api(self) -> 'MainApi': + async def to_main_api(self) -> "MainApi": vc = await self.mainservice.raw_vc_for_member(self.member) main = await self.mainservice.mode(create=False, force_play=False).main_for_raw_vc(vc) return MainApi(self, vc, main) @@ -235,21 +231,19 @@ class VoiceApi(GuildApi): return VoiceApi(super()._sub(request), self.channel) def _api_text(self) -> str: - return 'voice api' + return "voice api" async def _api(self) -> ResponseType: match self.request: - case {'main': _}: - self.request.pop('main') + case {"main": _}: + self.request.pop("main") return await (await self.to_main_api()).api() case _: return await self._fall_through_api() class MainApi(VoiceApi): - def __init__( - self, api: VoiceApi, vc: discord.VoiceClient, main: MainAudio - ) -> None: + def __init__(self, api: VoiceApi, vc: discord.VoiceClient, main: MainAudio) -> None: super().__init__(api, api.channel) self.vc = vc self.main = main @@ -258,20 +252,20 @@ class MainApi(VoiceApi): return MainApi(super()._sub(request), self.vc, self.main) def _api_text(self) -> str: - return 'main api' + return "main api" async def _api(self) -> ResponseType: match self.request: - case {'type': 'volume'}: + case {"type": "volume"}: return self.main.volume - case {'type': 'playing'}: + case {"type": "playing"}: return self.vc.is_playing() - case {'type': 'paused'}: + case {"type": "paused"}: return self.vc.is_paused() - case {'type': 'queueformat'}: + case {"type": "queueformat"}: return await self.main.queue.format() - case {'type': 'queuejson'}: - return await self.main.queue.pubjson(self.member, self.request.get('limit', 1000)) + case {"type": "queuejson"}: + return await self.main.queue.pubjson(self.member, self.request.get("limit", 1000)) case _: return await self._fall_through_api() @@ -287,30 +281,30 @@ class OperatorApi(UserApi): return OperatorApi(super()._sub(request)) def _api_text(self) -> str: - return 'operator api' + return "operator api" async def _api(self) -> ResponseType: match self.request: - case {'target': str() as targetname}: + case {"target": str() as targetname}: return await InstrumentationApi(self, targetname).api() - case {'type': 'resetmonitoring'}: + case {"type": "resetmonitoring"}: return self.pi.mainservice.pmonitoring.reset() - case {'type': 'guilds'}: + case {"type": "guilds"}: guilds = [] for guild in self.client.guilds: if self._guild_visible(guild): guilds.append( { - 'id': str(guild.id), - 'member_count': guild.member_count, - 'name': guild.name, + "id": str(guild.id), + "member_count": guild.member_count, + "name": guild.name, } ) return guilds - case {'type': 'sleep', 'duration': (float() | int()) as duration, 'echo': _ as echo}: + case {"type": "sleep", "duration": (float() | int()) as duration, "echo": _ as echo}: await asyncio.sleep(duration) return echo - case {'type': 'pool'}: + case {"type": "pool"}: return self.pi.mainservice.pool_json() case _: return await self._fall_through_api() @@ -319,7 +313,7 @@ class OperatorApi(UserApi): class InstrumentationApi(OperatorApi): class UnknownTarget(Api.UnknownApi): def json(self) -> dict: - return super().json() | {'unknowntarget': None} + return super().json() | {"unknowntarget": None} def __init__(self, api: OperatorApi, targetname: str) -> None: super().__init__(api) @@ -327,20 +321,18 @@ class InstrumentationApi(OperatorApi): self.targetname = targetname target_tuple = self.targets.targets.get(targetname, None) if target_tuple is None: - raise InstrumentationApi.UnknownTarget('unknown target', targetname) + raise InstrumentationApi.UnknownTarget("unknown target", targetname) self.target, self.methodname = target_tuple.value def _sub(self, request: dict) -> Self: return InstrumentationApi(super()._sub(request), self.targetname) def _api_text(self) -> str: - return 'instrumentation api' + return "instrumentation api" async def _api(self) -> ResponseType: match self.request: - case { - 'type': str() as instrumentationname - } if ( + case {"type": str() as instrumentationname} if ( instrumentation_factory := self.targets.instrumentations.get(instrumentationname) ) is not None: try: @@ -355,7 +347,7 @@ class InstrumentationApi(OperatorApi): ) except KeyError as e: raise InstrumentationApi.UnknownTarget( - 'binding failed', self.targetname, instrumentationname, str(e) + "binding failed", self.targetname, instrumentationname, str(e) ) from e if not isinstance(instrumentation, JsonLike): raise TypeError diff --git a/v6d3music/app.py b/v6d3music/app.py index edae2f6..8c5f566 100644 --- a/v6d3music/app.py +++ b/v6d3music/app.py @@ -1,244 +1,34 @@ import asyncio -import functools -import urllib.parse -from contextlib import AsyncExitStack -from pathlib import Path -from typing import Any, Callable, Coroutine, Generic, Hashable, TypeVar import aiohttp from aiohttp import web -from ptvp35 import * -from v6d0auth.appfactory import * -from v6d0auth.run_app import * -from v6d1tokens.client import * -from v6d3music.api import * -from v6d3music.config import auth_redirect, myroot -from v6d3music.utils.bytes_hash import * +from v6d0auth.appfactory import AppFactory +from v6d0auth.run_app import start_app +from v6d1tokens.client import request_token +from v6d3music.api import Api +from v6d3music.core.set_config import set_config -__all__ = ('AppContext',) - -T = TypeVar('T') -TKey = TypeVar('TKey', bound=Hashable) - - -class CachedEntry(Generic[T]): - def __init__(self, value: T, getter: Callable[[], Coroutine[Any, Any, T]]) -> None: - self.__value: T = value - self.__getter = getter - self.__task: asyncio.Future[T] = asyncio.Future() - self.__task.set_result(value) - - async def _set(self) -> T: - self.__value = await self.__getter() - return self.__value - - def get_nowait(self) -> T: - if self.__task.done(): - self.__task = asyncio.create_task(self._set()) - return self.__value - - async def get(self) -> T: - if self.__task.done(): - self.__task = asyncio.create_task(self._set()) - return await self.__task - - -class CachedDictionary(Generic[TKey, T]): - def __init__(self, factory: Callable[[TKey], Coroutine[Any, Any, T]]) -> None: - self.__factory = factory - self.__entries: dict[TKey, CachedEntry[T]] = {} - - def entry(self, key: TKey, default: T) -> CachedEntry[T]: - if key not in self.__entries: - self.__entries[key] = CachedEntry(default, functools.partial(self.__factory, key)) - return self.__entries[key] +__all__ = ("AppContext",) class MusicAppFactory(AppFactory): - def __init__( - self, - secret: str, - api: Api, - db: DbConnection - ): + def __init__(self, secret: str, api: Api): self.secret = secret - self.redirect = auth_redirect self.loop = asyncio.get_running_loop() self._api = api - self.db = db - self._token_clients: CachedDictionary[str, dict | None] = CachedDictionary( - self._token_client - ) - - def auth_link(self) -> str: - client_id = self._api.user_id() - if client_id is None: - return '' - else: - return f'https://discord.com/api/oauth2/authorize?client_id={client_id}' \ - f'&redirect_uri={urllib.parse.quote(self.redirect)}&response_type=code&scope=identify' - - async def code_token(self, code: str) -> dict: - client_id = self._api.user_id() - assert client_id is not None - data = { - 'client_id': str(client_id), - 'client_secret': self.secret, - 'grant_type': 'authorization_code', - 'code': code, - 'redirect_uri': self.redirect - } - headers = { - 'Content-Type': 'application/x-www-form-urlencoded' - } - async with aiohttp.ClientSession() as session: - async with session.post('https://discord.com/api/oauth2/token', data=data, headers=headers) as response: - return await response.json() - - async def _token_client(self, access_token: str) -> dict | None: - headers = { - 'Authorization': f'Bearer {access_token}' - } - async with aiohttp.ClientSession() as session: - async with session.get('https://discord.com/api/oauth2/@me', headers=headers) as response: - return await response.json() - - async def token_client(self, access_token: str) -> dict | None: - return self._token_clients.entry(access_token, None).get_nowait() - - async def session_client(self, data: dict) -> dict | None: - match data: - case {'token': {'access_token': str() as access_token}}: - pass - case _: - return None - return await self.token_client(access_token) - - @classmethod - def client_status(cls, sclient: dict) -> dict: - user = cls.client_user(sclient) - return { - 'expires': sclient.get('expires'), - 'user': (None if user is None else cls.user_status(user)), - } - - @classmethod - def user_status(cls, user: dict) -> dict: - return { - 'avatar': cls.user_avatar_url(user), - 'id': cls.user_id(user), - 'username': cls.user_username_full(user) - } - - @classmethod - def user_username_full(cls, user: dict) -> str | None: - match user: - case {'username': str() as username, 'discriminator': str() as discriminator}: - return f'{username}#{discriminator}' - case _: - return None - - @classmethod - def client_user(cls, sclient: dict) -> dict | None: - return sclient.get('user') - - @classmethod - def user_id(cls, user: dict) -> str | int | None: - return user.get('id') - - @classmethod - def user_avatar(cls, user: dict) -> str | None: - return user.get('avatar') - - @classmethod - def user_avatar_url(cls, user: dict) -> str | None: - cid = cls.user_id(user) - if cid is None: - return None - avatar = cls.user_avatar(user) - if avatar is None: - return None - return f'https://cdn.discordapp.com/avatars/{cid}/{avatar}.png' - - async def session_status(self, session: str) -> dict: - data = self.session_data(session) - sclient = await self.session_client(data) - return { - 'code_set': data.get('code') is not None, - 'token_set': data.get('token') is not None, - 'client': (None if sclient is None else self.client_status(sclient)) - } - - async def session_queue(self, session: str): - data = self.session_data(session) - sclient = await self.session_client(data) - if sclient is None: - return None - user = self.client_user(sclient) - if user is None: - return None - cid = self.user_id(user) - return cid - - def session_data(self, session: str | None) -> dict: - if session is None: - return {} - data = self.db.get(session, {}) - if not isinstance(data, dict): - return {} - return data def define_routes(self, routes: web.RouteTableDef) -> None: - @routes.get('/authlink/') - async def authlink(_request: web.Request) -> web.StreamResponse: - return web.Response(text=self.auth_link()) - - @routes.get('/auth/') - async def auth(request: web.Request) -> web.StreamResponse: - session = request.query.get('session') - state = request.query.get('state') - code = request.query.get('code') - match session, state, code: - case str() as session, str() as state, str() as code: - if bytes_hash(session.encode()) != state: - raise web.HTTPBadRequest - data = self.session_data(session) - data['code'] = code - data['token'] = await self.code_token(code) - await self.db.set(session, data) - return web.HTTPFound('/') - case _: - raise web.HTTPBadRequest - - @routes.get('/state/') - async def get_state(request: web.Request) -> web.Response: - session = str(request.query.get('session')) - return web.json_response( - data=f'{bytes_hash(session.encode())}' - ) - - @routes.get('/status/') - async def status(request: web.Request) -> web.Response: - session = str(request.query.get('session')) - return web.json_response( - data=await self.session_status(session) - ) - - @routes.post('/api/') + @routes.post("/api/") async def api(request: web.Request) -> web.Response: - session = request.query.get('session') - data = self.session_data(session) - sclient = await self.session_client(data) - if sclient is None: - raise web.HTTPUnauthorized - user = self.client_user(sclient) - if user is None: - raise web.HTTPUnauthorized - user_id = self.user_id(user) + async with aiohttp.ClientSession() as s, s.get( + "http://sessionservice/userid/", params={"session": request.query.get("session")} + ) as response: + user_id: int | None = await response.json() if user_id is None: raise web.HTTPUnauthorized - user_id = int(user_id) + if not isinstance(user_id, int): + raise TypeError jsr = await request.json() assert isinstance(jsr, dict) try: @@ -246,12 +36,6 @@ class MusicAppFactory(AppFactory): except Api.MisusedApi as e: return web.json_response(e.json(), status=404) - @routes.get('/whaturl/') - async def whaturl(request: web.Request) -> web.StreamResponse: - if request.headers.get('X-Forwarded-Proto') == 'https': - request = request.clone(scheme='https') - return web.json_response(str(request.url)) - class AppContext: def __init__(self, api: Api) -> None: @@ -259,36 +43,26 @@ class AppContext: async def start(self) -> tuple[web.Application, asyncio.Task[None]] | None: try: - factory = MusicAppFactory( - await request_token('music-client', 'token'), - self.api, - self.__db - ) + factory = MusicAppFactory(await request_token("music-client", "token"), self.api) + await set_config("secret", factory.secret) except aiohttp.ClientConnectorError: - print('no web app (likely due to no token)') + print("no web app (likely due to no token)") else: app = factory.app() task = asyncio.create_task(start_app(app)) return app, task - async def __aenter__(self) -> 'AppContext': - async with AsyncExitStack() as es: - self.__db = await es.enter_async_context(DbFactory(myroot / 'session.db', kvfactory=KVJson())) - self.__task: asyncio.Task[ - tuple[web.Application, asyncio.Task[None]] | None - ] = asyncio.create_task(self.start()) - self.__es = es.pop_all() - return self + async def __aenter__(self) -> "AppContext": + self.__task: asyncio.Task[tuple[web.Application, asyncio.Task[None]] | None] = asyncio.create_task(self.start()) + return self async def __aexit__(self, exc_type, exc_val, exc_tb): - async with self.__es: - if self.__task.done(): - result = await self.__task - if result is not None: - app, task = result - await task - await app.shutdown() - await app.cleanup() - else: - self.__task.cancel() - del self.__es + if self.__task.done(): + result = await self.__task + if result is not None: + app, task = result + await task + await app.shutdown() + await app.cleanup() + else: + self.__task.cancel() diff --git a/v6d3music/config.py b/v6d3music/config.py index 9c429f8..81046d9 100644 --- a/v6d3music/config.py +++ b/v6d3music/config.py @@ -2,7 +2,6 @@ import os from v6d0auth.config import root -prefix = os.getenv('v6prefix', '?/') -auth_redirect = os.getenv('v6redirect', 'https://music.parrrate.ru/auth/') -myroot = root / 'v6d3music' +prefix = os.getenv("v6prefix", "?/") +myroot = root / "v6d3music" myroot.mkdir(exist_ok=True) diff --git a/v6d3music/core/set_config.py b/v6d3music/core/set_config.py new file mode 100644 index 0000000..d4db44a --- /dev/null +++ b/v6d3music/core/set_config.py @@ -0,0 +1,24 @@ +import asyncio +from typing import Any, Callable, Coroutine, TypeVar + +import aiohttp + +T = TypeVar("T") + + +async def repeat(repeated: Callable[[], Coroutine[Any, Any, T]]) -> T: + for _ in range(60): + try: + return await repeated() + except aiohttp.ClientConnectorError: + await asyncio.sleep(1) + raise RuntimeError("cannot reach sessionservice") + + +async def set_config(key: str, value: Any) -> None: + json = {"key": key, "value": value} + async def call() -> None: + async with aiohttp.ClientSession() as s, s.post("http://sessionservice/config/", json=json) as response: + if response.status != 200: + raise RuntimeError("config request failed") + await repeat(call) diff --git a/v6d3music/main.py b/v6d3music/main.py index 329e776..5a91ce4 100644 --- a/v6d3music/main.py +++ b/v6d3music/main.py @@ -6,13 +6,13 @@ import time from traceback import print_exc import discord -from v6d2ctx.integration.event import * -from v6d2ctx.integration.targets import * from ptvp35 import * from rainbowadn.instrument import Instrumentation from v6d1tokens.client import * from v6d2ctx.handle_content import * +from v6d2ctx.integration.event import * +from v6d2ctx.integration.targets import * from v6d2ctx.pain import * from v6d2ctx.serve import * from v6d3music.api import * @@ -22,6 +22,7 @@ from v6d3music.config import prefix from v6d3music.core.caching import * from v6d3music.core.default_effects import * from v6d3music.core.mainservice import * +from v6d3music.core.set_config import set_config loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -47,7 +48,7 @@ _client = MusicClient( ) -banned_guilds = set(map(int, filter(bool, map(str.strip, os.getenv('banned_guilds', '').split(':'))))) +banned_guilds = set(map(int, filter(bool, map(str.strip, os.getenv("banned_guilds", "").split(":"))))) def guild_allowed(guild: discord.Guild | None) -> bool: @@ -71,31 +72,36 @@ def register_handlers(client: discord.Client, mainservice: MainService): @client.event async def on_ready(): - print('ready') + print("ready") + if client.user is None: + raise RuntimeError + await set_config("user-id", client.user.id) + await set_config("ready", True) await client.change_presence( activity=discord.Game( - name='феноменально', + name="феноменально", ) ) await mainservice.restore() + print("ready (startup finished)") class UpgradeABMInit(Instrumentation): def __init__(self): - super().__init__(ABlockMonitor, '__init__') + super().__init__(ABlockMonitor, "__init__") def instrument(self, method, abm, *, threshold=0.0, delta=10.0, interval=0.0): - print('created upgraded') + print("created upgraded") method(abm, threshold=threshold, delta=delta, interval=interval) abm.threshold = threshold class UpgradeABMTask(Instrumentation): def __init__(self): - super().__init__(ABlockMonitor, '_monitor') + super().__init__(ABlockMonitor, "_monitor") async def instrument(self, _, abm): - print('started upgraded') + print("started upgraded") while True: delta = abm.delta t = time.time() @@ -104,8 +110,7 @@ class UpgradeABMTask(Instrumentation): delay = spent - delta if delay > abm.threshold: abm.threshold = delay - print( - f'upgraded block monitor reached new peak delay {delay:.4f}') + print(f"upgraded block monitor reached new peak delay {delay:.4f}") interval = abm.interval if interval > 0: await asyncio.sleep(interval) @@ -130,45 +135,47 @@ class PathPrint(Instrumentation): print(self.pref, db._DbConnection__path) # type: ignore except Exception: from traceback import print_exc + print_exc() return result def _db_ee() -> contextlib.ExitStack: with contextlib.ExitStack() as es: - es.enter_context(PathPrint('_initialize', 'open :')) - es.enter_context(PathPrint('aclose', 'close:')) + es.enter_context(PathPrint("_initialize", "open :")) + es.enter_context(PathPrint("aclose", "close:")) return es.pop_all() raise RuntimeError async def amain(client: discord.Client): - roles = {key: value for key, value in os.environ.items() if key.startswith('roles')} + roles = {key: value for key, value in os.environ.items() if key.startswith("roles")} async with ( client, DefaultEffects() as defaulteffects, MainService(Targets(), defaulteffects, client, Events()) as mainservice, AppContext(Api(mainservice, roles)), - ABlockMonitor(delta=0.5) + ABlockMonitor(delta=0.5), ): register_handlers(client, mainservice) - if 'guerilla' in sys.argv: + if "guerilla" in sys.argv: from pathlib import Path - tokenpath = Path('.token.txt') + + tokenpath = Path(".token.txt") if tokenpath.exists(): token = tokenpath.read_text() else: - token = input('token:') + token = input("token:") tokenpath.write_text(token) - elif (token_ := os.getenv('trial_token')): + elif token_ := os.getenv("trial_token"): token = token_ else: - token = await request_token('music', 'token') + token = await request_token("music", "token") await client.login(token) - if os.getenv('v6tor', None) is None: - print('no tor') + if os.getenv("v6tor", None) is None: + print("no tor") await client.connect() - print('exited') + print("exited") def main() -> None: