context.py + respawn + restore
This commit is contained in:
parent
cc9191b2d3
commit
84de31ce87
97
v6d3music/context.py
Normal file
97
v6d3music/context.py
Normal file
@ -0,0 +1,97 @@
|
||||
import asyncio
|
||||
import time
|
||||
from io import StringIO
|
||||
from typing import Union, Optional, Callable, Awaitable
|
||||
|
||||
# noinspection PyPackageRequirements
|
||||
import discord
|
||||
|
||||
usertype = Union[discord.abc.User, discord.user.BaseUser, discord.Member, discord.User]
|
||||
|
||||
|
||||
class Context:
|
||||
def __init__(self, message: discord.Message):
|
||||
self.message: discord.Message = message
|
||||
self.channel: discord.abc.Messageable = message.channel
|
||||
self.dm_or_text: Union[discord.DMChannel, discord.TextChannel] = message.channel
|
||||
self.author: usertype = message.author
|
||||
self.content: str = message.content
|
||||
self.member: Optional[discord.Member] = message.author if isinstance(message.author, discord.Member) else None
|
||||
self.guild: Optional[discord.Guild] = None if self.member is None else self.member.guild
|
||||
|
||||
async def reply(self, content=None, **kwargs) -> discord.Message:
|
||||
return await self.message.reply(content, mention_author=False, **kwargs)
|
||||
|
||||
async def long(self, s: str):
|
||||
resio = StringIO(s)
|
||||
res = ''
|
||||
for line in resio:
|
||||
if len(res) + len(line) < 2000:
|
||||
res += line
|
||||
else:
|
||||
await self.reply(res)
|
||||
res = line
|
||||
if res:
|
||||
await self.reply(res)
|
||||
|
||||
|
||||
ESCAPED = '`_*\'"\\'
|
||||
|
||||
|
||||
def escape(s: str):
|
||||
res = StringIO()
|
||||
for c in s:
|
||||
if c in ESCAPED:
|
||||
c = '\\' + c
|
||||
res.write(c)
|
||||
return res.getvalue()
|
||||
|
||||
|
||||
buckets: dict[str, dict[str, Callable[[Context, list[str]], Awaitable[None]]]] = {}
|
||||
|
||||
|
||||
def at(bucket: str, name: str):
|
||||
def wrap(f: Callable[[Context, list[str]], Awaitable[None]]):
|
||||
buckets.setdefault(bucket, {})[name] = f
|
||||
|
||||
return f
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
class Implicit(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def of(bucket: str, name: str) -> Callable[[Context, list[str]], Awaitable[None]]:
|
||||
try:
|
||||
return buckets[bucket][name]
|
||||
except KeyError:
|
||||
raise Implicit
|
||||
|
||||
|
||||
benchmarks: dict[str, dict[str, float]] = {}
|
||||
_t = time.perf_counter()
|
||||
|
||||
|
||||
class Benchmark:
|
||||
def __init__(self, benchmark: str):
|
||||
self.benchmark = benchmark
|
||||
|
||||
def __enter__(self):
|
||||
self.t = time.perf_counter()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
d = (time.perf_counter() - self.t)
|
||||
benchmarks.setdefault(self.benchmark, {'integral': 0.0, 'max': 0.0})
|
||||
benchmarks[self.benchmark]['integral'] += d
|
||||
benchmarks[self.benchmark]['max'] = max(benchmarks[self.benchmark]['max'], d)
|
||||
|
||||
|
||||
async def monitor():
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
dt = time.perf_counter() - _t
|
||||
print('Benchmarks:')
|
||||
for benchmark, metrics in benchmarks.items():
|
||||
print(benchmark, '=', metrics['integral'] / max(dt, .00001), ':', metrics['max'])
|
@ -7,7 +7,7 @@ import subprocess
|
||||
import time
|
||||
from collections import deque
|
||||
from io import StringIO
|
||||
from typing import Callable, Awaitable, Union, Optional, AsyncIterable, Any
|
||||
from typing import Optional, AsyncIterable, Any
|
||||
|
||||
# noinspection PyPackageRequirements
|
||||
import discord
|
||||
@ -19,6 +19,7 @@ from v6d1tokens.client import request_token
|
||||
|
||||
from v6d3music.app import get_app
|
||||
from v6d3music.config import prefix
|
||||
from v6d3music.context import Context, of, at, escape, Implicit, monitor, Benchmark
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
@ -39,45 +40,33 @@ myroot = root / 'v6d3music'
|
||||
myroot.mkdir(exist_ok=True)
|
||||
volume_db = Db(myroot / 'volume.db', kvrequest_type=KVJson)
|
||||
queue_db = Db(myroot / 'queue.db', kvrequest_type=KVJson)
|
||||
ESCAPED = '`_*\'"\\'
|
||||
|
||||
vcs_restored = False
|
||||
|
||||
|
||||
def escape(s: str):
|
||||
res = StringIO()
|
||||
for c in s:
|
||||
if c in ESCAPED:
|
||||
c = '\\' + c
|
||||
res.write(c)
|
||||
return res.getvalue()
|
||||
|
||||
|
||||
usertype = Union[discord.abc.User, discord.user.BaseUser, discord.Member, discord.User]
|
||||
|
||||
|
||||
class Context:
|
||||
def __init__(self, message: discord.Message):
|
||||
self.message: discord.Message = message
|
||||
self.channel: discord.abc.Messageable = message.channel
|
||||
self.dm_or_text: Union[discord.DMChannel, discord.TextChannel] = message.channel
|
||||
self.author: usertype = message.author
|
||||
self.content: str = message.content
|
||||
self.member: Optional[discord.Member] = message.author if isinstance(message.author, discord.Member) else None
|
||||
self.guild: Optional[discord.Guild] = None if self.member is None else self.member.guild
|
||||
|
||||
async def reply(self, content=None, **kwargs) -> discord.Message:
|
||||
return await self.message.reply(content, mention_author=False, **kwargs)
|
||||
|
||||
async def long(self, s: str):
|
||||
resio = StringIO(s)
|
||||
res = ''
|
||||
for line in resio:
|
||||
if len(res) + len(line) < 2000:
|
||||
res += line
|
||||
async def restore_vcs():
|
||||
global vcs_restored
|
||||
vcs: list[tuple[int, int, bool]] = queue_db.get('vcs', [])
|
||||
try:
|
||||
for vcgid, vccid, vc_is_paused in vcs:
|
||||
try:
|
||||
guild: discord.Guild = await client.fetch_guild(vcgid)
|
||||
async with lock_for(guild):
|
||||
channels = await guild.fetch_channels()
|
||||
channel: discord.VoiceChannel
|
||||
channel, = [ch for ch in channels if ch.id == vccid]
|
||||
vp: discord.VoiceProtocol = await channel.connect()
|
||||
assert isinstance(vp, discord.VoiceClient)
|
||||
vc = vp
|
||||
await main_for_raw_vc(vc, create=True)
|
||||
if vc_is_paused:
|
||||
vc.pause()
|
||||
except Exception as e:
|
||||
print(f'vc {vcgid} {vccid} {vc_is_paused} failed', e)
|
||||
else:
|
||||
await self.reply(res)
|
||||
res = line
|
||||
if res:
|
||||
await self.reply(res)
|
||||
print(f'vc restored {vcgid} {vccid}')
|
||||
finally:
|
||||
vcs_restored = True
|
||||
|
||||
|
||||
@client.event
|
||||
@ -86,29 +75,8 @@ async def on_ready():
|
||||
await client.change_presence(activity=discord.Game(
|
||||
name='феноменально',
|
||||
))
|
||||
|
||||
|
||||
buckets: dict[str, dict[str, Callable[[Context, list[str]], Awaitable[None]]]] = {}
|
||||
|
||||
|
||||
def at(bucket: str, name: str):
|
||||
def wrap(f: Callable[[Context, list[str]], Awaitable[None]]):
|
||||
buckets.setdefault(bucket, {})[name] = f
|
||||
|
||||
return f
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
class Implicit(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def of(bucket: str, name: str) -> Callable[[Context, list[str]], Awaitable[None]]:
|
||||
try:
|
||||
return buckets[bucket][name]
|
||||
except KeyError:
|
||||
raise Implicit
|
||||
if not vcs_restored:
|
||||
await restore_vcs()
|
||||
|
||||
|
||||
async def handle_command(ctx: Context, name: str, args: list[str]) -> None:
|
||||
@ -132,9 +100,7 @@ class Explicit(Exception):
|
||||
self.msg = msg
|
||||
|
||||
|
||||
def lock_for(ctx: Context) -> asyncio.Lock:
|
||||
# noinspection PyTypeChecker
|
||||
guild: discord.Guild = ctx.guild
|
||||
def lock_for(guild: discord.Guild) -> asyncio.Lock:
|
||||
if guild is None:
|
||||
raise Explicit('not in a guild')
|
||||
if guild in locks:
|
||||
@ -223,13 +189,13 @@ class YTAudio(discord.AudioSource):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def respawn(cls, guild: discord.Guild, respawn) -> 'YTAudio':
|
||||
async def respawn(cls, guild: discord.Guild, respawn) -> 'YTAudio':
|
||||
return YTAudio(
|
||||
respawn['url'],
|
||||
respawn['origin'],
|
||||
respawn['description'],
|
||||
respawn['options'],
|
||||
guild.get_member(respawn['rby']),
|
||||
guild.get_member(respawn['rby']) or await guild.fetch_member(respawn['rby']),
|
||||
respawn['already_read']
|
||||
)
|
||||
|
||||
@ -248,17 +214,35 @@ FILL = b'\x00' * discord.opus.Encoder.FRAME_SIZE
|
||||
|
||||
|
||||
class QueueAudio(discord.AudioSource):
|
||||
def __init__(self, guild: discord.Guild):
|
||||
def __init__(self, guild: discord.Guild, respawned: list[YTAudio]):
|
||||
self.queue: deque[YTAudio] = deque()
|
||||
self.queue.extend(respawned)
|
||||
self.guild = guild
|
||||
for audio_respawn in queue_db.get(self.guild.id, []):
|
||||
try:
|
||||
self.queue.append(YTAudio.respawn(self.guild, audio_respawn))
|
||||
except Exception as e:
|
||||
print('respawn failed', e)
|
||||
|
||||
def save(self):
|
||||
queue_db.set_nowait(self.guild.id, [audio.hybernate() for audio in self.queue])
|
||||
@staticmethod
|
||||
async def respawned(guild: discord.Guild) -> list[YTAudio]:
|
||||
respawned = []
|
||||
try:
|
||||
for audio_respawn in queue_db.get(guild.id, []):
|
||||
try:
|
||||
respawned.append(await YTAudio.respawn(guild, audio_respawn))
|
||||
except Exception as e:
|
||||
print('audio respawn failed', e)
|
||||
raise
|
||||
except Exception as e:
|
||||
print('queue respawn failed', e)
|
||||
return respawned
|
||||
|
||||
@classmethod
|
||||
async def create(cls, guild: discord.Guild):
|
||||
return cls(guild, await QueueAudio.respawned(guild))
|
||||
|
||||
async def save(self):
|
||||
hybernated = []
|
||||
for audio in list(self.queue):
|
||||
await asyncio.sleep(0.01)
|
||||
hybernated.append(audio.hybernate())
|
||||
queue_db.set_nowait(self.guild.id, hybernated)
|
||||
|
||||
def append(self, audio: YTAudio):
|
||||
self.queue.append(audio)
|
||||
@ -299,7 +283,7 @@ class QueueAudio(discord.AudioSource):
|
||||
|
||||
def format(self) -> str:
|
||||
stream = StringIO()
|
||||
for i, audio in enumerate(self.queue):
|
||||
for i, audio in enumerate(list(self.queue)):
|
||||
stream.write(f'`[{i}]` {audio.description}\n')
|
||||
return stream.getvalue()
|
||||
|
||||
@ -313,8 +297,8 @@ class QueueAudio(discord.AudioSource):
|
||||
|
||||
|
||||
class MainAudio(discord.PCMVolumeTransformer):
|
||||
def __init__(self, guild: discord.Guild, volume: float):
|
||||
self.queue = QueueAudio(guild)
|
||||
def __init__(self, queue: QueueAudio, volume: float):
|
||||
self.queue = queue
|
||||
super().__init__(self.queue, volume=volume)
|
||||
|
||||
async def set(self, volume: float, member: discord.Member):
|
||||
@ -460,48 +444,66 @@ async def play(ctx: Context, args: list[str]) -> None:
|
||||
`play url [- effects] ...args`
|
||||
'''.strip())
|
||||
case _:
|
||||
async with lock_for(ctx):
|
||||
queue = await queue_for(ctx)
|
||||
async with lock_for(ctx.guild):
|
||||
queue = await queue_for(ctx, create=True)
|
||||
async for audio in yt_audios(ctx, args):
|
||||
queue.append(audio)
|
||||
await ctx.reply('done')
|
||||
|
||||
|
||||
async def vc_main_for(ctx: Context) -> tuple[discord.VoiceClient, MainAudio]:
|
||||
async def raw_vc_for(ctx: Context) -> discord.VoiceClient:
|
||||
if ctx.guild is None:
|
||||
raise Explicit("not in a guild")
|
||||
vc: discord.VoiceProtocol = ctx.guild.voice_client
|
||||
if vc is None:
|
||||
if vc is None or isinstance(vc, discord.VoiceClient) and not vc.is_connected():
|
||||
vs: discord.VoiceState = ctx.member.voice
|
||||
if vs is None:
|
||||
raise Explicit("not connected")
|
||||
vch: discord.VoiceChannel = vs.channel
|
||||
if vch is None:
|
||||
raise Explicit("not connected")
|
||||
vc: discord.VoiceProtocol = await vch.connect()
|
||||
try:
|
||||
vc: discord.VoiceProtocol = await vch.connect()
|
||||
except discord.ClientException:
|
||||
await ctx.guild.fetch_channels()
|
||||
raise Explicit("try again later")
|
||||
assert isinstance(vc, discord.VoiceClient)
|
||||
if ctx.guild in mainasrcs:
|
||||
source = mainasrcs[ctx.guild]
|
||||
else:
|
||||
await ctx.reply('respawning queue')
|
||||
source = mainasrcs.setdefault(ctx.guild, MainAudio(ctx.guild, volume=volume_db.get(ctx.guild.id, 0.2)))
|
||||
if vc.source != source:
|
||||
vc.play(source)
|
||||
return vc, source
|
||||
|
||||
|
||||
async def vc_for(ctx: Context) -> discord.VoiceClient:
|
||||
vc, source = await vc_main_for(ctx)
|
||||
return vc
|
||||
|
||||
|
||||
async def main_for(ctx: Context) -> MainAudio:
|
||||
vc, source = await vc_main_for(ctx)
|
||||
async def main_for_raw_vc(vc: discord.VoiceClient, *, create: bool) -> MainAudio:
|
||||
if vc.guild in mainasrcs:
|
||||
source = mainasrcs[vc.guild]
|
||||
else:
|
||||
if create:
|
||||
source = mainasrcs.setdefault(
|
||||
vc.guild,
|
||||
MainAudio(await QueueAudio.create(vc.guild), volume=volume_db.get(vc.guild.id, 0.2))
|
||||
)
|
||||
else:
|
||||
raise Explicit("not playing")
|
||||
if vc.source != source:
|
||||
vc.play(source)
|
||||
return source
|
||||
|
||||
|
||||
async def queue_for(ctx: Context) -> QueueAudio:
|
||||
return (await main_for(ctx)).queue
|
||||
async def vc_main_for(ctx: Context, *, create: bool) -> tuple[discord.VoiceClient, MainAudio]:
|
||||
vc = await raw_vc_for(ctx)
|
||||
return vc, await main_for_raw_vc(vc, create=create)
|
||||
|
||||
|
||||
async def vc_for(ctx: Context, *, create: bool) -> discord.VoiceClient:
|
||||
vc, source = await vc_main_for(ctx, create=create)
|
||||
return vc
|
||||
|
||||
|
||||
async def main_for(ctx: Context, *, create: bool) -> MainAudio:
|
||||
vc, source = await vc_main_for(ctx, create=create)
|
||||
return source
|
||||
|
||||
|
||||
async def queue_for(ctx: Context, *, create: bool) -> QueueAudio:
|
||||
return (await main_for(ctx, create=create)).queue
|
||||
|
||||
|
||||
@at('commands', 'skip')
|
||||
@ -512,15 +514,15 @@ async def skip(ctx: Context, args: list[str]) -> None:
|
||||
`skip [first] [last]`
|
||||
'''.strip())
|
||||
case []:
|
||||
queue = await queue_for(ctx)
|
||||
queue = await queue_for(ctx, create=False)
|
||||
queue.skip_at(0, ctx.member)
|
||||
case [pos]:
|
||||
pos = int(pos)
|
||||
queue = await queue_for(ctx)
|
||||
queue = await queue_for(ctx, create=False)
|
||||
queue.skip_at(pos, ctx.member)
|
||||
case [pos0, pos1]:
|
||||
pos0, pos1 = int(pos0), int(pos1)
|
||||
queue = await queue_for(ctx)
|
||||
queue = await queue_for(ctx, create=False)
|
||||
for i in range(pos0, pos1 + 1):
|
||||
if not queue.skip_at(pos0, ctx.member):
|
||||
pos0 += 1
|
||||
@ -532,9 +534,13 @@ async def queue_(ctx: Context, args: list[str]) -> None:
|
||||
case ['help']:
|
||||
await ctx.reply('current queue')
|
||||
case []:
|
||||
await ctx.long((await queue_for(ctx)).format().strip() or 'no queue')
|
||||
await ctx.long((await queue_for(ctx, create=False)).format().strip() or 'no queue')
|
||||
case ['clear']:
|
||||
(await queue_for(ctx)).clear(ctx.member)
|
||||
(await queue_for(ctx, create=False)).clear(ctx.member)
|
||||
await ctx.reply('done')
|
||||
case ['resume']:
|
||||
async with lock_for(ctx.guild):
|
||||
await queue_for(ctx, create=True)
|
||||
|
||||
|
||||
@at('commands', 'volume')
|
||||
@ -544,28 +550,22 @@ async def volume_(ctx: Context, args: list[str]) -> None:
|
||||
await ctx.reply('`volume 0.2`')
|
||||
case [volume]:
|
||||
volume = float(volume)
|
||||
await (await main_for(ctx)).set(volume, ctx.member)
|
||||
await (await main_for(ctx, create=False)).set(volume, ctx.member)
|
||||
|
||||
|
||||
@at('commands', 'pause')
|
||||
async def pause(ctx: Context, _args: list[str]) -> None:
|
||||
(await vc_for(ctx)).pause()
|
||||
vc = await vc_for(ctx, create=False)
|
||||
vc.pause()
|
||||
|
||||
|
||||
@at('commands', 'resume')
|
||||
async def resume(ctx: Context, _args: list[str]) -> None:
|
||||
(await vc_for(ctx)).resume()
|
||||
vc = await vc_for(ctx, create=False)
|
||||
vc.resume()
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_message(message: discord.Message) -> None:
|
||||
if message.author.bot:
|
||||
return
|
||||
content: str = message.content
|
||||
if not content.startswith(prefix):
|
||||
return
|
||||
content = content.removeprefix(prefix)
|
||||
args = shlex.split(content)
|
||||
async def handle_args(message: discord.Message, args: list[str]):
|
||||
match args:
|
||||
case []:
|
||||
return
|
||||
@ -579,44 +579,56 @@ async def on_message(message: discord.Message) -> None:
|
||||
await ctx.reply(e.msg)
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_message(message: discord.Message) -> None:
|
||||
if message.author.bot:
|
||||
return
|
||||
content: str = message.content
|
||||
if not content.startswith(prefix):
|
||||
return
|
||||
content = content.removeprefix(prefix)
|
||||
args = shlex.split(content)
|
||||
await handle_args(message, args)
|
||||
|
||||
|
||||
async def save_queues():
|
||||
for mainasrc in list(mainasrcs.values()):
|
||||
await asyncio.sleep(0.01)
|
||||
await mainasrc.queue.save()
|
||||
|
||||
|
||||
async def save_vcs():
|
||||
if vcs_restored:
|
||||
vcs = []
|
||||
vc: discord.VoiceClient
|
||||
for vc in list(client.voice_clients):
|
||||
await asyncio.sleep(0.01)
|
||||
if vc.is_playing():
|
||||
vcs.append((vc.guild.id, vc.channel.id, vc.is_paused()))
|
||||
queue_db.set_nowait('vcs', vcs)
|
||||
|
||||
|
||||
async def save_commit():
|
||||
await queue_db.set('commit', time.time())
|
||||
|
||||
|
||||
async def save_job():
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
for mainasrc in list(mainasrcs.values()):
|
||||
await asyncio.sleep(1)
|
||||
mainasrc.queue.save()
|
||||
|
||||
|
||||
benchmarks: dict[str, float] = {}
|
||||
_t = time.perf_counter()
|
||||
|
||||
|
||||
class Benchmark:
|
||||
def __init__(self, benchmark: str):
|
||||
self.benchmark = benchmark
|
||||
|
||||
def __enter__(self):
|
||||
self.t = time.perf_counter()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
benchmarks.setdefault(self.benchmark, 0.0)
|
||||
benchmarks[self.benchmark] += time.perf_counter() - self.t
|
||||
|
||||
|
||||
async def monitor():
|
||||
while True:
|
||||
await asyncio.sleep(60)
|
||||
dt = time.perf_counter() - _t
|
||||
for benchmark, count in benchmarks.items():
|
||||
print(benchmark, '=', count / max(dt, .00001))
|
||||
with Benchmark('SVQ'):
|
||||
await save_queues()
|
||||
with Benchmark('SVV'):
|
||||
await save_vcs()
|
||||
with Benchmark('SVC'):
|
||||
await save_commit()
|
||||
|
||||
|
||||
async def main():
|
||||
async with volume_db, queue_db:
|
||||
await start_app(get_app(client))
|
||||
await client.login(token)
|
||||
asyncio.get_event_loop().create_task(save_queues())
|
||||
asyncio.get_event_loop().create_task(monitor())
|
||||
loop.create_task(save_job())
|
||||
loop.create_task(monitor())
|
||||
await client.connect()
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user