forceclear

This commit is contained in:
AF 2023-12-27 04:33:55 +00:00
parent 7923f24dec
commit 79695fe142

View File

@ -39,7 +39,9 @@ class Api:
self.targets = mainservice.targets
self.targets.register_instance(self, "api", Async)
self.targets.register_instrumentation("Count", lambda t, n: Count(t, n))
self.targets.register_instrumentation("Concurrency", lambda t, n: Concurrency(t, n), Async)
self.targets.register_instrumentation(
"Concurrency", lambda t, n: Concurrency(t, n), Async
)
def user_id(self) -> int | None:
if self.client.user is None:
@ -100,31 +102,45 @@ class UserApi:
match requests:
case list():
return list(
await asyncio.gather(*(self.sub(request, key).api() for (key, request) in enumerate(requests)))
await asyncio.gather(
*(
self.sub(request, key).api()
for (key, request) in enumerate(requests)
)
)
)
case dict():
items = list(requests.items())
responses = await asyncio.gather(
*(self.sub({idkey: key} | base | request, key).api() for key, request in items)
*(
self.sub({idkey: key} | base | request, key).api()
for key, request in items
)
)
return dict(
(key, response) for (key, _), response in zip(items, responses)
)
return dict((key, response) for (key, _), response in zip(items, responses))
case _:
raise Api.MisusedApi("that should not happen")
def _sub(self, request: dict) -> Self:
def _sub(self, request: dict) -> UserApi:
return UserApi(self.session, request, self.user_id)
def sub(self, request: dict, key: str | int) -> Self:
def sub(self, request: dict, key: str | int) -> UserApi:
sub = self._sub(request)
sub._parent = self
sub._key = key
return sub
async def to_guild_api(self, guild_id: int) -> GuildApi:
guild = self.client.get_guild(guild_id) or await self.client.fetch_guild(guild_id)
guild = self.client.get_guild(guild_id) or await self.client.fetch_guild(
guild_id
)
if guild is None:
raise UserApi.UnknownMember("unknown guild")
member = guild.get_member(self.user_id) or await guild.fetch_member(self.user_id)
member = guild.get_member(self.user_id) or await guild.fetch_member(
self.user_id
)
if member is None:
raise UserApi.UnknownMember("unknown member of a guild")
return GuildApi(self, member)
@ -148,7 +164,9 @@ class UserApi:
async def _api(self) -> ResponseType:
match self.request:
case {"guild": str() as guild_id_str} if guild_id_str.isdecimal() and len(guild_id_str) < 100:
case {"guild": str() as guild_id_str} if guild_id_str.isdecimal() and len(
guild_id_str
) < 100:
self.request.pop("guild")
return await (await self.to_guild_api(int(guild_id_str))).api()
case {"operator": _}:
@ -203,7 +221,7 @@ class GuildApi(UserApi):
raise GuildApi.VoiceNotConnected("bot not connected")
return VoiceApi(self, channel)
def _sub(self, request: dict) -> Self:
def _sub(self, request: dict) -> GuildApi:
return GuildApi(super()._sub(request), self.member)
def _api_text(self) -> str:
@ -219,17 +237,21 @@ class GuildApi(UserApi):
class VoiceApi(GuildApi):
def __init__(self, api: GuildApi, channel: discord.VoiceChannel | discord.StageChannel) -> None:
def __init__(
self, api: GuildApi, channel: discord.VoiceChannel | discord.StageChannel
) -> None:
super().__init__(api, api.member)
self.channel = channel
self.mainservice = self.pi.mainservice
async def to_main_api(self) -> "MainApi":
vc = await self.mainservice.raw_vc_for_member(self.member)
main = await self.mainservice.mode(create=False, force_play=False).main_for_raw_vc(vc)
main = await self.mainservice.mode(
create=False, force_play=False
).main_for_raw_vc(vc)
return MainApi(self, vc, main)
def _sub(self, request: dict) -> Self:
def _sub(self, request: dict) -> VoiceApi:
return VoiceApi(super()._sub(request), self.channel)
def _api_text(self) -> str:
@ -250,7 +272,7 @@ class MainApi(VoiceApi):
self.vc = vc
self.main = main
def _sub(self, request: dict) -> Self:
def _sub(self, request: dict) -> MainApi:
return MainApi(super()._sub(request), self.vc, self.main)
def _api_text(self) -> str:
@ -267,7 +289,9 @@ class MainApi(VoiceApi):
case {"type": "queueformat"}:
return await self.main.queue.format()
case {"type": "queuejson"}:
return await self.main.queue.pubjson(self.member, self.request.get("limit", 1000))
return await self.main.queue.pubjson(
self.member, self.request.get("limit", 1000)
)
case _:
return await self._fall_through_api()
@ -279,12 +303,18 @@ class OperatorApi(UserApi):
def _guild_visible(self, guild: discord.Guild) -> bool:
return True
def _sub(self, request: dict) -> Self:
def _sub(self, request: dict) -> OperatorApi:
return OperatorApi(super()._sub(request))
def _api_text(self) -> str:
return "operator api"
def _forceclear(self, guildid: int) -> None:
guild = self.client.get_guild(guildid)
if guild is None:
raise KeyError
self.pi.mainservice.mains[guild].queue.forceclear()
async def _api(self) -> ResponseType:
match self.request:
case {"target": str() as targetname}:
@ -303,11 +333,17 @@ class OperatorApi(UserApi):
}
)
return guilds
case {"type": "sleep", "duration": (float() | int()) as duration, "echo": _ as echo}:
case {
"type": "sleep",
"duration": (float() | int()) as duration,
"echo": _ as echo,
}:
await asyncio.sleep(duration)
return echo
case {"type": "pool"}:
return self.pi.mainservice.pool_json()
case {"type": "forceclear", "guildid": str() as guildid}:
return self._forceclear(int(guildid))
case _:
return await self._fall_through_api()
@ -326,7 +362,7 @@ class InstrumentationApi(OperatorApi):
raise InstrumentationApi.UnknownTarget("unknown target", targetname)
self.target, self.methodname = target_tuple.value
def _sub(self, request: dict) -> Self:
def _sub(self, request: dict) -> InstrumentationApi:
return InstrumentationApi(super()._sub(request), self.targetname)
def _api_text(self) -> str:
@ -335,16 +371,20 @@ class InstrumentationApi(OperatorApi):
async def _api(self) -> ResponseType:
match self.request:
case {"type": str() as instrumentationname} if (
instrumentation_factory := self.targets.instrumentations.get(instrumentationname)
instrumentation_factory := self.targets.instrumentations.get(
instrumentationname
)
) is not None:
try:
instrumentation: Instrumentation = await self.pi.mainservice.pmonitoring.get(
self.targets.get_factory(
self.targetname,
self.target,
self.methodname,
instrumentationname,
instrumentation_factory.value,
instrumentation: Instrumentation = (
await self.pi.mainservice.pmonitoring.get(
self.targets.get_factory(
self.targetname,
self.target,
self.methodname,
instrumentationname,
instrumentation_factory.value,
)
)
)
except KeyError as e: