context.py + respawn + restore

This commit is contained in:
AF 2021-11-29 12:35:21 +03:00
parent cc9191b2d3
commit 84de31ce87
2 changed files with 252 additions and 143 deletions

97
v6d3music/context.py Normal file
View File

@ -0,0 +1,97 @@
import asyncio
import time
from io import StringIO
from typing import Union, Optional, Callable, Awaitable
# noinspection PyPackageRequirements
import discord
usertype = Union[discord.abc.User, discord.user.BaseUser, discord.Member, discord.User]
class Context:
def __init__(self, message: discord.Message):
self.message: discord.Message = message
self.channel: discord.abc.Messageable = message.channel
self.dm_or_text: Union[discord.DMChannel, discord.TextChannel] = message.channel
self.author: usertype = message.author
self.content: str = message.content
self.member: Optional[discord.Member] = message.author if isinstance(message.author, discord.Member) else None
self.guild: Optional[discord.Guild] = None if self.member is None else self.member.guild
async def reply(self, content=None, **kwargs) -> discord.Message:
return await self.message.reply(content, mention_author=False, **kwargs)
async def long(self, s: str):
resio = StringIO(s)
res = ''
for line in resio:
if len(res) + len(line) < 2000:
res += line
else:
await self.reply(res)
res = line
if res:
await self.reply(res)
ESCAPED = '`_*\'"\\'
def escape(s: str):
res = StringIO()
for c in s:
if c in ESCAPED:
c = '\\' + c
res.write(c)
return res.getvalue()
buckets: dict[str, dict[str, Callable[[Context, list[str]], Awaitable[None]]]] = {}
def at(bucket: str, name: str):
def wrap(f: Callable[[Context, list[str]], Awaitable[None]]):
buckets.setdefault(bucket, {})[name] = f
return f
return wrap
class Implicit(Exception):
pass
def of(bucket: str, name: str) -> Callable[[Context, list[str]], Awaitable[None]]:
try:
return buckets[bucket][name]
except KeyError:
raise Implicit
benchmarks: dict[str, dict[str, float]] = {}
_t = time.perf_counter()
class Benchmark:
def __init__(self, benchmark: str):
self.benchmark = benchmark
def __enter__(self):
self.t = time.perf_counter()
def __exit__(self, exc_type, exc_val, exc_tb):
d = (time.perf_counter() - self.t)
benchmarks.setdefault(self.benchmark, {'integral': 0.0, 'max': 0.0})
benchmarks[self.benchmark]['integral'] += d
benchmarks[self.benchmark]['max'] = max(benchmarks[self.benchmark]['max'], d)
async def monitor():
while True:
await asyncio.sleep(10)
dt = time.perf_counter() - _t
print('Benchmarks:')
for benchmark, metrics in benchmarks.items():
print(benchmark, '=', metrics['integral'] / max(dt, .00001), ':', metrics['max'])

View File

@ -7,7 +7,7 @@ import subprocess
import time
from collections import deque
from io import StringIO
from typing import Callable, Awaitable, Union, Optional, AsyncIterable, Any
from typing import Optional, AsyncIterable, Any
# noinspection PyPackageRequirements
import discord
@ -19,6 +19,7 @@ from v6d1tokens.client import request_token
from v6d3music.app import get_app
from v6d3music.config import prefix
from v6d3music.context import Context, of, at, escape, Implicit, monitor, Benchmark
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
@ -39,45 +40,33 @@ myroot = root / 'v6d3music'
myroot.mkdir(exist_ok=True)
volume_db = Db(myroot / 'volume.db', kvrequest_type=KVJson)
queue_db = Db(myroot / 'queue.db', kvrequest_type=KVJson)
ESCAPED = '`_*\'"\\'
vcs_restored = False
def escape(s: str):
res = StringIO()
for c in s:
if c in ESCAPED:
c = '\\' + c
res.write(c)
return res.getvalue()
usertype = Union[discord.abc.User, discord.user.BaseUser, discord.Member, discord.User]
class Context:
def __init__(self, message: discord.Message):
self.message: discord.Message = message
self.channel: discord.abc.Messageable = message.channel
self.dm_or_text: Union[discord.DMChannel, discord.TextChannel] = message.channel
self.author: usertype = message.author
self.content: str = message.content
self.member: Optional[discord.Member] = message.author if isinstance(message.author, discord.Member) else None
self.guild: Optional[discord.Guild] = None if self.member is None else self.member.guild
async def reply(self, content=None, **kwargs) -> discord.Message:
return await self.message.reply(content, mention_author=False, **kwargs)
async def long(self, s: str):
resio = StringIO(s)
res = ''
for line in resio:
if len(res) + len(line) < 2000:
res += line
async def restore_vcs():
global vcs_restored
vcs: list[tuple[int, int, bool]] = queue_db.get('vcs', [])
try:
for vcgid, vccid, vc_is_paused in vcs:
try:
guild: discord.Guild = await client.fetch_guild(vcgid)
async with lock_for(guild):
channels = await guild.fetch_channels()
channel: discord.VoiceChannel
channel, = [ch for ch in channels if ch.id == vccid]
vp: discord.VoiceProtocol = await channel.connect()
assert isinstance(vp, discord.VoiceClient)
vc = vp
await main_for_raw_vc(vc, create=True)
if vc_is_paused:
vc.pause()
except Exception as e:
print(f'vc {vcgid} {vccid} {vc_is_paused} failed', e)
else:
await self.reply(res)
res = line
if res:
await self.reply(res)
print(f'vc restored {vcgid} {vccid}')
finally:
vcs_restored = True
@client.event
@ -86,29 +75,8 @@ async def on_ready():
await client.change_presence(activity=discord.Game(
name='феноменально',
))
buckets: dict[str, dict[str, Callable[[Context, list[str]], Awaitable[None]]]] = {}
def at(bucket: str, name: str):
def wrap(f: Callable[[Context, list[str]], Awaitable[None]]):
buckets.setdefault(bucket, {})[name] = f
return f
return wrap
class Implicit(Exception):
pass
def of(bucket: str, name: str) -> Callable[[Context, list[str]], Awaitable[None]]:
try:
return buckets[bucket][name]
except KeyError:
raise Implicit
if not vcs_restored:
await restore_vcs()
async def handle_command(ctx: Context, name: str, args: list[str]) -> None:
@ -132,9 +100,7 @@ class Explicit(Exception):
self.msg = msg
def lock_for(ctx: Context) -> asyncio.Lock:
# noinspection PyTypeChecker
guild: discord.Guild = ctx.guild
def lock_for(guild: discord.Guild) -> asyncio.Lock:
if guild is None:
raise Explicit('not in a guild')
if guild in locks:
@ -223,13 +189,13 @@ class YTAudio(discord.AudioSource):
}
@classmethod
def respawn(cls, guild: discord.Guild, respawn) -> 'YTAudio':
async def respawn(cls, guild: discord.Guild, respawn) -> 'YTAudio':
return YTAudio(
respawn['url'],
respawn['origin'],
respawn['description'],
respawn['options'],
guild.get_member(respawn['rby']),
guild.get_member(respawn['rby']) or await guild.fetch_member(respawn['rby']),
respawn['already_read']
)
@ -248,17 +214,35 @@ FILL = b'\x00' * discord.opus.Encoder.FRAME_SIZE
class QueueAudio(discord.AudioSource):
def __init__(self, guild: discord.Guild):
def __init__(self, guild: discord.Guild, respawned: list[YTAudio]):
self.queue: deque[YTAudio] = deque()
self.queue.extend(respawned)
self.guild = guild
for audio_respawn in queue_db.get(self.guild.id, []):
try:
self.queue.append(YTAudio.respawn(self.guild, audio_respawn))
except Exception as e:
print('respawn failed', e)
def save(self):
queue_db.set_nowait(self.guild.id, [audio.hybernate() for audio in self.queue])
@staticmethod
async def respawned(guild: discord.Guild) -> list[YTAudio]:
respawned = []
try:
for audio_respawn in queue_db.get(guild.id, []):
try:
respawned.append(await YTAudio.respawn(guild, audio_respawn))
except Exception as e:
print('audio respawn failed', e)
raise
except Exception as e:
print('queue respawn failed', e)
return respawned
@classmethod
async def create(cls, guild: discord.Guild):
return cls(guild, await QueueAudio.respawned(guild))
async def save(self):
hybernated = []
for audio in list(self.queue):
await asyncio.sleep(0.01)
hybernated.append(audio.hybernate())
queue_db.set_nowait(self.guild.id, hybernated)
def append(self, audio: YTAudio):
self.queue.append(audio)
@ -299,7 +283,7 @@ class QueueAudio(discord.AudioSource):
def format(self) -> str:
stream = StringIO()
for i, audio in enumerate(self.queue):
for i, audio in enumerate(list(self.queue)):
stream.write(f'`[{i}]` {audio.description}\n')
return stream.getvalue()
@ -313,8 +297,8 @@ class QueueAudio(discord.AudioSource):
class MainAudio(discord.PCMVolumeTransformer):
def __init__(self, guild: discord.Guild, volume: float):
self.queue = QueueAudio(guild)
def __init__(self, queue: QueueAudio, volume: float):
self.queue = queue
super().__init__(self.queue, volume=volume)
async def set(self, volume: float, member: discord.Member):
@ -460,48 +444,66 @@ async def play(ctx: Context, args: list[str]) -> None:
`play url [- effects] ...args`
'''.strip())
case _:
async with lock_for(ctx):
queue = await queue_for(ctx)
async with lock_for(ctx.guild):
queue = await queue_for(ctx, create=True)
async for audio in yt_audios(ctx, args):
queue.append(audio)
await ctx.reply('done')
async def vc_main_for(ctx: Context) -> tuple[discord.VoiceClient, MainAudio]:
async def raw_vc_for(ctx: Context) -> discord.VoiceClient:
if ctx.guild is None:
raise Explicit("not in a guild")
vc: discord.VoiceProtocol = ctx.guild.voice_client
if vc is None:
if vc is None or isinstance(vc, discord.VoiceClient) and not vc.is_connected():
vs: discord.VoiceState = ctx.member.voice
if vs is None:
raise Explicit("not connected")
vch: discord.VoiceChannel = vs.channel
if vch is None:
raise Explicit("not connected")
try:
vc: discord.VoiceProtocol = await vch.connect()
except discord.ClientException:
await ctx.guild.fetch_channels()
raise Explicit("try again later")
assert isinstance(vc, discord.VoiceClient)
if ctx.guild in mainasrcs:
source = mainasrcs[ctx.guild]
else:
await ctx.reply('respawning queue')
source = mainasrcs.setdefault(ctx.guild, MainAudio(ctx.guild, volume=volume_db.get(ctx.guild.id, 0.2)))
if vc.source != source:
vc.play(source)
return vc, source
async def vc_for(ctx: Context) -> discord.VoiceClient:
vc, source = await vc_main_for(ctx)
return vc
async def main_for(ctx: Context) -> MainAudio:
vc, source = await vc_main_for(ctx)
async def main_for_raw_vc(vc: discord.VoiceClient, *, create: bool) -> MainAudio:
if vc.guild in mainasrcs:
source = mainasrcs[vc.guild]
else:
if create:
source = mainasrcs.setdefault(
vc.guild,
MainAudio(await QueueAudio.create(vc.guild), volume=volume_db.get(vc.guild.id, 0.2))
)
else:
raise Explicit("not playing")
if vc.source != source:
vc.play(source)
return source
async def queue_for(ctx: Context) -> QueueAudio:
return (await main_for(ctx)).queue
async def vc_main_for(ctx: Context, *, create: bool) -> tuple[discord.VoiceClient, MainAudio]:
vc = await raw_vc_for(ctx)
return vc, await main_for_raw_vc(vc, create=create)
async def vc_for(ctx: Context, *, create: bool) -> discord.VoiceClient:
vc, source = await vc_main_for(ctx, create=create)
return vc
async def main_for(ctx: Context, *, create: bool) -> MainAudio:
vc, source = await vc_main_for(ctx, create=create)
return source
async def queue_for(ctx: Context, *, create: bool) -> QueueAudio:
return (await main_for(ctx, create=create)).queue
@at('commands', 'skip')
@ -512,15 +514,15 @@ async def skip(ctx: Context, args: list[str]) -> None:
`skip [first] [last]`
'''.strip())
case []:
queue = await queue_for(ctx)
queue = await queue_for(ctx, create=False)
queue.skip_at(0, ctx.member)
case [pos]:
pos = int(pos)
queue = await queue_for(ctx)
queue = await queue_for(ctx, create=False)
queue.skip_at(pos, ctx.member)
case [pos0, pos1]:
pos0, pos1 = int(pos0), int(pos1)
queue = await queue_for(ctx)
queue = await queue_for(ctx, create=False)
for i in range(pos0, pos1 + 1):
if not queue.skip_at(pos0, ctx.member):
pos0 += 1
@ -532,9 +534,13 @@ async def queue_(ctx: Context, args: list[str]) -> None:
case ['help']:
await ctx.reply('current queue')
case []:
await ctx.long((await queue_for(ctx)).format().strip() or 'no queue')
await ctx.long((await queue_for(ctx, create=False)).format().strip() or 'no queue')
case ['clear']:
(await queue_for(ctx)).clear(ctx.member)
(await queue_for(ctx, create=False)).clear(ctx.member)
await ctx.reply('done')
case ['resume']:
async with lock_for(ctx.guild):
await queue_for(ctx, create=True)
@at('commands', 'volume')
@ -544,28 +550,22 @@ async def volume_(ctx: Context, args: list[str]) -> None:
await ctx.reply('`volume 0.2`')
case [volume]:
volume = float(volume)
await (await main_for(ctx)).set(volume, ctx.member)
await (await main_for(ctx, create=False)).set(volume, ctx.member)
@at('commands', 'pause')
async def pause(ctx: Context, _args: list[str]) -> None:
(await vc_for(ctx)).pause()
vc = await vc_for(ctx, create=False)
vc.pause()
@at('commands', 'resume')
async def resume(ctx: Context, _args: list[str]) -> None:
(await vc_for(ctx)).resume()
vc = await vc_for(ctx, create=False)
vc.resume()
@client.event
async def on_message(message: discord.Message) -> None:
if message.author.bot:
return
content: str = message.content
if not content.startswith(prefix):
return
content = content.removeprefix(prefix)
args = shlex.split(content)
async def handle_args(message: discord.Message, args: list[str]):
match args:
case []:
return
@ -579,44 +579,56 @@ async def on_message(message: discord.Message) -> None:
await ctx.reply(e.msg)
@client.event
async def on_message(message: discord.Message) -> None:
if message.author.bot:
return
content: str = message.content
if not content.startswith(prefix):
return
content = content.removeprefix(prefix)
args = shlex.split(content)
await handle_args(message, args)
async def save_queues():
while True:
await asyncio.sleep(1)
for mainasrc in list(mainasrcs.values()):
await asyncio.sleep(1)
mainasrc.queue.save()
await asyncio.sleep(0.01)
await mainasrc.queue.save()
benchmarks: dict[str, float] = {}
_t = time.perf_counter()
async def save_vcs():
if vcs_restored:
vcs = []
vc: discord.VoiceClient
for vc in list(client.voice_clients):
await asyncio.sleep(0.01)
if vc.is_playing():
vcs.append((vc.guild.id, vc.channel.id, vc.is_paused()))
queue_db.set_nowait('vcs', vcs)
class Benchmark:
def __init__(self, benchmark: str):
self.benchmark = benchmark
def __enter__(self):
self.t = time.perf_counter()
def __exit__(self, exc_type, exc_val, exc_tb):
benchmarks.setdefault(self.benchmark, 0.0)
benchmarks[self.benchmark] += time.perf_counter() - self.t
async def save_commit():
await queue_db.set('commit', time.time())
async def monitor():
async def save_job():
while True:
await asyncio.sleep(60)
dt = time.perf_counter() - _t
for benchmark, count in benchmarks.items():
print(benchmark, '=', count / max(dt, .00001))
await asyncio.sleep(1)
with Benchmark('SVQ'):
await save_queues()
with Benchmark('SVV'):
await save_vcs()
with Benchmark('SVC'):
await save_commit()
async def main():
async with volume_db, queue_db:
await start_app(get_app(client))
await client.login(token)
asyncio.get_event_loop().create_task(save_queues())
asyncio.get_event_loop().create_task(monitor())
loop.create_task(save_job())
loop.create_task(monitor())
await client.connect()