From 84de31ce87818d1768fad072fd650caa6b7b4964 Mon Sep 17 00:00:00 2001 From: timotheyca Date: Mon, 29 Nov 2021 12:35:21 +0300 Subject: [PATCH] context.py + respawn + restore --- v6d3music/context.py | 97 ++++++++++++++ v6d3music/run-bot.py | 298 ++++++++++++++++++++++--------------------- 2 files changed, 252 insertions(+), 143 deletions(-) create mode 100644 v6d3music/context.py diff --git a/v6d3music/context.py b/v6d3music/context.py new file mode 100644 index 0000000..24120f1 --- /dev/null +++ b/v6d3music/context.py @@ -0,0 +1,97 @@ +import asyncio +import time +from io import StringIO +from typing import Union, Optional, Callable, Awaitable + +# noinspection PyPackageRequirements +import discord + +usertype = Union[discord.abc.User, discord.user.BaseUser, discord.Member, discord.User] + + +class Context: + def __init__(self, message: discord.Message): + self.message: discord.Message = message + self.channel: discord.abc.Messageable = message.channel + self.dm_or_text: Union[discord.DMChannel, discord.TextChannel] = message.channel + self.author: usertype = message.author + self.content: str = message.content + self.member: Optional[discord.Member] = message.author if isinstance(message.author, discord.Member) else None + self.guild: Optional[discord.Guild] = None if self.member is None else self.member.guild + + async def reply(self, content=None, **kwargs) -> discord.Message: + return await self.message.reply(content, mention_author=False, **kwargs) + + async def long(self, s: str): + resio = StringIO(s) + res = '' + for line in resio: + if len(res) + len(line) < 2000: + res += line + else: + await self.reply(res) + res = line + if res: + await self.reply(res) + + +ESCAPED = '`_*\'"\\' + + +def escape(s: str): + res = StringIO() + for c in s: + if c in ESCAPED: + c = '\\' + c + res.write(c) + return res.getvalue() + + +buckets: dict[str, dict[str, Callable[[Context, list[str]], Awaitable[None]]]] = {} + + +def at(bucket: str, name: str): + def wrap(f: Callable[[Context, list[str]], Awaitable[None]]): + buckets.setdefault(bucket, {})[name] = f + + return f + + return wrap + + +class Implicit(Exception): + pass + + +def of(bucket: str, name: str) -> Callable[[Context, list[str]], Awaitable[None]]: + try: + return buckets[bucket][name] + except KeyError: + raise Implicit + + +benchmarks: dict[str, dict[str, float]] = {} +_t = time.perf_counter() + + +class Benchmark: + def __init__(self, benchmark: str): + self.benchmark = benchmark + + def __enter__(self): + self.t = time.perf_counter() + + def __exit__(self, exc_type, exc_val, exc_tb): + d = (time.perf_counter() - self.t) + benchmarks.setdefault(self.benchmark, {'integral': 0.0, 'max': 0.0}) + benchmarks[self.benchmark]['integral'] += d + benchmarks[self.benchmark]['max'] = max(benchmarks[self.benchmark]['max'], d) + + +async def monitor(): + while True: + await asyncio.sleep(10) + dt = time.perf_counter() - _t + print('Benchmarks:') + for benchmark, metrics in benchmarks.items(): + print(benchmark, '=', metrics['integral'] / max(dt, .00001), ':', metrics['max']) diff --git a/v6d3music/run-bot.py b/v6d3music/run-bot.py index 2a83c6f..33221d7 100644 --- a/v6d3music/run-bot.py +++ b/v6d3music/run-bot.py @@ -7,7 +7,7 @@ import subprocess import time from collections import deque from io import StringIO -from typing import Callable, Awaitable, Union, Optional, AsyncIterable, Any +from typing import Optional, AsyncIterable, Any # noinspection PyPackageRequirements import discord @@ -19,6 +19,7 @@ from v6d1tokens.client import request_token from v6d3music.app import get_app from v6d3music.config import prefix +from v6d3music.context import Context, of, at, escape, Implicit, monitor, Benchmark loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -39,45 +40,33 @@ myroot = root / 'v6d3music' myroot.mkdir(exist_ok=True) volume_db = Db(myroot / 'volume.db', kvrequest_type=KVJson) queue_db = Db(myroot / 'queue.db', kvrequest_type=KVJson) -ESCAPED = '`_*\'"\\' + +vcs_restored = False -def escape(s: str): - res = StringIO() - for c in s: - if c in ESCAPED: - c = '\\' + c - res.write(c) - return res.getvalue() - - -usertype = Union[discord.abc.User, discord.user.BaseUser, discord.Member, discord.User] - - -class Context: - def __init__(self, message: discord.Message): - self.message: discord.Message = message - self.channel: discord.abc.Messageable = message.channel - self.dm_or_text: Union[discord.DMChannel, discord.TextChannel] = message.channel - self.author: usertype = message.author - self.content: str = message.content - self.member: Optional[discord.Member] = message.author if isinstance(message.author, discord.Member) else None - self.guild: Optional[discord.Guild] = None if self.member is None else self.member.guild - - async def reply(self, content=None, **kwargs) -> discord.Message: - return await self.message.reply(content, mention_author=False, **kwargs) - - async def long(self, s: str): - resio = StringIO(s) - res = '' - for line in resio: - if len(res) + len(line) < 2000: - res += line +async def restore_vcs(): + global vcs_restored + vcs: list[tuple[int, int, bool]] = queue_db.get('vcs', []) + try: + for vcgid, vccid, vc_is_paused in vcs: + try: + guild: discord.Guild = await client.fetch_guild(vcgid) + async with lock_for(guild): + channels = await guild.fetch_channels() + channel: discord.VoiceChannel + channel, = [ch for ch in channels if ch.id == vccid] + vp: discord.VoiceProtocol = await channel.connect() + assert isinstance(vp, discord.VoiceClient) + vc = vp + await main_for_raw_vc(vc, create=True) + if vc_is_paused: + vc.pause() + except Exception as e: + print(f'vc {vcgid} {vccid} {vc_is_paused} failed', e) else: - await self.reply(res) - res = line - if res: - await self.reply(res) + print(f'vc restored {vcgid} {vccid}') + finally: + vcs_restored = True @client.event @@ -86,29 +75,8 @@ async def on_ready(): await client.change_presence(activity=discord.Game( name='феноменально', )) - - -buckets: dict[str, dict[str, Callable[[Context, list[str]], Awaitable[None]]]] = {} - - -def at(bucket: str, name: str): - def wrap(f: Callable[[Context, list[str]], Awaitable[None]]): - buckets.setdefault(bucket, {})[name] = f - - return f - - return wrap - - -class Implicit(Exception): - pass - - -def of(bucket: str, name: str) -> Callable[[Context, list[str]], Awaitable[None]]: - try: - return buckets[bucket][name] - except KeyError: - raise Implicit + if not vcs_restored: + await restore_vcs() async def handle_command(ctx: Context, name: str, args: list[str]) -> None: @@ -132,9 +100,7 @@ class Explicit(Exception): self.msg = msg -def lock_for(ctx: Context) -> asyncio.Lock: - # noinspection PyTypeChecker - guild: discord.Guild = ctx.guild +def lock_for(guild: discord.Guild) -> asyncio.Lock: if guild is None: raise Explicit('not in a guild') if guild in locks: @@ -223,13 +189,13 @@ class YTAudio(discord.AudioSource): } @classmethod - def respawn(cls, guild: discord.Guild, respawn) -> 'YTAudio': + async def respawn(cls, guild: discord.Guild, respawn) -> 'YTAudio': return YTAudio( respawn['url'], respawn['origin'], respawn['description'], respawn['options'], - guild.get_member(respawn['rby']), + guild.get_member(respawn['rby']) or await guild.fetch_member(respawn['rby']), respawn['already_read'] ) @@ -248,17 +214,35 @@ FILL = b'\x00' * discord.opus.Encoder.FRAME_SIZE class QueueAudio(discord.AudioSource): - def __init__(self, guild: discord.Guild): + def __init__(self, guild: discord.Guild, respawned: list[YTAudio]): self.queue: deque[YTAudio] = deque() + self.queue.extend(respawned) self.guild = guild - for audio_respawn in queue_db.get(self.guild.id, []): - try: - self.queue.append(YTAudio.respawn(self.guild, audio_respawn)) - except Exception as e: - print('respawn failed', e) - def save(self): - queue_db.set_nowait(self.guild.id, [audio.hybernate() for audio in self.queue]) + @staticmethod + async def respawned(guild: discord.Guild) -> list[YTAudio]: + respawned = [] + try: + for audio_respawn in queue_db.get(guild.id, []): + try: + respawned.append(await YTAudio.respawn(guild, audio_respawn)) + except Exception as e: + print('audio respawn failed', e) + raise + except Exception as e: + print('queue respawn failed', e) + return respawned + + @classmethod + async def create(cls, guild: discord.Guild): + return cls(guild, await QueueAudio.respawned(guild)) + + async def save(self): + hybernated = [] + for audio in list(self.queue): + await asyncio.sleep(0.01) + hybernated.append(audio.hybernate()) + queue_db.set_nowait(self.guild.id, hybernated) def append(self, audio: YTAudio): self.queue.append(audio) @@ -299,7 +283,7 @@ class QueueAudio(discord.AudioSource): def format(self) -> str: stream = StringIO() - for i, audio in enumerate(self.queue): + for i, audio in enumerate(list(self.queue)): stream.write(f'`[{i}]` {audio.description}\n') return stream.getvalue() @@ -313,8 +297,8 @@ class QueueAudio(discord.AudioSource): class MainAudio(discord.PCMVolumeTransformer): - def __init__(self, guild: discord.Guild, volume: float): - self.queue = QueueAudio(guild) + def __init__(self, queue: QueueAudio, volume: float): + self.queue = queue super().__init__(self.queue, volume=volume) async def set(self, volume: float, member: discord.Member): @@ -460,48 +444,66 @@ async def play(ctx: Context, args: list[str]) -> None: `play url [- effects] ...args` '''.strip()) case _: - async with lock_for(ctx): - queue = await queue_for(ctx) + async with lock_for(ctx.guild): + queue = await queue_for(ctx, create=True) async for audio in yt_audios(ctx, args): queue.append(audio) await ctx.reply('done') -async def vc_main_for(ctx: Context) -> tuple[discord.VoiceClient, MainAudio]: +async def raw_vc_for(ctx: Context) -> discord.VoiceClient: if ctx.guild is None: raise Explicit("not in a guild") vc: discord.VoiceProtocol = ctx.guild.voice_client - if vc is None: + if vc is None or isinstance(vc, discord.VoiceClient) and not vc.is_connected(): vs: discord.VoiceState = ctx.member.voice if vs is None: raise Explicit("not connected") vch: discord.VoiceChannel = vs.channel if vch is None: raise Explicit("not connected") - vc: discord.VoiceProtocol = await vch.connect() + try: + vc: discord.VoiceProtocol = await vch.connect() + except discord.ClientException: + await ctx.guild.fetch_channels() + raise Explicit("try again later") assert isinstance(vc, discord.VoiceClient) - if ctx.guild in mainasrcs: - source = mainasrcs[ctx.guild] - else: - await ctx.reply('respawning queue') - source = mainasrcs.setdefault(ctx.guild, MainAudio(ctx.guild, volume=volume_db.get(ctx.guild.id, 0.2))) - if vc.source != source: - vc.play(source) - return vc, source - - -async def vc_for(ctx: Context) -> discord.VoiceClient: - vc, source = await vc_main_for(ctx) return vc -async def main_for(ctx: Context) -> MainAudio: - vc, source = await vc_main_for(ctx) +async def main_for_raw_vc(vc: discord.VoiceClient, *, create: bool) -> MainAudio: + if vc.guild in mainasrcs: + source = mainasrcs[vc.guild] + else: + if create: + source = mainasrcs.setdefault( + vc.guild, + MainAudio(await QueueAudio.create(vc.guild), volume=volume_db.get(vc.guild.id, 0.2)) + ) + else: + raise Explicit("not playing") + if vc.source != source: + vc.play(source) return source -async def queue_for(ctx: Context) -> QueueAudio: - return (await main_for(ctx)).queue +async def vc_main_for(ctx: Context, *, create: bool) -> tuple[discord.VoiceClient, MainAudio]: + vc = await raw_vc_for(ctx) + return vc, await main_for_raw_vc(vc, create=create) + + +async def vc_for(ctx: Context, *, create: bool) -> discord.VoiceClient: + vc, source = await vc_main_for(ctx, create=create) + return vc + + +async def main_for(ctx: Context, *, create: bool) -> MainAudio: + vc, source = await vc_main_for(ctx, create=create) + return source + + +async def queue_for(ctx: Context, *, create: bool) -> QueueAudio: + return (await main_for(ctx, create=create)).queue @at('commands', 'skip') @@ -512,15 +514,15 @@ async def skip(ctx: Context, args: list[str]) -> None: `skip [first] [last]` '''.strip()) case []: - queue = await queue_for(ctx) + queue = await queue_for(ctx, create=False) queue.skip_at(0, ctx.member) case [pos]: pos = int(pos) - queue = await queue_for(ctx) + queue = await queue_for(ctx, create=False) queue.skip_at(pos, ctx.member) case [pos0, pos1]: pos0, pos1 = int(pos0), int(pos1) - queue = await queue_for(ctx) + queue = await queue_for(ctx, create=False) for i in range(pos0, pos1 + 1): if not queue.skip_at(pos0, ctx.member): pos0 += 1 @@ -532,9 +534,13 @@ async def queue_(ctx: Context, args: list[str]) -> None: case ['help']: await ctx.reply('current queue') case []: - await ctx.long((await queue_for(ctx)).format().strip() or 'no queue') + await ctx.long((await queue_for(ctx, create=False)).format().strip() or 'no queue') case ['clear']: - (await queue_for(ctx)).clear(ctx.member) + (await queue_for(ctx, create=False)).clear(ctx.member) + await ctx.reply('done') + case ['resume']: + async with lock_for(ctx.guild): + await queue_for(ctx, create=True) @at('commands', 'volume') @@ -544,28 +550,22 @@ async def volume_(ctx: Context, args: list[str]) -> None: await ctx.reply('`volume 0.2`') case [volume]: volume = float(volume) - await (await main_for(ctx)).set(volume, ctx.member) + await (await main_for(ctx, create=False)).set(volume, ctx.member) @at('commands', 'pause') async def pause(ctx: Context, _args: list[str]) -> None: - (await vc_for(ctx)).pause() + vc = await vc_for(ctx, create=False) + vc.pause() @at('commands', 'resume') async def resume(ctx: Context, _args: list[str]) -> None: - (await vc_for(ctx)).resume() + vc = await vc_for(ctx, create=False) + vc.resume() -@client.event -async def on_message(message: discord.Message) -> None: - if message.author.bot: - return - content: str = message.content - if not content.startswith(prefix): - return - content = content.removeprefix(prefix) - args = shlex.split(content) +async def handle_args(message: discord.Message, args: list[str]): match args: case []: return @@ -579,44 +579,56 @@ async def on_message(message: discord.Message) -> None: await ctx.reply(e.msg) +@client.event +async def on_message(message: discord.Message) -> None: + if message.author.bot: + return + content: str = message.content + if not content.startswith(prefix): + return + content = content.removeprefix(prefix) + args = shlex.split(content) + await handle_args(message, args) + + async def save_queues(): + for mainasrc in list(mainasrcs.values()): + await asyncio.sleep(0.01) + await mainasrc.queue.save() + + +async def save_vcs(): + if vcs_restored: + vcs = [] + vc: discord.VoiceClient + for vc in list(client.voice_clients): + await asyncio.sleep(0.01) + if vc.is_playing(): + vcs.append((vc.guild.id, vc.channel.id, vc.is_paused())) + queue_db.set_nowait('vcs', vcs) + + +async def save_commit(): + await queue_db.set('commit', time.time()) + + +async def save_job(): while True: await asyncio.sleep(1) - for mainasrc in list(mainasrcs.values()): - await asyncio.sleep(1) - mainasrc.queue.save() - - -benchmarks: dict[str, float] = {} -_t = time.perf_counter() - - -class Benchmark: - def __init__(self, benchmark: str): - self.benchmark = benchmark - - def __enter__(self): - self.t = time.perf_counter() - - def __exit__(self, exc_type, exc_val, exc_tb): - benchmarks.setdefault(self.benchmark, 0.0) - benchmarks[self.benchmark] += time.perf_counter() - self.t - - -async def monitor(): - while True: - await asyncio.sleep(60) - dt = time.perf_counter() - _t - for benchmark, count in benchmarks.items(): - print(benchmark, '=', count / max(dt, .00001)) + with Benchmark('SVQ'): + await save_queues() + with Benchmark('SVV'): + await save_vcs() + with Benchmark('SVC'): + await save_commit() async def main(): async with volume_db, queue_db: await start_app(get_app(client)) await client.login(token) - asyncio.get_event_loop().create_task(save_queues()) - asyncio.get_event_loop().create_task(monitor()) + loop.create_task(save_job()) + loop.create_task(monitor()) await client.connect()