diff --git a/v6d3music/api.py b/v6d3music/api.py index 67d4bf2..a2b37e6 100644 --- a/v6d3music/api.py +++ b/v6d3music/api.py @@ -23,34 +23,27 @@ class Api: def __init__(self, explicit: Explicit) -> None: super().__init__(*explicit.args) - def __init__(self, client: discord.Client) -> None: + def __init__(self, client: discord.Client, roles: dict[str, str]) -> None: self.client = client + self.roles = roles + + def is_operator(self, user_id: int) -> bool: + return '(operator)' in self.roles.get(f'roles{user_id}', '') async def api(self, request: dict, user_id: int) -> ResponseType: - return await UserApi(self.client, request, user_id).api() + return await UserApi(self, request, user_id).api() -class UserApi(Api): +class UserApi: class UnknownMember(Api.MisusedApi): pass - def __init__(self, client: discord.Client, request: dict, user_id: int) -> None: - super().__init__(client) + def __init__(self, api: Api, request: dict, user_id: int) -> None: + self.pi = api + self.client = api.client self.request = request self.user_id = user_id - async def _guild_api(self, guild_id: int) -> 'GuildApi': - guild = self.client.get_guild(guild_id) or await self.client.fetch_guild(guild_id) - if guild is None: - raise UserApi.UnknownMember('unknown guild') - member = guild.get_member(self.user_id) or await guild.fetch_member(self.user_id) - if member is None: - raise UserApi.UnknownMember('unknown member of a guild') - return GuildApi(self.client, self.request, member) - - def sub(self, request: dict) -> 'UserApi': - return UserApi(self.client, request, self.user_id) - async def subs(self, requests: list[dict] | dict[str, dict]) -> ResponseType: match requests: case list(): @@ -68,11 +61,31 @@ class UserApi(Api): case _: raise Api.MisusedApi('that should not happen') + def sub(self, request: dict) -> 'UserApi': + return UserApi(self.pi, request, self.user_id) + + async def _guild_api(self, guild_id: int) -> 'GuildApi': + guild = self.client.get_guild(guild_id) or await self.client.fetch_guild(guild_id) + if guild is None: + raise UserApi.UnknownMember('unknown guild') + member = guild.get_member(self.user_id) or await guild.fetch_member(self.user_id) + if member is None: + raise UserApi.UnknownMember('unknown member of a guild') + return GuildApi(self.pi, self.request, member) + + async def _operator_api(self) -> 'OperatorApi': + if not self.pi.is_operator(self.user_id): + raise UserApi.UnknownMember('not an operator') + return OperatorApi(self.pi, self.request, self.user_id) + async def _api(self) -> ResponseType: match self.request: case {'guild': str() as guild_id_str} if guild_id_str.isdecimal() and len(guild_id_str) < 100: self.request.pop('guild') return await (await self._guild_api(int(guild_id_str))).api() + case {'operator': _}: + self.request.pop('operator') + return await (await self._operator_api()).api() case {'type': 'ping', 't': (float() | int()) as t}: return time.time() - t case {'type': 'guilds'}: @@ -107,8 +120,8 @@ class GuildApi(UserApi): class VoiceNotConnected(Api.MisusedApi): pass - def __init__(self, client: discord.Client, request: dict, member: discord.Member) -> None: - super().__init__(client, request, member.id) + def __init__(self, api: Api, request: dict, member: discord.Member) -> None: + super().__init__(api, request, member.id) self.member = member self.guild = member.guild @@ -123,10 +136,10 @@ class GuildApi(UserApi): raise GuildApi.VoiceNotConnected('bot client user not initialised') if self.client.user.id not in channel.voice_states: raise GuildApi.VoiceNotConnected('bot not connected') - return VoiceApi(self.client, self.request, self.member, channel) + return VoiceApi(self.pi, self.request, self.member, channel) def sub(self, request: dict) -> 'GuildApi': - return GuildApi(self.client, request, self.member) + return GuildApi(self.pi, request, self.member) async def _api(self) -> ResponseType: match self.request: @@ -143,18 +156,18 @@ class GuildApi(UserApi): class VoiceApi(GuildApi): def __init__( - self, client: discord.Client, request: dict, member: discord.Member, channel: discord.VoiceChannel | discord.StageChannel + self, api: Api, request: dict, member: discord.Member, channel: discord.VoiceChannel | discord.StageChannel ) -> None: - super().__init__(client, request, member) + super().__init__(api, request, member) self.channel = channel async def _main_api(self) -> 'MainApi': vc = await raw_vc_for_member(self.member) main = await main_for_raw_vc(vc, create=False, force_play=False) - return MainApi(self.client, self.request, self.member, self.channel, vc, main) + return MainApi(self.pi, self.request, self.member, self.channel, vc, main) def sub(self, request: dict) -> 'VoiceApi': - return VoiceApi(self.client, request, self.member, self.channel) + return VoiceApi(self.pi, request, self.member, self.channel) async def _api(self) -> ResponseType: match self.request: @@ -171,15 +184,15 @@ class VoiceApi(GuildApi): class MainApi(VoiceApi): def __init__( - self, client: discord.Client, request: dict, member: discord.Member, channel: discord.VoiceChannel | discord.StageChannel, + self, api: Api, request: dict, member: discord.Member, channel: discord.VoiceChannel | discord.StageChannel, vc: discord.VoiceClient, main: MainAudio ) -> None: - super().__init__(client, request, member, channel) + super().__init__(api, request, member, channel) self.vc = vc self.main = main def sub(self, request: dict) -> 'MainApi': - return MainApi(self.client, request, self.member, self.channel, self.vc, self.main) + return MainApi(self.pi, request, self.member, self.channel, self.vc, self.main) async def _api(self) -> ResponseType: match self.request: @@ -197,3 +210,32 @@ class MainApi(VoiceApi): return await self.subs(requests) case _: raise Api.UnknownApi('unknown main api') + + +class OperatorApi(UserApi): + def sub(self, request: dict) -> 'OperatorApi': + return OperatorApi(self.pi, request, self.user_id) + + async def _guild_visible(self, guild: discord.Guild) -> bool: + return True + + async def _api(self) -> ResponseType: + match self.request: + case {'type': 'guilds'}: + guilds = [] + for guild in self.client.guilds: + if self._guild_visible(guild): + guilds.append( + { + 'id': str(guild.id), + 'member_count': guild.member_count, + 'name': guild.name, + } + ) + return guilds + case {'type': '?'}: + return 'this is operator api' + case {'type': '*', 'requests': list() | dict() as requests}: + return await self.subs(requests) + case _: + raise Api.UnknownApi('unknown operator api') diff --git a/v6d3music/app.py b/v6d3music/app.py index c8bc5d9..74d6102 100644 --- a/v6d3music/app.py +++ b/v6d3music/app.py @@ -1,5 +1,6 @@ import asyncio import functools +import os import urllib.parse from pathlib import Path from typing import Any, Callable, Coroutine, Generic, Hashable, TypeVar @@ -281,7 +282,11 @@ class MusicAppFactory(AppFactory): @classmethod async def start(cls, client: discord.Client): try: - factory = cls(await request_token('music-client', 'token'), client, Api(client)) + factory = cls( + await request_token('music-client', 'token'), + client, + Api(client, {key: value for key, value in os.environ.items() if key.startswith('roles')}) + ) except aiohttp.ClientConnectorError: print('no web app (likely due to no token)') else: diff --git a/v6d3music/html/main.js b/v6d3music/html/main.js index d4a8fc7..c42bccf 100644 --- a/v6d3music/html/main.js +++ b/v6d3music/html/main.js @@ -196,3 +196,8 @@ const pageHome = async () => { baseEl("div", await aQueueWidget()) ); }; +aApi({ + type: "guilds", + operator: null, + catches: { "not an operator": null, "*": null }, +}).then(console.log); diff --git a/v6d3music/run-bot.py b/v6d3music/run-bot.py index 365c620..bd9bd9b 100644 --- a/v6d3music/run-bot.py +++ b/v6d3music/run-bot.py @@ -102,9 +102,21 @@ async def on_ready(): await restore_vcs() +banned_guilds = set(map(int, map(str.strip, os.getenv('banned_guilds', '').split(':')))) + + +def guild_allowed(guild: discord.Guild | None) -> bool: + return guild is not None and guild.id not in banned_guilds + + +def message_allowed(message: discord.Message) -> bool: + return guild_allowed(message.guild) + + @client.event async def on_message(message: discord.Message) -> None: - await handle_content(message, message.content, prefix, client) + if message_allowed(message): + await handle_content(message, message.content, prefix, client) async def save_queues(delay: bool): diff --git a/v6d3music/utils/aextract.py b/v6d3music/utils/aextract.py index a6fa580..6a17283 100644 --- a/v6d3music/utils/aextract.py +++ b/v6d3music/utils/aextract.py @@ -7,12 +7,11 @@ from v6d3music.utils.extract import extract async def aextract(params: dict, url: str, **kwargs): - with Benchmark('AEX'): - with ProcessPoolExecutor() as pool: - return await asyncio.get_running_loop().run_in_executor( - pool, - extract, - params, - url, - kwargs - ) + with ProcessPoolExecutor() as pool: + return await asyncio.get_running_loop().run_in_executor( + pool, + extract, + params, + url, + kwargs + )