context.py + respawn + restore

This commit is contained in:
AF 2021-11-29 12:35:21 +03:00
parent cc9191b2d3
commit 84de31ce87
2 changed files with 252 additions and 143 deletions

97
v6d3music/context.py Normal file
View File

@ -0,0 +1,97 @@
import asyncio
import time
from io import StringIO
from typing import Union, Optional, Callable, Awaitable
# noinspection PyPackageRequirements
import discord
usertype = Union[discord.abc.User, discord.user.BaseUser, discord.Member, discord.User]
class Context:
def __init__(self, message: discord.Message):
self.message: discord.Message = message
self.channel: discord.abc.Messageable = message.channel
self.dm_or_text: Union[discord.DMChannel, discord.TextChannel] = message.channel
self.author: usertype = message.author
self.content: str = message.content
self.member: Optional[discord.Member] = message.author if isinstance(message.author, discord.Member) else None
self.guild: Optional[discord.Guild] = None if self.member is None else self.member.guild
async def reply(self, content=None, **kwargs) -> discord.Message:
return await self.message.reply(content, mention_author=False, **kwargs)
async def long(self, s: str):
resio = StringIO(s)
res = ''
for line in resio:
if len(res) + len(line) < 2000:
res += line
else:
await self.reply(res)
res = line
if res:
await self.reply(res)
ESCAPED = '`_*\'"\\'
def escape(s: str):
res = StringIO()
for c in s:
if c in ESCAPED:
c = '\\' + c
res.write(c)
return res.getvalue()
buckets: dict[str, dict[str, Callable[[Context, list[str]], Awaitable[None]]]] = {}
def at(bucket: str, name: str):
def wrap(f: Callable[[Context, list[str]], Awaitable[None]]):
buckets.setdefault(bucket, {})[name] = f
return f
return wrap
class Implicit(Exception):
pass
def of(bucket: str, name: str) -> Callable[[Context, list[str]], Awaitable[None]]:
try:
return buckets[bucket][name]
except KeyError:
raise Implicit
benchmarks: dict[str, dict[str, float]] = {}
_t = time.perf_counter()
class Benchmark:
def __init__(self, benchmark: str):
self.benchmark = benchmark
def __enter__(self):
self.t = time.perf_counter()
def __exit__(self, exc_type, exc_val, exc_tb):
d = (time.perf_counter() - self.t)
benchmarks.setdefault(self.benchmark, {'integral': 0.0, 'max': 0.0})
benchmarks[self.benchmark]['integral'] += d
benchmarks[self.benchmark]['max'] = max(benchmarks[self.benchmark]['max'], d)
async def monitor():
while True:
await asyncio.sleep(10)
dt = time.perf_counter() - _t
print('Benchmarks:')
for benchmark, metrics in benchmarks.items():
print(benchmark, '=', metrics['integral'] / max(dt, .00001), ':', metrics['max'])

View File

@ -7,7 +7,7 @@ import subprocess
import time import time
from collections import deque from collections import deque
from io import StringIO from io import StringIO
from typing import Callable, Awaitable, Union, Optional, AsyncIterable, Any from typing import Optional, AsyncIterable, Any
# noinspection PyPackageRequirements # noinspection PyPackageRequirements
import discord import discord
@ -19,6 +19,7 @@ from v6d1tokens.client import request_token
from v6d3music.app import get_app from v6d3music.app import get_app
from v6d3music.config import prefix from v6d3music.config import prefix
from v6d3music.context import Context, of, at, escape, Implicit, monitor, Benchmark
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
@ -39,45 +40,33 @@ myroot = root / 'v6d3music'
myroot.mkdir(exist_ok=True) myroot.mkdir(exist_ok=True)
volume_db = Db(myroot / 'volume.db', kvrequest_type=KVJson) volume_db = Db(myroot / 'volume.db', kvrequest_type=KVJson)
queue_db = Db(myroot / 'queue.db', kvrequest_type=KVJson) queue_db = Db(myroot / 'queue.db', kvrequest_type=KVJson)
ESCAPED = '`_*\'"\\'
vcs_restored = False
def escape(s: str): async def restore_vcs():
res = StringIO() global vcs_restored
for c in s: vcs: list[tuple[int, int, bool]] = queue_db.get('vcs', [])
if c in ESCAPED: try:
c = '\\' + c for vcgid, vccid, vc_is_paused in vcs:
res.write(c) try:
return res.getvalue() guild: discord.Guild = await client.fetch_guild(vcgid)
async with lock_for(guild):
channels = await guild.fetch_channels()
usertype = Union[discord.abc.User, discord.user.BaseUser, discord.Member, discord.User] channel: discord.VoiceChannel
channel, = [ch for ch in channels if ch.id == vccid]
vp: discord.VoiceProtocol = await channel.connect()
class Context: assert isinstance(vp, discord.VoiceClient)
def __init__(self, message: discord.Message): vc = vp
self.message: discord.Message = message await main_for_raw_vc(vc, create=True)
self.channel: discord.abc.Messageable = message.channel if vc_is_paused:
self.dm_or_text: Union[discord.DMChannel, discord.TextChannel] = message.channel vc.pause()
self.author: usertype = message.author except Exception as e:
self.content: str = message.content print(f'vc {vcgid} {vccid} {vc_is_paused} failed', e)
self.member: Optional[discord.Member] = message.author if isinstance(message.author, discord.Member) else None
self.guild: Optional[discord.Guild] = None if self.member is None else self.member.guild
async def reply(self, content=None, **kwargs) -> discord.Message:
return await self.message.reply(content, mention_author=False, **kwargs)
async def long(self, s: str):
resio = StringIO(s)
res = ''
for line in resio:
if len(res) + len(line) < 2000:
res += line
else: else:
await self.reply(res) print(f'vc restored {vcgid} {vccid}')
res = line finally:
if res: vcs_restored = True
await self.reply(res)
@client.event @client.event
@ -86,29 +75,8 @@ async def on_ready():
await client.change_presence(activity=discord.Game( await client.change_presence(activity=discord.Game(
name='феноменально', name='феноменально',
)) ))
if not vcs_restored:
await restore_vcs()
buckets: dict[str, dict[str, Callable[[Context, list[str]], Awaitable[None]]]] = {}
def at(bucket: str, name: str):
def wrap(f: Callable[[Context, list[str]], Awaitable[None]]):
buckets.setdefault(bucket, {})[name] = f
return f
return wrap
class Implicit(Exception):
pass
def of(bucket: str, name: str) -> Callable[[Context, list[str]], Awaitable[None]]:
try:
return buckets[bucket][name]
except KeyError:
raise Implicit
async def handle_command(ctx: Context, name: str, args: list[str]) -> None: async def handle_command(ctx: Context, name: str, args: list[str]) -> None:
@ -132,9 +100,7 @@ class Explicit(Exception):
self.msg = msg self.msg = msg
def lock_for(ctx: Context) -> asyncio.Lock: def lock_for(guild: discord.Guild) -> asyncio.Lock:
# noinspection PyTypeChecker
guild: discord.Guild = ctx.guild
if guild is None: if guild is None:
raise Explicit('not in a guild') raise Explicit('not in a guild')
if guild in locks: if guild in locks:
@ -223,13 +189,13 @@ class YTAudio(discord.AudioSource):
} }
@classmethod @classmethod
def respawn(cls, guild: discord.Guild, respawn) -> 'YTAudio': async def respawn(cls, guild: discord.Guild, respawn) -> 'YTAudio':
return YTAudio( return YTAudio(
respawn['url'], respawn['url'],
respawn['origin'], respawn['origin'],
respawn['description'], respawn['description'],
respawn['options'], respawn['options'],
guild.get_member(respawn['rby']), guild.get_member(respawn['rby']) or await guild.fetch_member(respawn['rby']),
respawn['already_read'] respawn['already_read']
) )
@ -248,17 +214,35 @@ FILL = b'\x00' * discord.opus.Encoder.FRAME_SIZE
class QueueAudio(discord.AudioSource): class QueueAudio(discord.AudioSource):
def __init__(self, guild: discord.Guild): def __init__(self, guild: discord.Guild, respawned: list[YTAudio]):
self.queue: deque[YTAudio] = deque() self.queue: deque[YTAudio] = deque()
self.queue.extend(respawned)
self.guild = guild self.guild = guild
for audio_respawn in queue_db.get(self.guild.id, []):
try:
self.queue.append(YTAudio.respawn(self.guild, audio_respawn))
except Exception as e:
print('respawn failed', e)
def save(self): @staticmethod
queue_db.set_nowait(self.guild.id, [audio.hybernate() for audio in self.queue]) async def respawned(guild: discord.Guild) -> list[YTAudio]:
respawned = []
try:
for audio_respawn in queue_db.get(guild.id, []):
try:
respawned.append(await YTAudio.respawn(guild, audio_respawn))
except Exception as e:
print('audio respawn failed', e)
raise
except Exception as e:
print('queue respawn failed', e)
return respawned
@classmethod
async def create(cls, guild: discord.Guild):
return cls(guild, await QueueAudio.respawned(guild))
async def save(self):
hybernated = []
for audio in list(self.queue):
await asyncio.sleep(0.01)
hybernated.append(audio.hybernate())
queue_db.set_nowait(self.guild.id, hybernated)
def append(self, audio: YTAudio): def append(self, audio: YTAudio):
self.queue.append(audio) self.queue.append(audio)
@ -299,7 +283,7 @@ class QueueAudio(discord.AudioSource):
def format(self) -> str: def format(self) -> str:
stream = StringIO() stream = StringIO()
for i, audio in enumerate(self.queue): for i, audio in enumerate(list(self.queue)):
stream.write(f'`[{i}]` {audio.description}\n') stream.write(f'`[{i}]` {audio.description}\n')
return stream.getvalue() return stream.getvalue()
@ -313,8 +297,8 @@ class QueueAudio(discord.AudioSource):
class MainAudio(discord.PCMVolumeTransformer): class MainAudio(discord.PCMVolumeTransformer):
def __init__(self, guild: discord.Guild, volume: float): def __init__(self, queue: QueueAudio, volume: float):
self.queue = QueueAudio(guild) self.queue = queue
super().__init__(self.queue, volume=volume) super().__init__(self.queue, volume=volume)
async def set(self, volume: float, member: discord.Member): async def set(self, volume: float, member: discord.Member):
@ -460,48 +444,66 @@ async def play(ctx: Context, args: list[str]) -> None:
`play url [- effects] ...args` `play url [- effects] ...args`
'''.strip()) '''.strip())
case _: case _:
async with lock_for(ctx): async with lock_for(ctx.guild):
queue = await queue_for(ctx) queue = await queue_for(ctx, create=True)
async for audio in yt_audios(ctx, args): async for audio in yt_audios(ctx, args):
queue.append(audio) queue.append(audio)
await ctx.reply('done') await ctx.reply('done')
async def vc_main_for(ctx: Context) -> tuple[discord.VoiceClient, MainAudio]: async def raw_vc_for(ctx: Context) -> discord.VoiceClient:
if ctx.guild is None: if ctx.guild is None:
raise Explicit("not in a guild") raise Explicit("not in a guild")
vc: discord.VoiceProtocol = ctx.guild.voice_client vc: discord.VoiceProtocol = ctx.guild.voice_client
if vc is None: if vc is None or isinstance(vc, discord.VoiceClient) and not vc.is_connected():
vs: discord.VoiceState = ctx.member.voice vs: discord.VoiceState = ctx.member.voice
if vs is None: if vs is None:
raise Explicit("not connected") raise Explicit("not connected")
vch: discord.VoiceChannel = vs.channel vch: discord.VoiceChannel = vs.channel
if vch is None: if vch is None:
raise Explicit("not connected") raise Explicit("not connected")
vc: discord.VoiceProtocol = await vch.connect() try:
vc: discord.VoiceProtocol = await vch.connect()
except discord.ClientException:
await ctx.guild.fetch_channels()
raise Explicit("try again later")
assert isinstance(vc, discord.VoiceClient) assert isinstance(vc, discord.VoiceClient)
if ctx.guild in mainasrcs:
source = mainasrcs[ctx.guild]
else:
await ctx.reply('respawning queue')
source = mainasrcs.setdefault(ctx.guild, MainAudio(ctx.guild, volume=volume_db.get(ctx.guild.id, 0.2)))
if vc.source != source:
vc.play(source)
return vc, source
async def vc_for(ctx: Context) -> discord.VoiceClient:
vc, source = await vc_main_for(ctx)
return vc return vc
async def main_for(ctx: Context) -> MainAudio: async def main_for_raw_vc(vc: discord.VoiceClient, *, create: bool) -> MainAudio:
vc, source = await vc_main_for(ctx) if vc.guild in mainasrcs:
source = mainasrcs[vc.guild]
else:
if create:
source = mainasrcs.setdefault(
vc.guild,
MainAudio(await QueueAudio.create(vc.guild), volume=volume_db.get(vc.guild.id, 0.2))
)
else:
raise Explicit("not playing")
if vc.source != source:
vc.play(source)
return source return source
async def queue_for(ctx: Context) -> QueueAudio: async def vc_main_for(ctx: Context, *, create: bool) -> tuple[discord.VoiceClient, MainAudio]:
return (await main_for(ctx)).queue vc = await raw_vc_for(ctx)
return vc, await main_for_raw_vc(vc, create=create)
async def vc_for(ctx: Context, *, create: bool) -> discord.VoiceClient:
vc, source = await vc_main_for(ctx, create=create)
return vc
async def main_for(ctx: Context, *, create: bool) -> MainAudio:
vc, source = await vc_main_for(ctx, create=create)
return source
async def queue_for(ctx: Context, *, create: bool) -> QueueAudio:
return (await main_for(ctx, create=create)).queue
@at('commands', 'skip') @at('commands', 'skip')
@ -512,15 +514,15 @@ async def skip(ctx: Context, args: list[str]) -> None:
`skip [first] [last]` `skip [first] [last]`
'''.strip()) '''.strip())
case []: case []:
queue = await queue_for(ctx) queue = await queue_for(ctx, create=False)
queue.skip_at(0, ctx.member) queue.skip_at(0, ctx.member)
case [pos]: case [pos]:
pos = int(pos) pos = int(pos)
queue = await queue_for(ctx) queue = await queue_for(ctx, create=False)
queue.skip_at(pos, ctx.member) queue.skip_at(pos, ctx.member)
case [pos0, pos1]: case [pos0, pos1]:
pos0, pos1 = int(pos0), int(pos1) pos0, pos1 = int(pos0), int(pos1)
queue = await queue_for(ctx) queue = await queue_for(ctx, create=False)
for i in range(pos0, pos1 + 1): for i in range(pos0, pos1 + 1):
if not queue.skip_at(pos0, ctx.member): if not queue.skip_at(pos0, ctx.member):
pos0 += 1 pos0 += 1
@ -532,9 +534,13 @@ async def queue_(ctx: Context, args: list[str]) -> None:
case ['help']: case ['help']:
await ctx.reply('current queue') await ctx.reply('current queue')
case []: case []:
await ctx.long((await queue_for(ctx)).format().strip() or 'no queue') await ctx.long((await queue_for(ctx, create=False)).format().strip() or 'no queue')
case ['clear']: case ['clear']:
(await queue_for(ctx)).clear(ctx.member) (await queue_for(ctx, create=False)).clear(ctx.member)
await ctx.reply('done')
case ['resume']:
async with lock_for(ctx.guild):
await queue_for(ctx, create=True)
@at('commands', 'volume') @at('commands', 'volume')
@ -544,28 +550,22 @@ async def volume_(ctx: Context, args: list[str]) -> None:
await ctx.reply('`volume 0.2`') await ctx.reply('`volume 0.2`')
case [volume]: case [volume]:
volume = float(volume) volume = float(volume)
await (await main_for(ctx)).set(volume, ctx.member) await (await main_for(ctx, create=False)).set(volume, ctx.member)
@at('commands', 'pause') @at('commands', 'pause')
async def pause(ctx: Context, _args: list[str]) -> None: async def pause(ctx: Context, _args: list[str]) -> None:
(await vc_for(ctx)).pause() vc = await vc_for(ctx, create=False)
vc.pause()
@at('commands', 'resume') @at('commands', 'resume')
async def resume(ctx: Context, _args: list[str]) -> None: async def resume(ctx: Context, _args: list[str]) -> None:
(await vc_for(ctx)).resume() vc = await vc_for(ctx, create=False)
vc.resume()
@client.event async def handle_args(message: discord.Message, args: list[str]):
async def on_message(message: discord.Message) -> None:
if message.author.bot:
return
content: str = message.content
if not content.startswith(prefix):
return
content = content.removeprefix(prefix)
args = shlex.split(content)
match args: match args:
case []: case []:
return return
@ -579,44 +579,56 @@ async def on_message(message: discord.Message) -> None:
await ctx.reply(e.msg) await ctx.reply(e.msg)
@client.event
async def on_message(message: discord.Message) -> None:
if message.author.bot:
return
content: str = message.content
if not content.startswith(prefix):
return
content = content.removeprefix(prefix)
args = shlex.split(content)
await handle_args(message, args)
async def save_queues(): async def save_queues():
for mainasrc in list(mainasrcs.values()):
await asyncio.sleep(0.01)
await mainasrc.queue.save()
async def save_vcs():
if vcs_restored:
vcs = []
vc: discord.VoiceClient
for vc in list(client.voice_clients):
await asyncio.sleep(0.01)
if vc.is_playing():
vcs.append((vc.guild.id, vc.channel.id, vc.is_paused()))
queue_db.set_nowait('vcs', vcs)
async def save_commit():
await queue_db.set('commit', time.time())
async def save_job():
while True: while True:
await asyncio.sleep(1) await asyncio.sleep(1)
for mainasrc in list(mainasrcs.values()): with Benchmark('SVQ'):
await asyncio.sleep(1) await save_queues()
mainasrc.queue.save() with Benchmark('SVV'):
await save_vcs()
with Benchmark('SVC'):
benchmarks: dict[str, float] = {} await save_commit()
_t = time.perf_counter()
class Benchmark:
def __init__(self, benchmark: str):
self.benchmark = benchmark
def __enter__(self):
self.t = time.perf_counter()
def __exit__(self, exc_type, exc_val, exc_tb):
benchmarks.setdefault(self.benchmark, 0.0)
benchmarks[self.benchmark] += time.perf_counter() - self.t
async def monitor():
while True:
await asyncio.sleep(60)
dt = time.perf_counter() - _t
for benchmark, count in benchmarks.items():
print(benchmark, '=', count / max(dt, .00001))
async def main(): async def main():
async with volume_db, queue_db: async with volume_db, queue_db:
await start_app(get_app(client)) await start_app(get_app(client))
await client.login(token) await client.login(token)
asyncio.get_event_loop().create_task(save_queues()) loop.create_task(save_job())
asyncio.get_event_loop().create_task(monitor()) loop.create_task(monitor())
await client.connect() await client.connect()