diff --git a/v6d3music/api.py b/v6d3music/api.py index 1264efe..4c61d6a 100644 --- a/v6d3music/api.py +++ b/v6d3music/api.py @@ -39,7 +39,9 @@ class Api: self.targets = mainservice.targets self.targets.register_instance(self, "api", Async) self.targets.register_instrumentation("Count", lambda t, n: Count(t, n)) - self.targets.register_instrumentation("Concurrency", lambda t, n: Concurrency(t, n), Async) + self.targets.register_instrumentation( + "Concurrency", lambda t, n: Concurrency(t, n), Async + ) def user_id(self) -> int | None: if self.client.user is None: @@ -100,31 +102,45 @@ class UserApi: match requests: case list(): return list( - await asyncio.gather(*(self.sub(request, key).api() for (key, request) in enumerate(requests))) + await asyncio.gather( + *( + self.sub(request, key).api() + for (key, request) in enumerate(requests) + ) + ) ) case dict(): items = list(requests.items()) responses = await asyncio.gather( - *(self.sub({idkey: key} | base | request, key).api() for key, request in items) + *( + self.sub({idkey: key} | base | request, key).api() + for key, request in items + ) + ) + return dict( + (key, response) for (key, _), response in zip(items, responses) ) - return dict((key, response) for (key, _), response in zip(items, responses)) case _: raise Api.MisusedApi("that should not happen") - def _sub(self, request: dict) -> Self: + def _sub(self, request: dict) -> UserApi: return UserApi(self.session, request, self.user_id) - def sub(self, request: dict, key: str | int) -> Self: + def sub(self, request: dict, key: str | int) -> UserApi: sub = self._sub(request) sub._parent = self sub._key = key return sub async def to_guild_api(self, guild_id: int) -> GuildApi: - guild = self.client.get_guild(guild_id) or await self.client.fetch_guild(guild_id) + guild = self.client.get_guild(guild_id) or await self.client.fetch_guild( + guild_id + ) if guild is None: raise UserApi.UnknownMember("unknown guild") - member = guild.get_member(self.user_id) or await guild.fetch_member(self.user_id) + member = guild.get_member(self.user_id) or await guild.fetch_member( + self.user_id + ) if member is None: raise UserApi.UnknownMember("unknown member of a guild") return GuildApi(self, member) @@ -148,7 +164,9 @@ class UserApi: async def _api(self) -> ResponseType: match self.request: - case {"guild": str() as guild_id_str} if guild_id_str.isdecimal() and len(guild_id_str) < 100: + case {"guild": str() as guild_id_str} if guild_id_str.isdecimal() and len( + guild_id_str + ) < 100: self.request.pop("guild") return await (await self.to_guild_api(int(guild_id_str))).api() case {"operator": _}: @@ -203,7 +221,7 @@ class GuildApi(UserApi): raise GuildApi.VoiceNotConnected("bot not connected") return VoiceApi(self, channel) - def _sub(self, request: dict) -> Self: + def _sub(self, request: dict) -> GuildApi: return GuildApi(super()._sub(request), self.member) def _api_text(self) -> str: @@ -219,17 +237,21 @@ class GuildApi(UserApi): class VoiceApi(GuildApi): - def __init__(self, api: GuildApi, channel: discord.VoiceChannel | discord.StageChannel) -> None: + def __init__( + self, api: GuildApi, channel: discord.VoiceChannel | discord.StageChannel + ) -> None: super().__init__(api, api.member) self.channel = channel self.mainservice = self.pi.mainservice async def to_main_api(self) -> "MainApi": vc = await self.mainservice.raw_vc_for_member(self.member) - main = await self.mainservice.mode(create=False, force_play=False).main_for_raw_vc(vc) + main = await self.mainservice.mode( + create=False, force_play=False + ).main_for_raw_vc(vc) return MainApi(self, vc, main) - def _sub(self, request: dict) -> Self: + def _sub(self, request: dict) -> VoiceApi: return VoiceApi(super()._sub(request), self.channel) def _api_text(self) -> str: @@ -250,7 +272,7 @@ class MainApi(VoiceApi): self.vc = vc self.main = main - def _sub(self, request: dict) -> Self: + def _sub(self, request: dict) -> MainApi: return MainApi(super()._sub(request), self.vc, self.main) def _api_text(self) -> str: @@ -267,7 +289,9 @@ class MainApi(VoiceApi): case {"type": "queueformat"}: return await self.main.queue.format() case {"type": "queuejson"}: - return await self.main.queue.pubjson(self.member, self.request.get("limit", 1000)) + return await self.main.queue.pubjson( + self.member, self.request.get("limit", 1000) + ) case _: return await self._fall_through_api() @@ -279,12 +303,18 @@ class OperatorApi(UserApi): def _guild_visible(self, guild: discord.Guild) -> bool: return True - def _sub(self, request: dict) -> Self: + def _sub(self, request: dict) -> OperatorApi: return OperatorApi(super()._sub(request)) def _api_text(self) -> str: return "operator api" + def _forceclear(self, guildid: int) -> None: + guild = self.client.get_guild(guildid) + if guild is None: + raise KeyError + self.pi.mainservice.mains[guild].queue.forceclear() + async def _api(self) -> ResponseType: match self.request: case {"target": str() as targetname}: @@ -303,11 +333,17 @@ class OperatorApi(UserApi): } ) return guilds - case {"type": "sleep", "duration": (float() | int()) as duration, "echo": _ as echo}: + case { + "type": "sleep", + "duration": (float() | int()) as duration, + "echo": _ as echo, + }: await asyncio.sleep(duration) return echo case {"type": "pool"}: return self.pi.mainservice.pool_json() + case {"type": "forceclear", "guildid": str() as guildid}: + return self._forceclear(int(guildid)) case _: return await self._fall_through_api() @@ -326,7 +362,7 @@ class InstrumentationApi(OperatorApi): raise InstrumentationApi.UnknownTarget("unknown target", targetname) self.target, self.methodname = target_tuple.value - def _sub(self, request: dict) -> Self: + def _sub(self, request: dict) -> InstrumentationApi: return InstrumentationApi(super()._sub(request), self.targetname) def _api_text(self) -> str: @@ -335,16 +371,20 @@ class InstrumentationApi(OperatorApi): async def _api(self) -> ResponseType: match self.request: case {"type": str() as instrumentationname} if ( - instrumentation_factory := self.targets.instrumentations.get(instrumentationname) + instrumentation_factory := self.targets.instrumentations.get( + instrumentationname + ) ) is not None: try: - instrumentation: Instrumentation = await self.pi.mainservice.pmonitoring.get( - self.targets.get_factory( - self.targetname, - self.target, - self.methodname, - instrumentationname, - instrumentation_factory.value, + instrumentation: Instrumentation = ( + await self.pi.mainservice.pmonitoring.get( + self.targets.get_factory( + self.targetname, + self.target, + self.methodname, + instrumentationname, + instrumentation_factory.value, + ) ) ) except KeyError as e: