initial_connection patch

This commit is contained in:
AF 2023-03-01 19:04:24 +00:00
parent 07bae21312
commit 834d9c159f
5 changed files with 162 additions and 98 deletions

View File

@ -1,4 +1,4 @@
aiohttp>=3.7.4,<4 aiohttp>=3.7.4,<4
discord.py[voice]~=2.1.0 discord.py[voice]~=2.2.0
yt-dlp~=2023.2.17 yt-dlp~=2023.2.17
typing_extensions~=4.4.0 typing_extensions~=4.4.0

View File

@ -1,6 +1,8 @@
import shlex import shlex
from typing import Callable from typing import Callable
import discord
from v6d2ctx.at_of import * from v6d2ctx.at_of import *
from v6d2ctx.context import * from v6d2ctx.context import *
from v6d3music.core.default_effects import * from v6d3music.core.default_effects import *
@ -10,45 +12,45 @@ from v6d3music.utils.catch import *
from v6d3music.utils.effects_for_preset import * from v6d3music.utils.effects_for_preset import *
from v6d3music.utils.presets import * from v6d3music.utils.presets import *
import discord __all__ = ("get_of",)
__all__ = ('get_of',)
def get_of(mainservice: MainService) -> Callable[[str], command_type]: def get_of(mainservice: MainService) -> Callable[[str], command_type]:
at_of: AtOf[str, command_type] = AtOf() at_of: AtOf[str, command_type] = AtOf()
at, of = at_of() at, of = at_of()
@at('help') @at("help")
async def help_(ctx: Context, args: list[str]) -> None: async def help_(ctx: Context, args: list[str]) -> None:
match args: match args:
case []: case []:
await ctx.reply('music bot\nhttps://music.parrrate.ru/docs/usage.html') await ctx.reply("music bot\nhttps://music.parrrate.ru/docs/usage.html")
case [name]: case [name]:
await ctx.reply(f'help for {name}: `{name} help`') await ctx.reply(f"help for {name}: `{name} help`")
@at('/') @at("/")
@at('play') @at("play")
async def play(ctx: Context, args: list[str]) -> None: async def play(ctx: Context, args: list[str]) -> None:
await catch( await catch(
ctx, args, ctx,
f''' args,
f"""
`play ...args` `play ...args`
`play url [- effects]/[+ preset] [[[h]]] [[m]] [s] [tor] ...args` `play url [- effects]/[+ preset] [[[h]]] [[m]] [s] [tor] ...args`
`pause` `pause`
`resume` `resume`
presets: {shlex.join(allowed_presets)} presets: {shlex.join(allowed_presets)}
''', """,
(), 'help' (),
"help",
) )
match args: match args:
case ['this', *args]: case ["this", *args]:
reference = ctx.message.reference reference = ctx.message.reference
if reference is None: if reference is None:
raise Explicit('use reply') raise Explicit("use reply")
resolved = reference.resolved resolved = reference.resolved
if not isinstance(resolved, discord.Message): if not isinstance(resolved, discord.Message):
raise Explicit('reference message is either deleted or cannot be found') raise Explicit("reference message is either deleted or cannot be found")
attachments = resolved.attachments attachments = resolved.attachments
case [*args]: case [*args]:
attachments = ctx.message.attachments attachments = ctx.message.attachments
@ -57,17 +59,20 @@ presets: {shlex.join(allowed_presets)}
async with mainservice.lock_for(ctx.guild): async with mainservice.lock_for(ctx.guild):
queue = await mainservice.context(ctx, create=True, force_play=False).queue() queue = await mainservice.context(ctx, create=True, force_play=False).queue()
if attachments: if attachments:
args = ['[[', *(attachment.url for attachment in attachments), ']]'] + args args = ["[[", *(attachment.url for attachment in attachments), "]]"] + args
async for audio in mainservice.yt_audios(ctx, args): async for audio in mainservice.yt_audios(ctx, args):
queue.append(audio) queue.append(audio)
await ctx.reply('done') await ctx.reply("done")
@at('skip') @at("skip")
async def skip(ctx: Context, args: list[str]) -> None: async def skip(ctx: Context, args: list[str]) -> None:
await catch( await catch(
ctx, args, ''' ctx,
args,
"""
`skip [first] [last]` `skip [first] [last]`
''', 'help' """,
"help",
) )
assert ctx.member is not None assert ctx.member is not None
match args: match args:
@ -83,15 +88,18 @@ presets: {shlex.join(allowed_presets)}
queue = await mainservice.context(ctx, create=False, force_play=False).queue() queue = await mainservice.context(ctx, create=False, force_play=False).queue()
queue.skip_between(pos0, pos1, ctx.member) queue.skip_between(pos0, pos1, ctx.member)
case _: case _:
raise Explicit('misformatted') raise Explicit("misformatted")
await ctx.reply('done') await ctx.reply("done")
@at('to') @at("to")
async def skip_to(ctx: Context, args: list[str]) -> None: async def skip_to(ctx: Context, args: list[str]) -> None:
await catch( await catch(
ctx, args, ''' ctx,
args,
"""
`to [[h]] [m] s` `to [[h]] [m] s`
''', 'help' """,
"help",
) )
match args: match args:
case [h, m, s, *args] if h.isdecimal() and m.isdecimal() and s.isdecimal(): case [h, m, s, *args] if h.isdecimal() and m.isdecimal() and s.isdecimal():
@ -101,83 +109,89 @@ presets: {shlex.join(allowed_presets)}
case [s, *args] if s.isdecimal(): case [s, *args] if s.isdecimal():
seconds = int(s) seconds = int(s)
case _: case _:
raise Explicit('misformatted, expected time') raise Explicit("misformatted, expected time")
match args: match args:
case ['at', spos] if spos.isdecimal(): case ["at", spos] if spos.isdecimal():
pos = int(spos) pos = int(spos)
case []: case []:
pos = 0 pos = 0
case _: case _:
raise Explicit('misformatted, expected position') raise Explicit("misformatted, expected position")
assert_admin(ctx.member) assert_admin(ctx.member)
queue = await mainservice.context(ctx, create=False, force_play=False).queue() queue = await mainservice.context(ctx, create=False, force_play=False).queue()
queue.queue[pos].set_seconds(seconds) queue.queue[pos].set_seconds(seconds)
@at('effects') @at("effects")
async def effects_(ctx: Context, args: list[str]) -> None: async def effects_(ctx: Context, args: list[str]) -> None:
await catch( await catch(
ctx, args, ''' ctx,
args,
"""
`effects - effects` `effects - effects`
`effects + preset` `effects + preset`
''', 'help' """,
"help",
) )
match args: match args:
case ['-', effects]: case ["-", effects]:
pass pass
case ['+', preset]: case ["+", preset]:
effects = effects_for_preset(preset) effects = effects_for_preset(preset)
case _: case _:
raise Explicit('misformatted') raise Explicit("misformatted")
assert_admin(ctx.member) assert_admin(ctx.member)
queue = await mainservice.context(ctx, create=False, force_play=False).queue() queue = await mainservice.context(ctx, create=False, force_play=False).queue()
queue.queue[0].set_effects(effects) queue.queue[0].set_effects(effects)
@at('default') @at("default")
async def default(ctx: Context, args: list[str]) -> None: async def default(ctx: Context, args: list[str]) -> None:
await catch( await catch(
ctx, args, ''' ctx,
args,
"""
`default - effects` `default - effects`
`default + preset` `default + preset`
`default none` `default none`
''', 'help' """,
"help",
) )
assert ctx.guild is not None assert ctx.guild is not None
match args: match args:
case ['-', effects]: case ["-", effects]:
pass pass
case ['+', preset]: case ["+", preset]:
effects = effects_for_preset(preset) effects = effects_for_preset(preset)
case ['none']: case ["none"]:
effects = None effects = None
case []: case []:
await ctx.reply(f'current default effects: {mainservice.defaulteffects.get(ctx.guild.id)}') await ctx.reply(f"current default effects: {mainservice.defaulteffects.get(ctx.guild.id)}")
return return
case _: case _:
raise Explicit('misformatted') raise Explicit("misformatted")
assert_admin(ctx.member) assert_admin(ctx.member)
await mainservice.defaulteffects.set(ctx.guild.id, effects) await mainservice.defaulteffects.set(ctx.guild.id, effects)
await ctx.reply(f'effects set to `{effects}`') await ctx.reply(f"effects set to `{effects}`")
@at('repeat') @at("repeat")
async def repeat(ctx: Context, args: list[str]): async def repeat(ctx: Context, args: list[str]):
match args: match args:
case ['x', n_, *args] if n_.isdecimal(): case ["x", n_, *args] if n_.isdecimal():
n = int(n_) n = int(n_)
case [*args]: case [*args]:
n = 1 n = 1
case _: case _:
raise RuntimeError raise RuntimeError
match args: match args:
case ['at', p_, *args] if p_.isdecimal(): case ["at", p_, *args] if p_.isdecimal():
p = int(p_) p = int(p_)
case [*args]: case [*args]:
p = 0 p = 0
case _: case _:
raise RuntimeError raise RuntimeError
match args: match args:
case ['to', t_, *args] if t_.isdecimal(): case ["to", t_, *args] if t_.isdecimal():
t = int(t_) t = int(t_)
case ['to', 'end']: case ["to", "end"]:
t = None t = None
case [*args]: case [*args]:
t = p + 1 t = p + 1
@ -187,44 +201,47 @@ presets: {shlex.join(allowed_presets)}
case []: case []:
pass pass
case _: case _:
raise Explicit('misformatted (extra arguments)') raise Explicit("misformatted (extra arguments)")
assert_admin(ctx.member) assert_admin(ctx.member)
queue = await mainservice.context(ctx, create=False, force_play=False).queue() queue = await mainservice.context(ctx, create=False, force_play=False).queue()
queue.repeat(n, p, t) queue.repeat(n, p, t)
@at('shuffle') @at("shuffle")
async def shuffle(ctx: Context, args: list[str]): async def shuffle(ctx: Context, args: list[str]):
assert_admin(ctx.member) assert_admin(ctx.member)
queue = await mainservice.context(ctx, create=False, force_play=False).queue() queue = await mainservice.context(ctx, create=False, force_play=False).queue()
queue.shuffle() queue.shuffle()
@at('branch') @at("branch")
async def branch(ctx: Context, args: list[str]): async def branch(ctx: Context, args: list[str]):
match args: match args:
case ['-', effects]: case ["-", effects]:
pass pass
case ['+', preset]: case ["+", preset]:
effects = effects_for_preset(preset) effects = effects_for_preset(preset)
case ['none']: case ["none"]:
effects = '' effects = ""
case []: case []:
effects = None effects = None
case _: case _:
raise Explicit('misformatted') raise Explicit("misformatted")
assert_admin(ctx.member) assert_admin(ctx.member)
queue = await mainservice.context(ctx, create=False, force_play=False).queue() queue = await mainservice.context(ctx, create=False, force_play=False).queue()
queue.branch(effects) queue.branch(effects)
@at('//') @at("//")
@at('queue') @at("queue")
async def queue_(ctx: Context, args: list[str]) -> None: async def queue_(ctx: Context, args: list[str]) -> None:
await catch( await catch(
ctx, args, ''' ctx,
args,
"""
`queue` `queue`
`queue clear` `queue clear`
`queue resume` `queue resume`
`queue pause` `queue pause`
''', 'help' """,
"help",
) )
assert ctx.member is not None assert ctx.member is not None
match args: match args:
@ -232,41 +249,41 @@ presets: {shlex.join(allowed_presets)}
limit = 24 limit = 24
case [lstr] if lstr.isdecimal(): case [lstr] if lstr.isdecimal():
limit = int(lstr) limit = int(lstr)
case ['tail', lstr] if lstr.isdecimal(): case ["tail", lstr] if lstr.isdecimal():
limit = -int(lstr) limit = -int(lstr)
if limit >= 0: if limit >= 0:
raise Explicit('limit of at least `1` required') raise Explicit("limit of at least `1` required")
case ['clear']: case ["clear"]:
(await mainservice.context(ctx, create=False, force_play=False).queue()).clear(ctx.member) (await mainservice.context(ctx, create=False, force_play=False).queue()).clear(ctx.member)
await ctx.reply('done') await ctx.reply("done")
return return
case ['resume']: case ["resume"]:
async with mainservice.lock_for(ctx.guild): async with mainservice.lock_for(ctx.guild):
await mainservice.context(ctx, create=True, force_play=True).vc() await mainservice.context(ctx, create=True, force_play=True).vc()
await ctx.reply('done') await ctx.reply("done")
return return
case ['pause']: case ["pause"]:
async with mainservice.lock_for(ctx.guild): async with mainservice.lock_for(ctx.guild):
vc = await mainservice.context(ctx, create=True, force_play=False).vc() vc = await mainservice.context(ctx, create=True, force_play=False).vc()
vc.pause() vc.pause()
await ctx.reply('done') await ctx.reply("done")
return return
case _: case _:
raise Explicit('misformatted') raise Explicit("misformatted")
await ctx.long( await ctx.long(
( (await (await mainservice.context(ctx, create=True, force_play=False).queue()).format(limit)).strip()
await ( or "no queue"
await mainservice.context(ctx, create=True, force_play=False).queue()
).format(limit)
).strip() or 'no queue'
) )
@at('swap') @at("swap")
async def swap(ctx: Context, args: list[str]) -> None: async def swap(ctx: Context, args: list[str]) -> None:
await catch( await catch(
ctx, args, ''' ctx,
args,
"""
`swap a b` `swap a b`
''', 'help' """,
"help",
) )
assert ctx.member is not None assert ctx.member is not None
match args: match args:
@ -274,14 +291,17 @@ presets: {shlex.join(allowed_presets)}
a, b = int(a), int(b) a, b = int(a), int(b)
(await mainservice.context(ctx, create=False, force_play=False).queue()).swap(ctx.member, a, b) (await mainservice.context(ctx, create=False, force_play=False).queue()).swap(ctx.member, a, b)
case _: case _:
raise Explicit('misformatted') raise Explicit("misformatted")
@at('move') @at("move")
async def move(ctx: Context, args: list[str]) -> None: async def move(ctx: Context, args: list[str]) -> None:
await catch( await catch(
ctx, args, ''' ctx,
args,
"""
`move a b` `move a b`
''', 'help' """,
"help",
) )
assert ctx.member is not None assert ctx.member is not None
match args: match args:
@ -289,14 +309,17 @@ presets: {shlex.join(allowed_presets)}
a, b = int(a), int(b) a, b = int(a), int(b)
(await mainservice.context(ctx, create=False, force_play=False).queue()).move(ctx.member, a, b) (await mainservice.context(ctx, create=False, force_play=False).queue()).move(ctx.member, a, b)
case _: case _:
raise Explicit('misformatted') raise Explicit("misformatted")
@at('volume') @at("volume")
async def volume_(ctx: Context, args: list[str]) -> None: async def volume_(ctx: Context, args: list[str]) -> None:
await catch( await catch(
ctx, args, ''' ctx,
args,
"""
`volume volume` `volume volume`
''', 'help' """,
"help",
) )
assert ctx.member is not None assert ctx.member is not None
match args: match args:
@ -305,27 +328,27 @@ presets: {shlex.join(allowed_presets)}
await (await mainservice.context(ctx, create=True, force_play=False).main()).set(volume, ctx.member) await (await mainservice.context(ctx, create=True, force_play=False).main()).set(volume, ctx.member)
case []: case []:
volume = (await mainservice.context(ctx, create=True, force_play=False).main()).get() volume = (await mainservice.context(ctx, create=True, force_play=False).main()).get()
await ctx.reply(f'volume is {volume}') await ctx.reply(f"volume is {volume}")
case _: case _:
raise Explicit('misformatted') raise Explicit("misformatted")
@at('pause') @at("pause")
async def pause(ctx: Context, _args: list[str]) -> None: async def pause(ctx: Context, _args: list[str]) -> None:
vc = await mainservice.context(ctx, create=False, force_play=False).vc() vc = await mainservice.context(ctx, create=False, force_play=False).vc()
vc.pause() vc.pause()
@at('resume') @at("resume")
async def resume(ctx: Context, _args: list[str]) -> None: async def resume(ctx: Context, _args: list[str]) -> None:
vc = await mainservice.context(ctx, create=False, force_play=True).vc() vc = await mainservice.context(ctx, create=False, force_play=True).vc()
vc.resume() vc.resume()
@at('leave') @at("leave")
async def leave(ctx: Context, _args: list[str]) -> None: async def leave(ctx: Context, _args: list[str]) -> None:
async with mainservice.lock_for(ctx.guild): async with mainservice.lock_for(ctx.guild):
vc, main = await mainservice.context(ctx, create=False, force_play=False).vc_main() vc, main = await mainservice.context(ctx, create=False, force_play=False).vc_main()
queue = main.queue queue = main.queue
if queue.queue: if queue.queue:
raise Explicit('queue not empty') raise Explicit("queue not empty")
await vc.disconnect() await vc.disconnect()
return of return of

View File

@ -4,13 +4,13 @@ from contextlib import AsyncExitStack
from typing import AsyncIterable, TypeVar from typing import AsyncIterable, TypeVar
import discord import discord
from v6d2ctx.integration.event import *
from v6d2ctx.integration.responsetype import *
from v6d2ctx.integration.targets import *
import v6d3music.processing.pool import v6d3music.processing.pool
from ptvp35 import * from ptvp35 import *
from v6d2ctx.context import * from v6d2ctx.context import *
from v6d2ctx.integration.event import *
from v6d2ctx.integration.responsetype import *
from v6d2ctx.integration.targets import *
from v6d2ctx.lock_for import * from v6d2ctx.lock_for import *
from v6d3music.config import myroot from v6d3music.config import myroot
from v6d3music.core.caching import * from v6d3music.core.caching import *
@ -82,6 +82,7 @@ class MainService:
try: try:
vc = await vch.connect() vc = await vch.connect()
except discord.ClientException as e: except discord.ClientException as e:
traceback.print_exc()
vc = member.guild.voice_client vc = member.guild.voice_client
assert vc is not None assert vc is not None
await member.guild.fetch_channels() await member.guild.fetch_channels()

View File

@ -1,9 +1,11 @@
import asyncio import asyncio
import contextlib import contextlib
import os import os
import struct
import sys import sys
import time import time
from traceback import print_exc from traceback import print_exc
from typing import Any
import discord import discord
@ -13,7 +15,7 @@ from v6d1tokens.client import *
from v6d2ctx.handle_content import * from v6d2ctx.handle_content import *
from v6d2ctx.integration.event import * from v6d2ctx.integration.event import *
from v6d2ctx.integration.targets import * from v6d2ctx.integration.targets import *
from v6d2ctx.pain import * from v6d2ctx.pain import ABlockMonitor, ALog, SLog
from v6d2ctx.serve import * from v6d2ctx.serve import *
from v6d3music.api import * from v6d3music.api import *
from v6d3music.app import * from v6d3music.app import *
@ -178,6 +180,45 @@ async def amain(client: discord.Client):
print("exited") print("exited")
async def initial_connection(self, data: dict[str, Any]) -> None:
state = self._connection
state.ssrc = data["ssrc"]
state.voice_port = data["port"]
state.endpoint_ip = data["ip"]
packet = bytearray(74)
struct.pack_into(">H", packet, 0, 1) # 1 = Send
struct.pack_into(">H", packet, 2, 70) # 70 = Length
struct.pack_into(">I", packet, 4, state.ssrc)
state.socket.sendto(packet, (state.endpoint_ip, state.voice_port))
recv = await self.loop.sock_recv(state.socket, 70)
# the ip is ascii starting at the 8th byte and ending at the first null
ip_start = 8
ip_end = recv.index(0, ip_start)
state.ip = recv[ip_start:ip_end].decode("ascii")
state.port = struct.unpack_from(">H", recv, 6)[0]
# there *should* always be at least one supported mode (xsalsa20_poly1305)
modes = [mode for mode in data["modes"] if mode in self._connection.supported_modes]
mode = modes[0]
await self.select_protocol(state.ip, state.port, mode)
__import__("discord.gateway").gateway.DiscordVoiceWebSocket.initial_connection = initial_connection
def main() -> None: def main() -> None:
with _upgrade_abm(), _db_ee(): wst = __import__("discord.gateway").gateway.DiscordVoiceWebSocket
with (
_upgrade_abm(),
_db_ee(),
ALog(discord.VoiceClient, "connect_websocket"),
ALog(wst, "poll_event"),
ALog(wst, "received_message"),
ALog(wst, "initial_connection"),
ALog(wst, "select_protocol"),
):
serve(amain(_client), _client, loop) serve(amain(_client), _client, loop)

View File

@ -1,5 +1,4 @@
from .main import main from .main import main
if __name__ == '__main__': if __name__ == '__main__':
main() main()