diff --git a/base.requirements.txt b/base.requirements.txt index f5cd17c..1048483 100644 --- a/base.requirements.txt +++ b/base.requirements.txt @@ -1,4 +1,4 @@ aiohttp>=3.7.4,<4 -discord.py[voice]~=2.1.0 +discord.py[voice]~=2.2.0 yt-dlp~=2023.2.17 typing_extensions~=4.4.0 diff --git a/v6d3music/commands.py b/v6d3music/commands.py index fabcc40..0a27ebf 100644 --- a/v6d3music/commands.py +++ b/v6d3music/commands.py @@ -1,6 +1,8 @@ import shlex from typing import Callable +import discord + from v6d2ctx.at_of import * from v6d2ctx.context import * from v6d3music.core.default_effects import * @@ -10,45 +12,45 @@ from v6d3music.utils.catch import * from v6d3music.utils.effects_for_preset import * from v6d3music.utils.presets import * -import discord - -__all__ = ('get_of',) +__all__ = ("get_of",) def get_of(mainservice: MainService) -> Callable[[str], command_type]: at_of: AtOf[str, command_type] = AtOf() at, of = at_of() - @at('help') + @at("help") async def help_(ctx: Context, args: list[str]) -> None: match args: case []: - await ctx.reply('music bot\nhttps://music.parrrate.ru/docs/usage.html') + await ctx.reply("music bot\nhttps://music.parrrate.ru/docs/usage.html") case [name]: - await ctx.reply(f'help for {name}: `{name} help`') + await ctx.reply(f"help for {name}: `{name} help`") - @at('/') - @at('play') + @at("/") + @at("play") async def play(ctx: Context, args: list[str]) -> None: await catch( - ctx, args, - f''' + ctx, + args, + f""" `play ...args` `play url [- effects]/[+ preset] [[[h]]] [[m]] [s] [tor] ...args` `pause` `resume` presets: {shlex.join(allowed_presets)} - ''', - (), 'help' + """, + (), + "help", ) match args: - case ['this', *args]: + case ["this", *args]: reference = ctx.message.reference if reference is None: - raise Explicit('use reply') + raise Explicit("use reply") resolved = reference.resolved if not isinstance(resolved, discord.Message): - raise Explicit('reference message is either deleted or cannot be found') + raise Explicit("reference message is either deleted or cannot be found") attachments = resolved.attachments case [*args]: attachments = ctx.message.attachments @@ -57,17 +59,20 @@ presets: {shlex.join(allowed_presets)} async with mainservice.lock_for(ctx.guild): queue = await mainservice.context(ctx, create=True, force_play=False).queue() if attachments: - args = ['[[', *(attachment.url for attachment in attachments), ']]'] + args + args = ["[[", *(attachment.url for attachment in attachments), "]]"] + args async for audio in mainservice.yt_audios(ctx, args): queue.append(audio) - await ctx.reply('done') + await ctx.reply("done") - @at('skip') + @at("skip") async def skip(ctx: Context, args: list[str]) -> None: await catch( - ctx, args, ''' + ctx, + args, + """ `skip [first] [last]` -''', 'help' +""", + "help", ) assert ctx.member is not None match args: @@ -83,15 +88,18 @@ presets: {shlex.join(allowed_presets)} queue = await mainservice.context(ctx, create=False, force_play=False).queue() queue.skip_between(pos0, pos1, ctx.member) case _: - raise Explicit('misformatted') - await ctx.reply('done') + raise Explicit("misformatted") + await ctx.reply("done") - @at('to') + @at("to") async def skip_to(ctx: Context, args: list[str]) -> None: await catch( - ctx, args, ''' + ctx, + args, + """ `to [[h]] [m] s` -''', 'help' +""", + "help", ) match args: case [h, m, s, *args] if h.isdecimal() and m.isdecimal() and s.isdecimal(): @@ -101,83 +109,89 @@ presets: {shlex.join(allowed_presets)} case [s, *args] if s.isdecimal(): seconds = int(s) case _: - raise Explicit('misformatted, expected time') + raise Explicit("misformatted, expected time") match args: - case ['at', spos] if spos.isdecimal(): + case ["at", spos] if spos.isdecimal(): pos = int(spos) case []: pos = 0 case _: - raise Explicit('misformatted, expected position') + raise Explicit("misformatted, expected position") assert_admin(ctx.member) queue = await mainservice.context(ctx, create=False, force_play=False).queue() queue.queue[pos].set_seconds(seconds) - @at('effects') + @at("effects") async def effects_(ctx: Context, args: list[str]) -> None: await catch( - ctx, args, ''' + ctx, + args, + """ `effects - effects` `effects + preset` -''', 'help' +""", + "help", ) match args: - case ['-', effects]: + case ["-", effects]: pass - case ['+', preset]: + case ["+", preset]: effects = effects_for_preset(preset) case _: - raise Explicit('misformatted') + raise Explicit("misformatted") assert_admin(ctx.member) queue = await mainservice.context(ctx, create=False, force_play=False).queue() queue.queue[0].set_effects(effects) - @at('default') + @at("default") async def default(ctx: Context, args: list[str]) -> None: await catch( - ctx, args, ''' + ctx, + args, + """ `default - effects` `default + preset` `default none` - ''', 'help' + """, + "help", ) assert ctx.guild is not None match args: - case ['-', effects]: + case ["-", effects]: pass - case ['+', preset]: + case ["+", preset]: effects = effects_for_preset(preset) - case ['none']: + case ["none"]: effects = None case []: - await ctx.reply(f'current default effects: {mainservice.defaulteffects.get(ctx.guild.id)}') + await ctx.reply(f"current default effects: {mainservice.defaulteffects.get(ctx.guild.id)}") return case _: - raise Explicit('misformatted') + raise Explicit("misformatted") assert_admin(ctx.member) await mainservice.defaulteffects.set(ctx.guild.id, effects) - await ctx.reply(f'effects set to `{effects}`') + await ctx.reply(f"effects set to `{effects}`") - @at('repeat') + @at("repeat") async def repeat(ctx: Context, args: list[str]): match args: - case ['x', n_, *args] if n_.isdecimal(): + case ["x", n_, *args] if n_.isdecimal(): n = int(n_) case [*args]: n = 1 case _: raise RuntimeError match args: - case ['at', p_, *args] if p_.isdecimal(): + case ["at", p_, *args] if p_.isdecimal(): p = int(p_) case [*args]: p = 0 case _: raise RuntimeError match args: - case ['to', t_, *args] if t_.isdecimal(): + case ["to", t_, *args] if t_.isdecimal(): t = int(t_) - case ['to', 'end']: + case ["to", "end"]: t = None case [*args]: t = p + 1 @@ -187,44 +201,47 @@ presets: {shlex.join(allowed_presets)} case []: pass case _: - raise Explicit('misformatted (extra arguments)') + raise Explicit("misformatted (extra arguments)") assert_admin(ctx.member) queue = await mainservice.context(ctx, create=False, force_play=False).queue() queue.repeat(n, p, t) - @at('shuffle') + @at("shuffle") async def shuffle(ctx: Context, args: list[str]): assert_admin(ctx.member) queue = await mainservice.context(ctx, create=False, force_play=False).queue() queue.shuffle() - @at('branch') + @at("branch") async def branch(ctx: Context, args: list[str]): match args: - case ['-', effects]: + case ["-", effects]: pass - case ['+', preset]: + case ["+", preset]: effects = effects_for_preset(preset) - case ['none']: - effects = '' + case ["none"]: + effects = "" case []: effects = None case _: - raise Explicit('misformatted') + raise Explicit("misformatted") assert_admin(ctx.member) queue = await mainservice.context(ctx, create=False, force_play=False).queue() queue.branch(effects) - @at('//') - @at('queue') + @at("//") + @at("queue") async def queue_(ctx: Context, args: list[str]) -> None: await catch( - ctx, args, ''' + ctx, + args, + """ `queue` `queue clear` `queue resume` `queue pause` - ''', 'help' + """, + "help", ) assert ctx.member is not None match args: @@ -232,41 +249,41 @@ presets: {shlex.join(allowed_presets)} limit = 24 case [lstr] if lstr.isdecimal(): limit = int(lstr) - case ['tail', lstr] if lstr.isdecimal(): + case ["tail", lstr] if lstr.isdecimal(): limit = -int(lstr) if limit >= 0: - raise Explicit('limit of at least `1` required') - case ['clear']: + raise Explicit("limit of at least `1` required") + case ["clear"]: (await mainservice.context(ctx, create=False, force_play=False).queue()).clear(ctx.member) - await ctx.reply('done') + await ctx.reply("done") return - case ['resume']: + case ["resume"]: async with mainservice.lock_for(ctx.guild): await mainservice.context(ctx, create=True, force_play=True).vc() - await ctx.reply('done') + await ctx.reply("done") return - case ['pause']: + case ["pause"]: async with mainservice.lock_for(ctx.guild): vc = await mainservice.context(ctx, create=True, force_play=False).vc() vc.pause() - await ctx.reply('done') + await ctx.reply("done") return case _: - raise Explicit('misformatted') + raise Explicit("misformatted") await ctx.long( - ( - await ( - await mainservice.context(ctx, create=True, force_play=False).queue() - ).format(limit) - ).strip() or 'no queue' + (await (await mainservice.context(ctx, create=True, force_play=False).queue()).format(limit)).strip() + or "no queue" ) - @at('swap') + @at("swap") async def swap(ctx: Context, args: list[str]) -> None: await catch( - ctx, args, ''' + ctx, + args, + """ `swap a b` -''', 'help' +""", + "help", ) assert ctx.member is not None match args: @@ -274,14 +291,17 @@ presets: {shlex.join(allowed_presets)} a, b = int(a), int(b) (await mainservice.context(ctx, create=False, force_play=False).queue()).swap(ctx.member, a, b) case _: - raise Explicit('misformatted') + raise Explicit("misformatted") - @at('move') + @at("move") async def move(ctx: Context, args: list[str]) -> None: await catch( - ctx, args, ''' + ctx, + args, + """ `move a b` -''', 'help' +""", + "help", ) assert ctx.member is not None match args: @@ -289,14 +309,17 @@ presets: {shlex.join(allowed_presets)} a, b = int(a), int(b) (await mainservice.context(ctx, create=False, force_play=False).queue()).move(ctx.member, a, b) case _: - raise Explicit('misformatted') + raise Explicit("misformatted") - @at('volume') + @at("volume") async def volume_(ctx: Context, args: list[str]) -> None: await catch( - ctx, args, ''' + ctx, + args, + """ `volume volume` -''', 'help' +""", + "help", ) assert ctx.member is not None match args: @@ -305,27 +328,27 @@ presets: {shlex.join(allowed_presets)} await (await mainservice.context(ctx, create=True, force_play=False).main()).set(volume, ctx.member) case []: volume = (await mainservice.context(ctx, create=True, force_play=False).main()).get() - await ctx.reply(f'volume is {volume}') + await ctx.reply(f"volume is {volume}") case _: - raise Explicit('misformatted') + raise Explicit("misformatted") - @at('pause') + @at("pause") async def pause(ctx: Context, _args: list[str]) -> None: vc = await mainservice.context(ctx, create=False, force_play=False).vc() vc.pause() - @at('resume') + @at("resume") async def resume(ctx: Context, _args: list[str]) -> None: vc = await mainservice.context(ctx, create=False, force_play=True).vc() vc.resume() - @at('leave') + @at("leave") async def leave(ctx: Context, _args: list[str]) -> None: async with mainservice.lock_for(ctx.guild): vc, main = await mainservice.context(ctx, create=False, force_play=False).vc_main() queue = main.queue if queue.queue: - raise Explicit('queue not empty') + raise Explicit("queue not empty") await vc.disconnect() return of diff --git a/v6d3music/core/mainservice.py b/v6d3music/core/mainservice.py index a11574c..99a6d79 100644 --- a/v6d3music/core/mainservice.py +++ b/v6d3music/core/mainservice.py @@ -4,13 +4,13 @@ from contextlib import AsyncExitStack from typing import AsyncIterable, TypeVar import discord -from v6d2ctx.integration.event import * -from v6d2ctx.integration.responsetype import * -from v6d2ctx.integration.targets import * import v6d3music.processing.pool from ptvp35 import * from v6d2ctx.context import * +from v6d2ctx.integration.event import * +from v6d2ctx.integration.responsetype import * +from v6d2ctx.integration.targets import * from v6d2ctx.lock_for import * from v6d3music.config import myroot from v6d3music.core.caching import * @@ -82,6 +82,7 @@ class MainService: try: vc = await vch.connect() except discord.ClientException as e: + traceback.print_exc() vc = member.guild.voice_client assert vc is not None await member.guild.fetch_channels() diff --git a/v6d3music/main.py b/v6d3music/main.py index 5a91ce4..b60405f 100644 --- a/v6d3music/main.py +++ b/v6d3music/main.py @@ -1,9 +1,11 @@ import asyncio import contextlib import os +import struct import sys import time from traceback import print_exc +from typing import Any import discord @@ -13,7 +15,7 @@ from v6d1tokens.client import * from v6d2ctx.handle_content import * from v6d2ctx.integration.event import * from v6d2ctx.integration.targets import * -from v6d2ctx.pain import * +from v6d2ctx.pain import ABlockMonitor, ALog, SLog from v6d2ctx.serve import * from v6d3music.api import * from v6d3music.app import * @@ -178,6 +180,45 @@ async def amain(client: discord.Client): print("exited") +async def initial_connection(self, data: dict[str, Any]) -> None: + state = self._connection + state.ssrc = data["ssrc"] + state.voice_port = data["port"] + state.endpoint_ip = data["ip"] + + packet = bytearray(74) + struct.pack_into(">H", packet, 0, 1) # 1 = Send + struct.pack_into(">H", packet, 2, 70) # 70 = Length + struct.pack_into(">I", packet, 4, state.ssrc) + state.socket.sendto(packet, (state.endpoint_ip, state.voice_port)) + recv = await self.loop.sock_recv(state.socket, 70) + + # the ip is ascii starting at the 8th byte and ending at the first null + ip_start = 8 + ip_end = recv.index(0, ip_start) + state.ip = recv[ip_start:ip_end].decode("ascii") + + state.port = struct.unpack_from(">H", recv, 6)[0] + + # there *should* always be at least one supported mode (xsalsa20_poly1305) + modes = [mode for mode in data["modes"] if mode in self._connection.supported_modes] + + mode = modes[0] + await self.select_protocol(state.ip, state.port, mode) + + +__import__("discord.gateway").gateway.DiscordVoiceWebSocket.initial_connection = initial_connection + + def main() -> None: - with _upgrade_abm(), _db_ee(): + wst = __import__("discord.gateway").gateway.DiscordVoiceWebSocket + with ( + _upgrade_abm(), + _db_ee(), + ALog(discord.VoiceClient, "connect_websocket"), + ALog(wst, "poll_event"), + ALog(wst, "received_message"), + ALog(wst, "initial_connection"), + ALog(wst, "select_protocol"), + ): serve(amain(_client), _client, loop) diff --git a/v6d3music/run-bot.py b/v6d3music/run-bot.py index f5d590c..838f98c 100644 --- a/v6d3music/run-bot.py +++ b/v6d3music/run-bot.py @@ -1,5 +1,4 @@ from .main import main - if __name__ == '__main__': main()