293 lines
11 KiB
Python
293 lines
11 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import traceback
|
|
from contextlib import AsyncExitStack
|
|
from typing import AsyncIterable, TypeVar
|
|
|
|
import discord
|
|
|
|
import v6d3music.processing.pool
|
|
from ptvp35 import DbFactory, KVJson
|
|
from v6d2ctx.context import Context, Explicit
|
|
from v6d2ctx.integration.event import Event, SendableEvents
|
|
from v6d2ctx.integration.responsetype import ResponseType
|
|
from v6d2ctx.integration.targets import Async, Targets
|
|
from v6d2ctx.lock_for import Locks
|
|
from v6d3music.config import myroot
|
|
from v6d3music.core.default_effects import DefaultEffects
|
|
from v6d3music.core.mainaudio import MainAudio
|
|
from v6d3music.core.monitoring import Monitoring, PersistentMonitoring
|
|
from v6d3music.core.queueaudio import QueueAudio
|
|
from v6d3music.core.ystate import YState
|
|
from v6d3music.core.aservicing import AServicing
|
|
from v6d3music.core.audio import Audio
|
|
from v6d3music.processing.pool import Pool, PoolEvent
|
|
from v6d3music.utils.argctx import ArgCtx
|
|
from v6d3music.utils.assert_admin import assert_admin
|
|
|
|
__all__ = ("MainService", "MainMode", "MainContext", "MainEvent")
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class MainEvent(Event):
|
|
pass
|
|
|
|
|
|
class _PMEvent(MainEvent):
|
|
def __init__(self, event: PoolEvent, /) -> None:
|
|
self.event = event
|
|
|
|
def json(self) -> ResponseType:
|
|
return {"pool": self.event.json()}
|
|
|
|
|
|
class _PMSendable(SendableEvents[PoolEvent]):
|
|
def __init__(self, sendable: SendableEvents[MainEvent], /) -> None:
|
|
self.sendable = sendable
|
|
|
|
def send(self, event: PoolEvent, /) -> None:
|
|
return self.sendable.send(_PMEvent(event))
|
|
|
|
|
|
class MainService:
|
|
def __init__(
|
|
self,
|
|
targets: Targets,
|
|
defaulteffects: DefaultEffects,
|
|
client: discord.Client,
|
|
events: SendableEvents[MainEvent],
|
|
) -> None:
|
|
self.targets = targets
|
|
self.defaulteffects = defaulteffects
|
|
self.client = client
|
|
self.mains: dict[discord.Guild, MainAudio] = {}
|
|
self.restore_lock = asyncio.Lock()
|
|
self.__events: SendableEvents[MainEvent] = events
|
|
self.__pool_events: SendableEvents[PoolEvent] = _PMSendable(self.__events)
|
|
self.__ystates: dict[discord.Guild, YState] = {}
|
|
|
|
def register_instrumentation(self):
|
|
self.targets.register_type(v6d3music.processing.pool.UnitJob, "run", Async)
|
|
|
|
@staticmethod
|
|
async def raw_vc_for_member(member: discord.Member) -> discord.VoiceClient:
|
|
vc: discord.VoiceProtocol | None = member.guild.voice_client
|
|
if vc is None or vc.channel is None or isinstance(vc, discord.VoiceClient) and not vc.is_connected():
|
|
vs: discord.VoiceState | None = member.voice
|
|
if vs is None:
|
|
raise Explicit("not connected")
|
|
vch: discord.abc.Connectable | None = vs.channel
|
|
if vch is None:
|
|
raise Explicit("not connected")
|
|
try:
|
|
vc = await vch.connect()
|
|
except discord.ClientException as e:
|
|
traceback.print_exc()
|
|
vc = member.guild.voice_client
|
|
assert vc is not None
|
|
await member.guild.fetch_channels()
|
|
await vc.disconnect(force=True)
|
|
raise Explicit("try again later") from e
|
|
assert isinstance(vc, discord.VoiceClient)
|
|
return vc
|
|
|
|
async def raw_vc_for(self, ctx: Context) -> discord.VoiceClient:
|
|
if ctx.member is None:
|
|
raise Explicit("not in a guild")
|
|
return await self.raw_vc_for_member(ctx.member)
|
|
|
|
def mode(self, *, create: bool, force_play: bool) -> MainMode:
|
|
return MainMode(self, create=create, force_play=force_play)
|
|
|
|
def context(self, ctx: Context, *, create: bool, force_play: bool) -> MainContext:
|
|
return self.mode(create=create, force_play=force_play).context(ctx)
|
|
|
|
async def create(self, guild: discord.Guild) -> MainAudio:
|
|
return await MainAudio.create(self.__servicing, self.__volumes, self.__queues, guild)
|
|
|
|
async def __aenter__(self) -> MainService:
|
|
async with AsyncExitStack() as es:
|
|
self.__locks = Locks()
|
|
self.__volumes = await es.enter_async_context(DbFactory(myroot / "volume.db", kvfactory=KVJson()))
|
|
self.__queues = await es.enter_async_context(DbFactory(myroot / "queue.db", kvfactory=KVJson()))
|
|
self.__pool = await es.enter_async_context(Pool(5, self.__pool_events))
|
|
self.__servicing = AServicing(self.__pool)
|
|
self.__vcs_restored: asyncio.Future[None] = asyncio.Future()
|
|
self.__save_task = asyncio.create_task(self.save_daemon())
|
|
self.monitoring = await es.enter_async_context(Monitoring())
|
|
self.pmonitoring = es.enter_context(PersistentMonitoring(self.monitoring))
|
|
self.register_instrumentation()
|
|
self.__es = es.pop_all()
|
|
return self
|
|
raise RuntimeError
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
async with self.__es:
|
|
await self.final_save()
|
|
del self.__es
|
|
|
|
async def save_queues(self, delay: bool) -> None:
|
|
for mainasrc in list(self.mains.values()):
|
|
if delay:
|
|
await asyncio.sleep(0.01)
|
|
await mainasrc.queue.save(delay)
|
|
|
|
async def save_vcs(self, delay: bool) -> None:
|
|
vcs = []
|
|
vc: discord.VoiceClient
|
|
for vc in (vcc for vcc in self.client.voice_clients if isinstance(vcc, discord.VoiceClient)):
|
|
if delay:
|
|
await asyncio.sleep(0.01)
|
|
if vc.is_playing():
|
|
if vc.guild is not None and vc.channel is not None:
|
|
vcs.append((vc.guild.id, vc.channel.id, vc.is_paused()))
|
|
self.__queues.set_nowait("vcs", vcs)
|
|
|
|
async def save_commit(self) -> None:
|
|
await self.__queues.commit()
|
|
|
|
async def _save_all(self, delay: bool, save_playing: bool) -> None:
|
|
await self.save_queues(delay)
|
|
if save_playing:
|
|
await self.save_vcs(delay)
|
|
await self.save_commit()
|
|
|
|
async def save_all(self, delay: bool, save_playing: bool) -> None:
|
|
await self._save_all(delay, save_playing)
|
|
|
|
async def save_job(self):
|
|
await self.__vcs_restored
|
|
print("starting saving")
|
|
while True:
|
|
await asyncio.sleep(1)
|
|
await self.save_all(True, not self.client.is_closed())
|
|
|
|
async def save_daemon(self):
|
|
try:
|
|
await self.save_job()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
async def final_save(self):
|
|
self.__save_task.cancel()
|
|
if not self.__vcs_restored.done():
|
|
self.__vcs_restored.cancel()
|
|
else:
|
|
try:
|
|
await self.save_all(False, False)
|
|
print("saved")
|
|
except Exception:
|
|
traceback.print_exc()
|
|
|
|
async def _restore_vc(self, guild: discord.Guild, vccid: int, vc_is_paused: bool) -> None:
|
|
channels = await guild.fetch_channels()
|
|
channel: discord.VoiceChannel
|
|
(channel,) = [ch for ch in (chc for chc in channels if isinstance(chc, discord.VoiceChannel)) if ch.id == vccid]
|
|
vp: discord.VoiceProtocol = await channel.connect()
|
|
assert isinstance(vp, discord.VoiceClient)
|
|
vc = vp
|
|
await self.mode(create=True, force_play=True).main_for_raw_vc(vc)
|
|
if vc_is_paused:
|
|
vc.pause()
|
|
|
|
def lock_for(self, guild: discord.Guild | None) -> asyncio.Lock:
|
|
return self.__locks.lock_for(guild, "not in a guild")
|
|
|
|
async def restore_vc(self, vcgid: int, vccid: int, vc_is_paused: bool) -> None:
|
|
try:
|
|
print(f"vc restoring {vcgid}")
|
|
guild: discord.Guild = await self.client.fetch_guild(vcgid)
|
|
async with self.lock_for(guild):
|
|
await self._restore_vc(guild, vccid, vc_is_paused)
|
|
except Exception as e:
|
|
print(f"vc {vcgid} {vccid} {vc_is_paused} failed")
|
|
traceback.print_exc()
|
|
else:
|
|
print(f"vc restored {vcgid} {vccid}")
|
|
|
|
async def restore_vcs(self) -> None:
|
|
vcs: list[tuple[int, int, bool]] = self.__queues.get("vcs", [])
|
|
try:
|
|
tasks = []
|
|
for vcgid, vccid, vc_is_paused in vcs:
|
|
tasks.append(asyncio.create_task(self.restore_vc(vcgid, vccid, vc_is_paused)))
|
|
for task in tasks:
|
|
await task
|
|
finally:
|
|
self.__vcs_restored.set_result(None)
|
|
|
|
async def restore(self) -> None:
|
|
async with self.restore_lock:
|
|
if not self.__vcs_restored.done():
|
|
await self.restore_vcs()
|
|
|
|
async def audios(self, ctx: Context, args: list[str]) -> AsyncIterable[Audio]:
|
|
assert ctx.guild is not None
|
|
argctx = ArgCtx(self.defaulteffects.get(ctx.guild.id), args)
|
|
ystate = YState(self.__servicing, self.__pool, ctx, argctx.sources)
|
|
self.__ystates[ctx.guild] = ystate
|
|
try:
|
|
async for audio in ystate.iterate():
|
|
yield audio
|
|
finally:
|
|
del self.__ystates[ctx.guild]
|
|
|
|
def cancel_loading(self, ctx: Context) -> None:
|
|
assert ctx.guild is not None
|
|
ystate = self.__ystates.get(ctx.guild)
|
|
if ystate is None:
|
|
return
|
|
if ystate.ctx.member != ctx.member:
|
|
assert_admin(ctx.member)
|
|
ystate.cancel()
|
|
|
|
def pool_json(self) -> ResponseType:
|
|
return self.__pool.json()
|
|
|
|
|
|
class MainMode:
|
|
def __init__(self, service: MainService, *, create: bool, force_play: bool) -> None:
|
|
self.mainservice = service
|
|
self.mains = service.mains
|
|
self.create = create
|
|
self.force_play = force_play
|
|
|
|
async def main_for_raw_vc(self, vc: discord.VoiceClient) -> MainAudio:
|
|
if vc.guild in self.mains:
|
|
source = self.mains[vc.guild]
|
|
elif self.create:
|
|
source = self.mains.setdefault(vc.guild, await self.mainservice.create(vc.guild))
|
|
else:
|
|
raise Explicit("not playing, use `queue pause` or `queue resume`")
|
|
if vc.source != source or self.create and not vc.is_playing() and (self.force_play or not vc.is_paused()):
|
|
vc.play(source)
|
|
return source
|
|
|
|
def context(self, ctx: Context) -> MainContext:
|
|
return MainContext(self, ctx)
|
|
|
|
|
|
class MainContext:
|
|
def __init__(self, mode: MainMode, ctx: Context) -> None:
|
|
self.mainservice = mode.mainservice
|
|
self.mode = mode
|
|
self.ctx = ctx
|
|
|
|
async def vc_main(self) -> tuple[discord.VoiceClient, MainAudio]:
|
|
vc = await self.mainservice.raw_vc_for(self.ctx)
|
|
return vc, await self.mode.main_for_raw_vc(vc)
|
|
|
|
async def vc(self) -> discord.VoiceClient:
|
|
vc, _ = await self.vc_main()
|
|
return vc
|
|
|
|
async def main(self) -> MainAudio:
|
|
_, source = await self.vc_main()
|
|
return source
|
|
|
|
async def queue(self) -> QueueAudio:
|
|
return (await self.main()).queue
|