forceclear

This commit is contained in:
AF 2023-12-27 04:33:55 +00:00
parent 7923f24dec
commit 79695fe142

View File

@ -39,7 +39,9 @@ class Api:
self.targets = mainservice.targets self.targets = mainservice.targets
self.targets.register_instance(self, "api", Async) self.targets.register_instance(self, "api", Async)
self.targets.register_instrumentation("Count", lambda t, n: Count(t, n)) self.targets.register_instrumentation("Count", lambda t, n: Count(t, n))
self.targets.register_instrumentation("Concurrency", lambda t, n: Concurrency(t, n), Async) self.targets.register_instrumentation(
"Concurrency", lambda t, n: Concurrency(t, n), Async
)
def user_id(self) -> int | None: def user_id(self) -> int | None:
if self.client.user is None: if self.client.user is None:
@ -100,31 +102,45 @@ class UserApi:
match requests: match requests:
case list(): case list():
return list( return list(
await asyncio.gather(*(self.sub(request, key).api() for (key, request) in enumerate(requests))) await asyncio.gather(
*(
self.sub(request, key).api()
for (key, request) in enumerate(requests)
)
)
) )
case dict(): case dict():
items = list(requests.items()) items = list(requests.items())
responses = await asyncio.gather( responses = await asyncio.gather(
*(self.sub({idkey: key} | base | request, key).api() for key, request in items) *(
self.sub({idkey: key} | base | request, key).api()
for key, request in items
)
)
return dict(
(key, response) for (key, _), response in zip(items, responses)
) )
return dict((key, response) for (key, _), response in zip(items, responses))
case _: case _:
raise Api.MisusedApi("that should not happen") raise Api.MisusedApi("that should not happen")
def _sub(self, request: dict) -> Self: def _sub(self, request: dict) -> UserApi:
return UserApi(self.session, request, self.user_id) return UserApi(self.session, request, self.user_id)
def sub(self, request: dict, key: str | int) -> Self: def sub(self, request: dict, key: str | int) -> UserApi:
sub = self._sub(request) sub = self._sub(request)
sub._parent = self sub._parent = self
sub._key = key sub._key = key
return sub return sub
async def to_guild_api(self, guild_id: int) -> GuildApi: async def to_guild_api(self, guild_id: int) -> GuildApi:
guild = self.client.get_guild(guild_id) or await self.client.fetch_guild(guild_id) guild = self.client.get_guild(guild_id) or await self.client.fetch_guild(
guild_id
)
if guild is None: if guild is None:
raise UserApi.UnknownMember("unknown guild") raise UserApi.UnknownMember("unknown guild")
member = guild.get_member(self.user_id) or await guild.fetch_member(self.user_id) member = guild.get_member(self.user_id) or await guild.fetch_member(
self.user_id
)
if member is None: if member is None:
raise UserApi.UnknownMember("unknown member of a guild") raise UserApi.UnknownMember("unknown member of a guild")
return GuildApi(self, member) return GuildApi(self, member)
@ -148,7 +164,9 @@ class UserApi:
async def _api(self) -> ResponseType: async def _api(self) -> ResponseType:
match self.request: match self.request:
case {"guild": str() as guild_id_str} if guild_id_str.isdecimal() and len(guild_id_str) < 100: case {"guild": str() as guild_id_str} if guild_id_str.isdecimal() and len(
guild_id_str
) < 100:
self.request.pop("guild") self.request.pop("guild")
return await (await self.to_guild_api(int(guild_id_str))).api() return await (await self.to_guild_api(int(guild_id_str))).api()
case {"operator": _}: case {"operator": _}:
@ -203,7 +221,7 @@ class GuildApi(UserApi):
raise GuildApi.VoiceNotConnected("bot not connected") raise GuildApi.VoiceNotConnected("bot not connected")
return VoiceApi(self, channel) return VoiceApi(self, channel)
def _sub(self, request: dict) -> Self: def _sub(self, request: dict) -> GuildApi:
return GuildApi(super()._sub(request), self.member) return GuildApi(super()._sub(request), self.member)
def _api_text(self) -> str: def _api_text(self) -> str:
@ -219,17 +237,21 @@ class GuildApi(UserApi):
class VoiceApi(GuildApi): class VoiceApi(GuildApi):
def __init__(self, api: GuildApi, channel: discord.VoiceChannel | discord.StageChannel) -> None: def __init__(
self, api: GuildApi, channel: discord.VoiceChannel | discord.StageChannel
) -> None:
super().__init__(api, api.member) super().__init__(api, api.member)
self.channel = channel self.channel = channel
self.mainservice = self.pi.mainservice self.mainservice = self.pi.mainservice
async def to_main_api(self) -> "MainApi": async def to_main_api(self) -> "MainApi":
vc = await self.mainservice.raw_vc_for_member(self.member) vc = await self.mainservice.raw_vc_for_member(self.member)
main = await self.mainservice.mode(create=False, force_play=False).main_for_raw_vc(vc) main = await self.mainservice.mode(
create=False, force_play=False
).main_for_raw_vc(vc)
return MainApi(self, vc, main) return MainApi(self, vc, main)
def _sub(self, request: dict) -> Self: def _sub(self, request: dict) -> VoiceApi:
return VoiceApi(super()._sub(request), self.channel) return VoiceApi(super()._sub(request), self.channel)
def _api_text(self) -> str: def _api_text(self) -> str:
@ -250,7 +272,7 @@ class MainApi(VoiceApi):
self.vc = vc self.vc = vc
self.main = main self.main = main
def _sub(self, request: dict) -> Self: def _sub(self, request: dict) -> MainApi:
return MainApi(super()._sub(request), self.vc, self.main) return MainApi(super()._sub(request), self.vc, self.main)
def _api_text(self) -> str: def _api_text(self) -> str:
@ -267,7 +289,9 @@ class MainApi(VoiceApi):
case {"type": "queueformat"}: case {"type": "queueformat"}:
return await self.main.queue.format() return await self.main.queue.format()
case {"type": "queuejson"}: case {"type": "queuejson"}:
return await self.main.queue.pubjson(self.member, self.request.get("limit", 1000)) return await self.main.queue.pubjson(
self.member, self.request.get("limit", 1000)
)
case _: case _:
return await self._fall_through_api() return await self._fall_through_api()
@ -279,12 +303,18 @@ class OperatorApi(UserApi):
def _guild_visible(self, guild: discord.Guild) -> bool: def _guild_visible(self, guild: discord.Guild) -> bool:
return True return True
def _sub(self, request: dict) -> Self: def _sub(self, request: dict) -> OperatorApi:
return OperatorApi(super()._sub(request)) return OperatorApi(super()._sub(request))
def _api_text(self) -> str: def _api_text(self) -> str:
return "operator api" return "operator api"
def _forceclear(self, guildid: int) -> None:
guild = self.client.get_guild(guildid)
if guild is None:
raise KeyError
self.pi.mainservice.mains[guild].queue.forceclear()
async def _api(self) -> ResponseType: async def _api(self) -> ResponseType:
match self.request: match self.request:
case {"target": str() as targetname}: case {"target": str() as targetname}:
@ -303,11 +333,17 @@ class OperatorApi(UserApi):
} }
) )
return guilds return guilds
case {"type": "sleep", "duration": (float() | int()) as duration, "echo": _ as echo}: case {
"type": "sleep",
"duration": (float() | int()) as duration,
"echo": _ as echo,
}:
await asyncio.sleep(duration) await asyncio.sleep(duration)
return echo return echo
case {"type": "pool"}: case {"type": "pool"}:
return self.pi.mainservice.pool_json() return self.pi.mainservice.pool_json()
case {"type": "forceclear", "guildid": str() as guildid}:
return self._forceclear(int(guildid))
case _: case _:
return await self._fall_through_api() return await self._fall_through_api()
@ -326,7 +362,7 @@ class InstrumentationApi(OperatorApi):
raise InstrumentationApi.UnknownTarget("unknown target", targetname) raise InstrumentationApi.UnknownTarget("unknown target", targetname)
self.target, self.methodname = target_tuple.value self.target, self.methodname = target_tuple.value
def _sub(self, request: dict) -> Self: def _sub(self, request: dict) -> InstrumentationApi:
return InstrumentationApi(super()._sub(request), self.targetname) return InstrumentationApi(super()._sub(request), self.targetname)
def _api_text(self) -> str: def _api_text(self) -> str:
@ -335,10 +371,13 @@ class InstrumentationApi(OperatorApi):
async def _api(self) -> ResponseType: async def _api(self) -> ResponseType:
match self.request: match self.request:
case {"type": str() as instrumentationname} if ( case {"type": str() as instrumentationname} if (
instrumentation_factory := self.targets.instrumentations.get(instrumentationname) instrumentation_factory := self.targets.instrumentations.get(
instrumentationname
)
) is not None: ) is not None:
try: try:
instrumentation: Instrumentation = await self.pi.mainservice.pmonitoring.get( instrumentation: Instrumentation = (
await self.pi.mainservice.pmonitoring.get(
self.targets.get_factory( self.targets.get_factory(
self.targetname, self.targetname,
self.target, self.target,
@ -347,6 +386,7 @@ class InstrumentationApi(OperatorApi):
instrumentation_factory.value, instrumentation_factory.value,
) )
) )
)
except KeyError as e: except KeyError as e:
raise InstrumentationApi.UnknownTarget( raise InstrumentationApi.UnknownTarget(
"binding failed", self.targetname, instrumentationname, str(e) "binding failed", self.targetname, instrumentationname, str(e)