From 77c0f16da32f41dc0c984962032b7441e0feb9f9 Mon Sep 17 00:00:00 2001 From: timofey Date: Sat, 24 Dec 2022 16:37:05 +0000 Subject: [PATCH] service hierarchy --- v6d3music/api.py | 14 +- v6d3music/app.py | 61 ++- v6d3music/commands.py | 508 ++++++++++----------- v6d3music/core/cache_url.py | 44 -- v6d3music/core/caching.py | 67 +++ v6d3music/core/create_ytaudio.py | 6 +- v6d3music/core/create_ytaudios.py | 21 - v6d3music/core/default_effects.py | 33 ++ v6d3music/core/entries_effects_for_args.py | 34 -- v6d3music/core/mainasrc.py | 69 --- v6d3music/core/mainaudio.py | 17 +- v6d3music/core/mainservice.py | 215 +++++++++ v6d3music/core/queueaudio.py | 29 +- v6d3music/core/real_url.py | 20 +- v6d3music/core/ystate.py | 5 +- v6d3music/core/yt_audios.py | 11 +- v6d3music/core/ytaudio.py | 10 +- v6d3music/run-bot.py | 191 +++----- v6d3music/utils/tor_prefix.py | 23 +- 19 files changed, 725 insertions(+), 653 deletions(-) delete mode 100644 v6d3music/core/cache_url.py create mode 100644 v6d3music/core/caching.py delete mode 100644 v6d3music/core/create_ytaudios.py create mode 100644 v6d3music/core/default_effects.py delete mode 100644 v6d3music/core/entries_effects_for_args.py delete mode 100644 v6d3music/core/mainasrc.py create mode 100644 v6d3music/core/mainservice.py diff --git a/v6d3music/api.py b/v6d3music/api.py index 495b4c7..16122aa 100644 --- a/v6d3music/api.py +++ b/v6d3music/api.py @@ -3,7 +3,7 @@ import time from typing import TypeAlias import discord -from v6d3music.core.mainasrc import main_for_raw_vc, raw_vc_for_member +from v6d3music.core.mainservice import MainService from v6d3music.core.mainaudio import MainAudio from v6d2ctx.context import Explicit @@ -27,8 +27,9 @@ class Api: def json(self) -> dict: return super().json() | {'explicit': None} - def __init__(self, client: discord.Client, roles: dict[str, str]) -> None: - self.client = client + def __init__(self, mainservice: MainService, roles: dict[str, str]) -> None: + self.mainservice = mainservice + self.client = mainservice.client self.roles = roles def is_operator(self, user_id: int) -> bool: @@ -166,10 +167,11 @@ class VoiceApi(GuildApi): ) -> None: super().__init__(api, api.member) self.channel = channel + self.mainservice = self.pi.mainservice async def _main_api(self) -> 'MainApi': - vc = await raw_vc_for_member(self.member) - main = await main_for_raw_vc(vc, create=False, force_play=False) + vc = await self.mainservice.raw_vc_for_member(self.member) + main = await self.mainservice.descriptor(create=False, force_play=False).main_for_raw_vc(vc) return MainApi(self, vc, main) def sub(self, request: dict) -> 'VoiceApi': @@ -210,7 +212,7 @@ class MainApi(VoiceApi): case {'type': 'queueformat'}: return await self.main.queue.format() case {'type': 'queuejson'}: - return await self.main.queue.pubjson(self.member) + return await self.main.queue.pubjson(self.member, self.request.get('limit', 1000)) case {'type': '?'}: return 'this is main api' case {'type': '*', 'requests': list() | dict() as requests}: diff --git a/v6d3music/app.py b/v6d3music/app.py index 74d6102..0ecdee2 100644 --- a/v6d3music/app.py +++ b/v6d3music/app.py @@ -2,6 +2,7 @@ import asyncio import functools import os import urllib.parse +from contextlib import AsyncExitStack from pathlib import Path from typing import Any, Callable, Coroutine, Generic, Hashable, TypeVar @@ -10,15 +11,14 @@ import discord from aiohttp import web from v6d3music.api import Api from v6d3music.config import auth_redirect, myroot +from v6d3music.core.mainservice import MainService from v6d3music.utils.bytes_hash import bytes_hash -from ptvp35 import Db, KVJson +from ptvp35 import * from v6d0auth.appfactory import AppFactory from v6d0auth.run_app import start_app from v6d1tokens.client import request_token -session_db = Db(myroot / 'session.db', kvfactory=KVJson()) - T = TypeVar('T') TKey = TypeVar('TKey', bound=Hashable) @@ -64,13 +64,15 @@ class MusicAppFactory(AppFactory): self, secret: str, client: discord.Client, - api: Api + api: Api, + db: DbConnection ): self.secret = secret self.redirect = auth_redirect self.loop = asyncio.get_running_loop() self.client = client self._api = api + self.db = db self._token_clients: CachedDictionary[str, dict | None] = CachedDictionary( self._token_client ) @@ -190,11 +192,10 @@ class MusicAppFactory(AppFactory): cid = self.user_id(user) return cid - @classmethod - def session_data(cls, session: str | None) -> dict: + def session_data(self, session: str | None) -> dict: if session is None: return {} - data = session_db.get(session, {}) + data = self.db.get(session, {}) if not isinstance(data, dict): return {} return data @@ -224,7 +225,7 @@ class MusicAppFactory(AppFactory): data = self.session_data(session) data['code'] = code data['token'] = await self.code_token(code) - await session_db.set(session, data) + await self.db.set(session, data) return response else: return web.FileResponse(self._path('auth.html')) @@ -279,15 +280,47 @@ class MusicAppFactory(AppFactory): except Api.MisusedApi as e: return web.json_response(e.json(), status=404) - @classmethod - async def start(cls, client: discord.Client): + +class AppContext: + def __init__(self, mainservice: MainService) -> None: + self.mainservice = mainservice + + async def start(self) -> tuple[web.Application, asyncio.Task[None]] | None: try: - factory = cls( + factory = MusicAppFactory( await request_token('music-client', 'token'), - client, - Api(client, {key: value for key, value in os.environ.items() if key.startswith('roles')}) + self.mainservice.client, + Api( + self.mainservice, + {key: value for key, value in os.environ.items() if key.startswith('roles')}, + ), + self.__db ) except aiohttp.ClientConnectorError: print('no web app (likely due to no token)') else: - await start_app(factory.app()) + app = factory.app() + task = asyncio.create_task(start_app(app)) + return app, task + + async def __aenter__(self) -> 'AppContext': + async with AsyncExitStack() as es: + self.__db = await es.enter_async_context(DbFactory(myroot / 'session.db', kvfactory=KVJson())) + self.__es = es.pop_all() + self.__task: asyncio.Task[ + tuple[web.Application, asyncio.Task[None]] | None + ] = asyncio.create_task(self.start()) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + async with self.__es: + if self.__task.done(): + result = await self.__task + if result is not None: + app, task = result + await task + await app.shutdown() + await app.cleanup() + else: + self.__task.cancel() + del self.__es diff --git a/v6d3music/commands.py b/v6d3music/commands.py index 56fa6e4..3394819 100644 --- a/v6d3music/commands.py +++ b/v6d3music/commands.py @@ -1,8 +1,8 @@ import shlex from typing import Callable -from v6d3music.core.entries_effects_for_args import * -from v6d3music.core.mainasrc import main_for, queue_for, vc_for +from v6d3music.core.default_effects import * +from v6d3music.core.mainservice import MainService from v6d3music.core.yt_audios import yt_audios from v6d3music.utils.assert_admin import assert_admin from v6d3music.utils.catch import catch @@ -14,276 +14,270 @@ from v6d2ctx.at_of import AtOf from v6d2ctx.context import Context, Explicit, command_type from v6d2ctx.lock_for import lock_for -at_of: AtOf[str, command_type] = AtOf() -at, of = at_of() +def get_of(mainservice: MainService, defaulteffects: DefaultEffects) -> Callable[[str], command_type]: + at_of: AtOf[str, command_type] = AtOf() + at, of = at_of() -@at('help') -async def help_(ctx: Context, args: list[str]) -> None: - match args: - case []: - await ctx.reply('music bot') - case [name]: - await ctx.reply(f'help for {name}: `{name} help`') + @at('help') + async def help_(ctx: Context, args: list[str]) -> None: + match args: + case []: + await ctx.reply('music bot') + case [name]: + await ctx.reply(f'help for {name}: `{name} help`') + @at('/') + @at('play') + async def play(ctx: Context, args: list[str]) -> None: + await catch( + ctx, args, + f''' + `play ...args` + `play url [- effects]/[+ preset] [[[h]]] [[m]] [s] [tor] ...args` + `pause` + `resume` + presets: {shlex.join(allowed_presets)} + ''', + (), 'help' + ) + async with lock_for(ctx.guild, 'not in a guild'): + queue = await mainservice.context(ctx, create=True, force_play=False).queue() + if ctx.message.attachments: + if len(ctx.message.attachments) > 1: + raise Explicit('no more than one attachment') + args = [ctx.message.attachments[0].url] + args + async for audio in yt_audios(mainservice.caching, defaulteffects, ctx, args): + queue.append(audio) + await ctx.reply('done') -@at('/') -@at('play') -async def play(ctx: Context, args: list[str]) -> None: - await catch( - ctx, args, - f''' -`play ...args` -`play url [- effects]/[+ preset] [[[h]]] [[m]] [s] [tor] ...args` -`pause` -`resume` -presets: {shlex.join(allowed_presets)} -''', - (), 'help' - ) - async with lock_for(ctx.guild, 'not in a guild'): - queue = await queue_for(ctx, create=True, force_play=False) - if ctx.message.attachments: - if len(ctx.message.attachments) > 1: - raise Explicit('no more than one attachment') - args = [ctx.message.attachments[0].url] + args - async for audio in yt_audios(ctx, args): - queue.append(audio) + @at('skip') + async def skip(ctx: Context, args: list[str]) -> None: + await catch( + ctx, args, ''' + `skip [first] [last]` + ''', 'help' + ) + assert ctx.member is not None + match args: + case []: + queue = await mainservice.context(ctx, create=False, force_play=False).queue() + queue.skip_at(0, ctx.member) + case [pos] if pos.isdecimal(): + pos = int(pos) + queue = await mainservice.context(ctx, create=False, force_play=False).queue() + queue.skip_at(pos, ctx.member) + case [pos0, pos1] if pos0.isdecimal() and pos1.isdecimal(): + pos0, pos1 = int(pos0), int(pos1) + queue = await mainservice.context(ctx, create=False, force_play=False).queue() + for i in range(pos0, pos1 + 1): + if not queue.skip_at(pos0, ctx.member): + pos0 += 1 + case _: + raise Explicit('misformatted') await ctx.reply('done') + @at('to') + async def skip_to(ctx: Context, args: list[str]) -> None: + await catch( + ctx, args, ''' + `to [[h]] [m] s` + ''', 'help' + ) + match args: + case [h, m, s] if h.isdecimal() and m.isdecimal() and s.isdecimal(): + seconds = 3600 * int(h) + 60 * int(m) + int(s) + case [m, s] if m.isdecimal() and s.isdecimal(): + seconds = 60 * int(m) + int(s) + case [s] if s.isdecimal(): + seconds = int(s) + case _: + raise Explicit('misformatted') + queue = await mainservice.context(ctx, create=False, force_play=False).queue() + queue.queue[0].set_seconds(seconds) -@at('skip') -async def skip(ctx: Context, args: list[str]) -> None: - await catch( - ctx, args, ''' -`skip [first] [last]` -''', 'help' - ) - assert ctx.member is not None - match args: - case []: - queue = await queue_for(ctx, create=False, force_play=False) - queue.skip_at(0, ctx.member) - case [pos] if pos.isdecimal(): - pos = int(pos) - queue = await queue_for(ctx, create=False, force_play=False) - queue.skip_at(pos, ctx.member) - case [pos0, pos1] if pos0.isdecimal() and pos1.isdecimal(): - pos0, pos1 = int(pos0), int(pos1) - queue = await queue_for(ctx, create=False, force_play=False) - for i in range(pos0, pos1 + 1): - if not queue.skip_at(pos0, ctx.member): - pos0 += 1 - case _: - raise Explicit('misformatted') - await ctx.reply('done') + @at('effects') + async def effects_(ctx: Context, args: list[str]) -> None: + await catch( + ctx, args, ''' + `effects - effects` + `effects + preset` + ''', 'help' + ) + match args: + case ['-', effects]: + pass + case ['+', preset]: + effects = effects_for_preset(preset) + case _: + raise Explicit('misformatted') + assert_admin(ctx.member) + queue = await mainservice.context(ctx, create=False, force_play=False).queue() + yta = queue.queue[0] + seconds = yta.source_seconds() + yta.options = options_for_effects(effects) + yta.set_seconds(seconds) + @at('default') + async def default(ctx: Context, args: list[str]) -> None: + await catch( + ctx, args, ''' + `default - effects` + `default + preset` + `default none` + ''', 'help' + ) + assert ctx.guild is not None + match args: + case ['-', effects]: + pass + case ['+', preset]: + effects = effects_for_preset(preset) + case ['none']: + effects = None + case []: + await ctx.reply(f'current default effects: {defaulteffects.get(ctx.guild.id)}') + return + case _: + raise Explicit('misformatted') + assert_admin(ctx.member) + await defaulteffects.set(ctx.guild.id, effects) + await ctx.reply(f'effects set to `{effects}`') -@at('to') -async def skip_to(ctx: Context, args: list[str]) -> None: - await catch( - ctx, args, ''' -`to [[h]] [m] s` -''', 'help' - ) - match args: - case [h, m, s] if h.isdecimal() and m.isdecimal() and s.isdecimal(): - seconds = 3600 * int(h) + 60 * int(m) + int(s) - case [m, s] if m.isdecimal() and s.isdecimal(): - seconds = 60 * int(m) + int(s) - case [s] if s.isdecimal(): - seconds = int(s) - case _: - raise Explicit('misformatted') - queue = await queue_for(ctx, create=False, force_play=False) - queue.queue[0].set_seconds(seconds) + @at('repeat') + async def repeat(ctx: Context, args: list[str]): + match args: + case []: + n = 1 + case [n_] if n_.isdecimal(): + n = int(n_) + case _: + raise Explicit('misformatted') + assert_admin(ctx.member) + queue = await mainservice.context(ctx, create=False, force_play=False).queue() + if not queue.queue: + raise Explicit('empty queue') + if n > 99: + raise Explicit('too long') + audio = queue.queue[0] + for _ in range(n): + queue.queue.insert(1, audio.copy()) + @at('branch') + async def branch(ctx: Context, args: list[str]): + match args: + case ['-', effects]: + pass + case ['+', preset]: + effects = effects_for_preset(preset) + case ['none']: + effects = '' + case []: + effects = None + case _: + raise Explicit('misformatted') + assert_admin(ctx.member) + queue = await mainservice.context(ctx, create=False, force_play=False).queue() + if not queue.queue: + raise Explicit('empty queue') + audio = queue.queue[0].branch() + if effects is not None: + seconds = audio.source_seconds() + audio.options = options_for_effects(effects or None) + audio.set_seconds(seconds) + else: + audio.set_source() + queue.queue.insert(1, audio) -@at('effects') -async def effects_(ctx: Context, args: list[str]) -> None: - await catch( - ctx, args, ''' -`effects - effects` -`effects + preset` -''', 'help' - ) - match args: - case ['-', effects]: - pass - case ['+', preset]: - effects = effects_for_preset(preset) - case _: - raise Explicit('misformatted') - assert_admin(ctx.member) - queue = await queue_for(ctx, create=False, force_play=False) - yta = queue.queue[0] - seconds = yta.source_seconds() - yta.options = options_for_effects(effects) - yta.set_seconds(seconds) - - -@at('default') -async def default(ctx: Context, args: list[str]) -> None: - await catch( - ctx, args, ''' -`default - effects` -`default + preset` -`default none` -''', 'help' - ) - assert ctx.guild is not None - match args: - case ['-', effects]: - pass - case ['+', preset]: - effects = effects_for_preset(preset) - case ['none']: - effects = None - case []: - await ctx.reply(f'current default effects: {default_effects(ctx.guild.id)}') - return - case _: - raise Explicit('misformatted') - assert_admin(ctx.member) - await set_default_effects(ctx.guild.id, effects) - await ctx.reply(f'effects set to `{effects}`') - - -@at('repeat') -async def repeat(ctx: Context, args: list[str]): - match args: - case []: - n = 1 - case [n_] if n_.isdecimal(): - n = int(n_) - case _: - raise Explicit('misformatted') - assert_admin(ctx.member) - queue = await queue_for(ctx, create=False, force_play=False) - if not queue.queue: - raise Explicit('empty queue') - if n > 99: - raise Explicit('too long') - audio = queue.queue[0] - for _ in range(n): - queue.queue.insert(1, audio.copy()) - - -@at('branch') -async def branch(ctx: Context, args: list[str]): - match args: - case ['-', effects]: - pass - case ['+', preset]: - effects = effects_for_preset(preset) - case ['none']: - effects = '' - case []: - effects = None - case _: - raise Explicit('misformatted') - assert_admin(ctx.member) - queue = await queue_for(ctx, create=False, force_play=False) - if not queue.queue: - raise Explicit('empty queue') - audio = queue.queue[0].branch() - if effects is not None: - seconds = audio.source_seconds() - audio.options = options_for_effects(effects or None) - audio.set_seconds(seconds) - else: - audio.set_source() - queue.queue.insert(1, audio) - - -@at('//') -@at('queue') -async def queue_(ctx: Context, args: list[str]) -> None: - await catch( - ctx, args, ''' -`queue` -`queue clear` -`queue resume` -`queue pause` -''', 'help' - ) - assert ctx.member is not None - match args: - case []: - await ctx.long( - (await (await queue_for(ctx, create=True, force_play=False)).format()).strip() or 'no queue' - ) - case ['clear']: - (await queue_for(ctx, create=False, force_play=False)).clear(ctx.member) - await ctx.reply('done') - case ['resume']: - async with lock_for(ctx.guild, 'not in a guild'): - await queue_for(ctx, create=True, force_play=True) + @at('//') + @at('queue') + async def queue_(ctx: Context, args: list[str]) -> None: + await catch( + ctx, args, ''' + `queue` + `queue clear` + `queue resume` + `queue pause` + ''', 'help' + ) + assert ctx.member is not None + match args: + case []: + await ctx.long( + ( + await ( + await mainservice.context(ctx, create=True, force_play=False).queue() + ).format() + ).strip() or 'no queue' + ) + case ['clear']: + (await mainservice.context(ctx, create=False, force_play=False).queue()).clear(ctx.member) await ctx.reply('done') - case ['pause']: - async with lock_for(ctx.guild, 'not in a guild'): - vc = await vc_for(ctx, create=True, force_play=False) - vc.pause() - await ctx.reply('done') - case _: - raise Explicit('misformatted') + case ['resume']: + async with lock_for(ctx.guild, 'not in a guild'): + await mainservice.context(ctx, create=True, force_play=True).vc() + await ctx.reply('done') + case ['pause']: + async with lock_for(ctx.guild, 'not in a guild'): + vc = await mainservice.context(ctx, create=True, force_play=False).vc() + vc.pause() + await ctx.reply('done') + case _: + raise Explicit('misformatted') + @at('swap') + async def swap(ctx: Context, args: list[str]) -> None: + await catch( + ctx, args, ''' + `swap a b` + ''', 'help' + ) + assert ctx.member is not None + match args: + case [a, b] if a.isdecimal() and b.isdecimal(): + a, b = int(a), int(b) + (await mainservice.context(ctx, create=False, force_play=False).queue()).swap(ctx.member, a, b) + case _: + raise Explicit('misformatted') -@at('swap') -async def swap(ctx: Context, args: list[str]) -> None: - await catch( - ctx, args, ''' -`swap a b` -''', 'help' - ) - assert ctx.member is not None - match args: - case [a, b] if a.isdecimal() and b.isdecimal(): - a, b = int(a), int(b) - (await queue_for(ctx, create=False, force_play=False)).swap(ctx.member, a, b) - case _: - raise Explicit('misformatted') + @at('move') + async def move(ctx: Context, args: list[str]) -> None: + await catch( + ctx, args, ''' + `move a b` + ''', 'help' + ) + assert ctx.member is not None + match args: + case [a, b] if a.isdecimal() and b.isdecimal(): + a, b = int(a), int(b) + (await mainservice.context(ctx, create=False, force_play=False).queue()).move(ctx.member, a, b) + case _: + raise Explicit('misformatted') + @at('volume') + async def volume_(ctx: Context, args: list[str]) -> None: + await catch( + ctx, args, ''' + `volume volume` + ''', 'help' + ) + assert ctx.member is not None + match args: + case [volume]: + volume = float(volume) + await (await mainservice.context(ctx, create=True, force_play=False).main()).set(volume, ctx.member) + case _: + raise Explicit('misformatted') -@at('move') -async def move(ctx: Context, args: list[str]) -> None: - await catch( - ctx, args, ''' -`move a b` -''', 'help' - ) - assert ctx.member is not None - match args: - case [a, b] if a.isdecimal() and b.isdecimal(): - a, b = int(a), int(b) - (await queue_for(ctx, create=False, force_play=False)).move(ctx.member, a, b) - case _: - raise Explicit('misformatted') + @at('pause') + async def pause(ctx: Context, _args: list[str]) -> None: + vc = await mainservice.context(ctx, create=False, force_play=False).vc() + vc.pause() + @at('resume') + async def resume(ctx: Context, _args: list[str]) -> None: + vc = await mainservice.context(ctx, create=False, force_play=True).vc() + vc.resume() -@at('volume') -async def volume_(ctx: Context, args: list[str]) -> None: - await catch( - ctx, args, ''' -`volume volume` -''', 'help' - ) - assert ctx.member is not None - match args: - case [volume]: - volume = float(volume) - await (await main_for(ctx, create=True, force_play=False)).set(volume, ctx.member) - case _: - raise Explicit('misformatted') - - -@at('pause') -async def pause(ctx: Context, _args: list[str]) -> None: - vc = await vc_for(ctx, create=False, force_play=False) - vc.pause() - - -@at('resume') -async def resume(ctx: Context, _args: list[str]) -> None: - vc = await vc_for(ctx, create=False, force_play=True) - vc.resume() + return of diff --git a/v6d3music/core/cache_url.py b/v6d3music/core/cache_url.py deleted file mode 100644 index 4faedd7..0000000 --- a/v6d3music/core/cache_url.py +++ /dev/null @@ -1,44 +0,0 @@ -import asyncio - -from ptvp35 import Db, KVJson -from v6d2ctx.lock_for import lock_for - -from v6d3music.config import myroot -from v6d3music.utils.tor_prefix import tor_prefix - -cache_root = myroot / 'cache' -cache_root.mkdir(exist_ok=True) -cache_db = Db(myroot / 'cache.db', kvfactory=KVJson()) - - -async def cache_url(hurl: str, rurl: str, override: bool, tor: bool) -> None: - async with lock_for(('cache', hurl), 'cache failed'): - if not override and cache_db.get(f'url:{hurl}', None) is not None: - return - cachable: bool = cache_db.get(f'cachable:{hurl}', False) - if cachable: - print('caching', hurl) - path = cache_root / f'{hurl}.opus' - tmp_path = cache_root / f'{hurl}.tmp.opus' - args = [] - if tor: - args.extend(tor_prefix()) - args.extend( - [ - 'ffmpeg', '-hide_banner', '-loglevel', 'warning', - '-reconnect', '1', '-reconnect_at_eof', '0', - '-reconnect_streamed', '1', '-reconnect_delay_max', '10', '-copy_unknown', - '-y', '-i', rurl, '-b:a', '128k', str(tmp_path) - ] - ) - ap = await asyncio.create_subprocess_exec(*args) - code = await ap.wait() - if code: - print(f'caching {hurl} failed with {code}') - return - await asyncio.to_thread(tmp_path.rename, path) - await cache_db.set(f'url:{hurl}', str(path)) - print('cached', hurl) - # await cache_db.set(f'cachable:{hurl}', False) - else: - await cache_db.set(f'cachable:{hurl}', True) diff --git a/v6d3music/core/caching.py b/v6d3music/core/caching.py new file mode 100644 index 0000000..d2e2b76 --- /dev/null +++ b/v6d3music/core/caching.py @@ -0,0 +1,67 @@ +import asyncio +from contextlib import AsyncExitStack + +from v6d3music.config import myroot +from v6d3music.utils.tor_prefix import tor_prefix + +from ptvp35 import * +from v6d2ctx.lock_for import lock_for + +__all__ = ('Caching',) + +cache_root = myroot / 'cache' +cache_root.mkdir(exist_ok=True) + + +class Caching: + async def cache_url(self, hurl: str, rurl: str, override: bool, tor: bool) -> None: + async with lock_for(('cache', hurl), 'cache failed'): + if not override and self.__db.get(f'url:{hurl}', None) is not None: + return + cachable: bool = self.__db.get(f'cachable:{hurl}', False) + if cachable: + print('caching', hurl) + path = cache_root / f'{hurl}.opus' + tmp_path = cache_root / f'{hurl}.tmp.opus' + args = [] + if tor: + args.extend(tor_prefix()) + args.extend( + [ + 'ffmpeg', '-hide_banner', '-loglevel', 'warning', + '-reconnect', '1', '-reconnect_at_eof', '0', + '-reconnect_streamed', '1', '-reconnect_delay_max', '10', '-copy_unknown', + '-y', '-i', rurl, '-b:a', '128k', str(tmp_path) + ] + ) + ap = await asyncio.create_subprocess_exec(*args) + code = await ap.wait() + if code: + print(f'caching {hurl} failed with {code}') + return + await asyncio.to_thread(tmp_path.rename, path) + await self.__db.set(f'url:{hurl}', str(path)) + print('cached', hurl) + # await cache_db.set(f'cachable:{hurl}', False) + else: + await self.__db.set(f'cachable:{hurl}', True) + + def get(self, hurl: str) -> str | None: + return self.__db.get(f'url:{hurl}', None) + + async def __aenter__(self) -> 'Caching': + es = AsyncExitStack() + async with es: + self.__db = await es.enter_async_context(DbFactory(myroot / 'cache.db', kvfactory=KVJson())) + self.__tasks = set() + self.__es = es.pop_all() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + async with self.__es: + del self.__es + + def schedule_cache(self, hurl: str, rurl: str, override: bool, tor: bool): + task = asyncio.create_task(self.cache_url(hurl, rurl, override, tor)) + self.__tasks.add(task) + task.add_done_callback(self.__tasks.discard) diff --git a/v6d3music/core/create_ytaudio.py b/v6d3music/core/create_ytaudio.py index c583ccb..5f54f94 100644 --- a/v6d3music/core/create_ytaudio.py +++ b/v6d3music/core/create_ytaudio.py @@ -4,6 +4,7 @@ from typing import Any, Optional from v6d2ctx.context import Context, Explicit, escape from v6d3music.core.real_url import real_url +from v6d3music.core.caching import Caching from v6d3music.core.ytaudio import YTAudio from v6d3music.utils.assert_admin import assert_admin from v6d3music.utils.options_for_effects import options_for_effects @@ -12,7 +13,7 @@ from v6d3music.utils.argctx import InfoCtx async def create_ytaudio( - ctx: Context, it: InfoCtx + caching: Caching, ctx: Context, it: InfoCtx ) -> YTAudio: assert ctx.member is not None if it.effects: @@ -24,7 +25,8 @@ async def create_ytaudio( else: options = None return YTAudio( - await real_url(it.info['url'], False, it.tor), + caching, + await real_url(caching, it.info['url'], False, it.tor), it.info['url'], f'{escape(it.info.get("title", "unknown"))} `Rby` {ctx.member}', options, diff --git a/v6d3music/core/create_ytaudios.py b/v6d3music/core/create_ytaudios.py deleted file mode 100644 index b5b2f3d..0000000 --- a/v6d3music/core/create_ytaudios.py +++ /dev/null @@ -1,21 +0,0 @@ -import asyncio -from typing import AsyncIterable - -from v6d2ctx.context import Context - -from v6d3music.core.create_ytaudio import create_ytaudio -from v6d3music.core.ytaudio import YTAudio -from v6d3music.utils.argctx import InfoCtx - - -async def create_ytaudios(ctx: Context, infos: list[InfoCtx]) -> AsyncIterable[YTAudio]: - for audio in await asyncio.gather( - *[ - create_ytaudio(ctx, it) - for - it - in - infos - ] - ): - yield audio diff --git a/v6d3music/core/default_effects.py b/v6d3music/core/default_effects.py new file mode 100644 index 0000000..59ffed5 --- /dev/null +++ b/v6d3music/core/default_effects.py @@ -0,0 +1,33 @@ +from contextlib import AsyncExitStack + +from v6d3music.config import myroot +from v6d3music.utils.presets import allowed_effects + +from ptvp35 import * +from v6d2ctx.context import Explicit + +__all__ = ('DefaultEffects',) + + +class DefaultEffects: + def get(self, gid: int) -> str | None: + effects = self.__db.get(gid, None) + if effects in allowed_effects: + return effects + else: + return None + + async def set(self, gid: int, effects: str | None) -> None: + if effects is not None and effects not in allowed_effects: + raise Explicit('these effects are not allowed') + await self.__db.set(gid, effects) + + async def __aenter__(self) -> 'DefaultEffects': + async with AsyncExitStack() as es: + self.__db = await es.enter_async_context(DbFactory(myroot / 'effects.db', kvfactory=KVJson())) + self.__es = es.pop_all() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + async with self.__es: + del self.__es diff --git a/v6d3music/core/entries_effects_for_args.py b/v6d3music/core/entries_effects_for_args.py deleted file mode 100644 index 9da95af..0000000 --- a/v6d3music/core/entries_effects_for_args.py +++ /dev/null @@ -1,34 +0,0 @@ -from typing import AsyncIterable - -from ptvp35 import Db, KVJson -from v6d2ctx.context import Explicit - -from v6d3music.utils.argctx import ArgCtx, InfoCtx -from v6d3music.config import myroot -from v6d3music.utils.presets import allowed_effects - - -__all__ = ('effects_db', 'default_effects', 'set_default_effects', 'entries_effects_for_args',) - - -effects_db = Db(myroot / 'effects.db', kvfactory=KVJson()) - - -def default_effects(gid: int) -> str | None: - effects = effects_db.get(gid, None) - if effects in allowed_effects: - return effects - else: - return None - - -async def set_default_effects(gid: int, effects: str | None) -> None: - if effects is not None and effects not in allowed_effects: - raise Explicit('these effects are not allowed') - await effects_db.set(gid, effects) - - -async def entries_effects_for_args(args: list[str], gid: int) -> AsyncIterable[InfoCtx]: - for ctx in ArgCtx(default_effects(gid), args).sources: - async for it in ctx.entries(): - yield it diff --git a/v6d3music/core/mainasrc.py b/v6d3music/core/mainasrc.py deleted file mode 100644 index 69a8586..0000000 --- a/v6d3music/core/mainasrc.py +++ /dev/null @@ -1,69 +0,0 @@ -import discord -from v6d2ctx.context import Context, Explicit - -from v6d3music.core.mainaudio import MainAudio -from v6d3music.core.queueaudio import QueueAudio - -mainasrcs: dict[discord.Guild, MainAudio] = {} - - -async def raw_vc_for_member(member: discord.Member) -> discord.VoiceClient: - vc: discord.VoiceProtocol | None = member.guild.voice_client - if vc is None or isinstance(vc, discord.VoiceClient) and not vc.is_connected(): - vs: discord.VoiceState | None = member.voice - if vs is None: - raise Explicit('not connected') - vch: discord.VoiceChannel | None = vs.channel # type: ignore - if vch is None: - raise Explicit('not connected') - try: - vc = await vch.connect() - except discord.ClientException: - vc = member.guild.voice_client - assert vc is not None - await member.guild.fetch_channels() - await vc.disconnect(force=True) - raise Explicit('try again later') - assert isinstance(vc, discord.VoiceClient) - return vc - - -async def raw_vc_for(ctx: Context) -> discord.VoiceClient: - if ctx.member is None: - raise Explicit('not in a guild') - return await raw_vc_for_member(ctx.member) - - -async def main_for_raw_vc(vc: discord.VoiceClient, *, create: bool, force_play: bool) -> MainAudio: - if vc.guild in mainasrcs: - source = mainasrcs[vc.guild] - else: - if create: - source = mainasrcs.setdefault( - vc.guild, - await MainAudio.create(vc.guild) - ) - else: - raise Explicit('not playing, use `queue pause` or `queue resume`') - if vc.source != source or create and not vc.is_playing() and (force_play or not vc.is_paused()): - vc.play(source) - return source - - -async def vc_main_for(ctx: Context, *, create: bool, force_play: bool) -> tuple[discord.VoiceClient, MainAudio]: - vc = await raw_vc_for(ctx) - return vc, await main_for_raw_vc(vc, create=create, force_play=force_play) - - -async def vc_for(ctx: Context, *, create: bool, force_play: bool) -> discord.VoiceClient: - vc, source = await vc_main_for(ctx, create=create, force_play=force_play) - return vc - - -async def main_for(ctx: Context, *, create: bool, force_play: bool) -> MainAudio: - vc, source = await vc_main_for(ctx, create=create, force_play=force_play) - return source - - -async def queue_for(ctx: Context, *, create: bool, force_play: bool) -> QueueAudio: - return (await main_for(ctx, create=create, force_play=force_play)).queue diff --git a/v6d3music/core/mainaudio.py b/v6d3music/core/mainaudio.py index 86debdd..c11fa91 100644 --- a/v6d3music/core/mainaudio.py +++ b/v6d3music/core/mainaudio.py @@ -1,16 +1,15 @@ import discord -from ptvp35 import Db, KVJson -from v6d2ctx.context import Explicit - -from v6d3music.config import myroot +from v6d3music.core.caching import Caching from v6d3music.core.queueaudio import QueueAudio from v6d3music.utils.assert_admin import assert_admin -volume_db = Db(myroot / 'volume.db', kvfactory=KVJson()) +from ptvp35 import * +from v6d2ctx.context import Explicit class MainAudio(discord.PCMVolumeTransformer): - def __init__(self, queue: QueueAudio, volume: float): + def __init__(self, db: DbConnection, queue: QueueAudio, volume: float): + self.db = db self.queue = queue super().__init__(self.queue, volume=volume) @@ -21,8 +20,8 @@ class MainAudio(discord.PCMVolumeTransformer): if volume > 1: raise Explicit('volume too big') self.volume = volume - await volume_db.set(member.guild.id, volume) + await self.db.set(member.guild.id, volume) @classmethod - async def create(cls, guild: discord.Guild) -> 'MainAudio': - return cls(await QueueAudio.create(guild), volume=volume_db.get(guild.id, 0.2)) + async def create(cls, caching: Caching, db: DbConnection, queues: DbConnection, guild: discord.Guild) -> 'MainAudio': + return cls(db, await QueueAudio.create(caching, queues, guild), volume=db.get(guild.id, 0.2)) diff --git a/v6d3music/core/mainservice.py b/v6d3music/core/mainservice.py new file mode 100644 index 0000000..79239fd --- /dev/null +++ b/v6d3music/core/mainservice.py @@ -0,0 +1,215 @@ +import asyncio +import traceback +from contextlib import AsyncExitStack + +import discord +from v6d3music.config import myroot +from v6d3music.core.caching import Caching +from v6d3music.core.mainaudio import MainAudio +from v6d3music.core.queueaudio import QueueAudio + +from ptvp35 import * +from v6d2ctx.context import Context, Explicit +from v6d2ctx.lock_for import lock_for + + +class MainService: + def __init__(self, client: discord.Client) -> None: + self.client = client + self.mains: dict[discord.Guild, MainAudio] = {} + + @staticmethod + async def raw_vc_for_member(member: discord.Member) -> discord.VoiceClient: + vc: discord.VoiceProtocol | None = member.guild.voice_client + if vc is None or isinstance(vc, discord.VoiceClient) and not vc.is_connected(): + vs: discord.VoiceState | None = member.voice + if vs is None: + raise Explicit('not connected') + vch: discord.abc.Connectable | None = vs.channel + if vch is None: + raise Explicit('not connected') + try: + vc = await vch.connect() + except discord.ClientException: + vc = member.guild.voice_client + assert vc is not None + await member.guild.fetch_channels() + await vc.disconnect(force=True) + raise Explicit('try again later') + assert isinstance(vc, discord.VoiceClient) + return vc + + async def raw_vc_for(self, ctx: Context) -> discord.VoiceClient: + if ctx.member is None: + raise Explicit('not in a guild') + return await self.raw_vc_for_member(ctx.member) + + def descriptor(self, *, create: bool, force_play: bool) -> 'MainDescriptor': + return MainDescriptor(self, create=create, force_play=force_play) + + def context(self, ctx: Context, *, create: bool, force_play: bool) -> 'MainContext': + return self.descriptor(create=create, force_play=force_play).context(ctx) + + async def create(self, guild: discord.Guild) -> MainAudio: + return await MainAudio.create(self.caching, self.__volumes, self.queues, guild) + + async def __aenter__(self) -> 'MainService': + async with AsyncExitStack() as es: + self.__volumes = await es.enter_async_context(DbFactory(myroot / 'volume.db', kvfactory=KVJson())) + self.queues = await es.enter_async_context(DbFactory(myroot / 'queue.db', kvfactory=KVJson())) + self.caching = await es.enter_async_context(Caching()) + self.__vcs_restored: asyncio.Future[None] = asyncio.Future() + self.__es = es.pop_all() + self.__save_task = asyncio.create_task(self.save_daemon()) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + async with self.__es: + await self.final_save() + del self.__es + + async def save_queues(self, delay: bool) -> None: + for mainasrc in list(self.mains.values()): + if delay: + await asyncio.sleep(0.01) + await mainasrc.queue.save(delay) + + async def save_vcs(self, delay: bool) -> None: + vcs = [] + vc: discord.VoiceClient + for vc in (vcc for vcc in self.client.voice_clients if isinstance(vcc, discord.VoiceClient)): + if delay: + await asyncio.sleep(0.01) + if vc.is_playing(): + if vc.guild is not None and vc.channel is not None: + vcs.append((vc.guild.id, vc.channel.id, vc.is_paused())) + self.queues.set_nowait('vcs', vcs) + + async def save_commit(self) -> None: + await self.queues.commit() + + async def _save_all(self, delay: bool, save_playing: bool) -> None: + await self.save_queues(delay) + if save_playing: + await self.save_vcs(delay) + await self.save_commit() + + async def save_all(self, delay: bool, save_playing: bool) -> None: + await self._save_all(delay, save_playing) + + async def save_job(self): + await self.__vcs_restored + print('starting saving') + while True: + await asyncio.sleep(1) + await self.save_all(True, not self.client.is_closed()) + + async def save_daemon(self): + try: + await self.save_job() + except asyncio.CancelledError: + pass + + async def final_save(self): + self.__save_task.cancel() + if not self.__vcs_restored.done(): + self.__vcs_restored.cancel() + else: + try: + await self.save_all(False, False) + print('saved') + except Exception: + traceback.print_exc() + + async def _restore_vc(self, guild: discord.Guild, vccid: int, vc_is_paused: bool) -> None: + channels = await guild.fetch_channels() + channel: discord.VoiceChannel + channel, = [ + ch for ch in + ( + chc for chc in channels + if + isinstance(chc, discord.VoiceChannel) + ) + if ch.id == vccid + ] + vp: discord.VoiceProtocol = await channel.connect() + assert isinstance(vp, discord.VoiceClient) + vc = vp + await self.descriptor(create=True, force_play=True).main_for_raw_vc(vc) + if vc_is_paused: + vc.pause() + + async def restore_vc(self, vcgid: int, vccid: int, vc_is_paused: bool) -> None: + try: + print(f'vc restoring {vcgid}') + guild: discord.Guild = await self.client.fetch_guild(vcgid) + async with lock_for(guild, 'not in a guild'): + await self._restore_vc(guild, vccid, vc_is_paused) + except Exception as e: + print(f'vc {vcgid} {vccid} {vc_is_paused} failed', e) + else: + print(f'vc restored {vcgid} {vccid}') + + async def restore_vcs(self) -> None: + vcs: list[tuple[int, int, bool]] = self.queues.get('vcs', []) + try: + tasks = [] + for vcgid, vccid, vc_is_paused in vcs: + tasks.append(asyncio.create_task(self.restore_vc(vcgid, vccid, vc_is_paused))) + for task in tasks: + await task + finally: + self.__vcs_restored.set_result(None) + + async def restore(self) -> None: + async with lock_for('vcs_restored', '...'): + if not self.__vcs_restored.done(): + await self.restore_vcs() + + +class MainDescriptor: + def __init__(self, service: MainService, *, create: bool, force_play: bool) -> None: + self.mainservice = service + self.mains = service.mains + self.create = create + self.force_play = force_play + + async def main_for_raw_vc(self, vc: discord.VoiceClient) -> MainAudio: + if vc.guild in self.mains: + source = self.mains[vc.guild] + elif self.create: + source = self.mains.setdefault( + vc.guild, + await self.mainservice.create(vc.guild) + ) + else: + raise Explicit('not playing, use `queue pause` or `queue resume`') + if vc.source != source or self.create and not vc.is_playing() and (self.force_play or not vc.is_paused()): + vc.play(source) + return source + + def context(self, ctx: Context) -> 'MainContext': + return MainContext(self, ctx) + + +class MainContext: + def __init__(self, descriptor: MainDescriptor, ctx: Context) -> None: + self.mainservice = descriptor.mainservice + self.descriptor = descriptor + self.ctx = ctx + + async def vc_main(self) -> tuple[discord.VoiceClient, MainAudio]: + vc = await self.mainservice.raw_vc_for(self.ctx) + return vc, await self.descriptor.main_for_raw_vc(vc) + + async def vc(self) -> discord.VoiceClient: + vc, _ = await self.vc_main() + return vc + + async def main(self) -> MainAudio: + _, source = await self.vc_main() + return source + + async def queue(self) -> QueueAudio: + return (await self.main()).queue diff --git a/v6d3music/core/queueaudio.py b/v6d3music/core/queueaudio.py index e2ccb72..26697dd 100644 --- a/v6d3music/core/queueaudio.py +++ b/v6d3music/core/queueaudio.py @@ -3,19 +3,19 @@ from collections import deque from io import StringIO import discord -from ptvp35 import Db, KVJson - -from v6d3music.config import myroot +from v6d3music.core.caching import Caching from v6d3music.core.ytaudio import YTAudio from v6d3music.utils.assert_admin import assert_admin from v6d3music.utils.fill import FILL -queue_db = Db(myroot / 'queue.db', kvfactory=KVJson()) +from ptvp35 import * + PRE_SET_LENGTH = 24 class QueueAudio(discord.AudioSource): - def __init__(self, guild: discord.Guild, respawned: list[YTAudio]): + def __init__(self, db: DbConnection, guild: discord.Guild, respawned: list[YTAudio]): + self.db = db self.queue: deque[YTAudio] = deque() for audio in respawned: self.append(audio) @@ -30,12 +30,12 @@ class QueueAudio(discord.AudioSource): return @staticmethod - async def respawned(guild: discord.Guild) -> list[YTAudio]: + async def respawned(caching: Caching, db: DbConnection, guild: discord.Guild) -> list[YTAudio]: respawned = [] try: - for audio_respawn in queue_db.get(guild.id, []): + for audio_respawn in db.get(guild.id, []): try: - respawned.append(await YTAudio.respawn(guild, audio_respawn)) + respawned.append(await YTAudio.respawn(caching, guild, audio_respawn)) except Exception as e: print('audio respawn failed', e) raise @@ -44,16 +44,16 @@ class QueueAudio(discord.AudioSource): return respawned @classmethod - async def create(cls, guild: discord.Guild): - return cls(guild, await cls.respawned(guild)) + async def create(cls, caching: Caching, db: DbConnection, guild: discord.Guild) -> 'QueueAudio': + return cls(db, guild, await cls.respawned(caching, db, guild)) - async def save(self, delay: bool): + async def save(self, delay: bool) -> None: hybernated = [] for audio in list(self.queue): if delay: await asyncio.sleep(0.01) hybernated.append(audio.hybernate()) - queue_db.set_nowait(self.guild.id, hybernated) + self.db.set_nowait(self.guild.id, hybernated) def append(self, audio: YTAudio): if len(self.queue) < PRE_SET_LENGTH: @@ -139,6 +139,7 @@ class QueueAudio(discord.AudioSource): except ValueError: pass - async def pubjson(self, member: discord.Member) -> list: + async def pubjson(self, member: discord.Member, limit: int) -> list: + import random audios = list(self.queue) - return [await audio.pubjson(member) for audio in audios] + return [await audio.pubjson(member) for audio, _ in zip(audios, range(limit))] diff --git a/v6d3music/core/real_url.py b/v6d3music/core/real_url.py index f680f24..687d5e4 100644 --- a/v6d3music/core/real_url.py +++ b/v6d3music/core/real_url.py @@ -1,27 +1,17 @@ import asyncio import os -from typing import Optional -from adaas.cachedb import RemoteCache -from v6d3music.core.cache_url import cache_db, cache_url +from v6d3music.core.caching import Caching from v6d3music.utils.bytes_hash import bytes_hash from v6d3music.utils.tor_prefix import tor_prefix +from adaas.cachedb import RemoteCache adaas_available = bool(os.getenv('adaasurl')) if adaas_available: print('running real_url through adaas') -_tasks = set() - - -def _schedule_cache(hurl: str, rurl: str, override: bool, tor: bool): - task = asyncio.create_task(cache_url(hurl, rurl, override, tor)) - _tasks.add(task) - task.add_done_callback(_tasks.discard) - - async def _resolve_url(url: str, tor: bool) -> str: args = [] if tor: @@ -39,15 +29,15 @@ async def _resolve_url(url: str, tor: bool) -> str: return (await ap.stdout.readline()).decode()[:-1] -async def real_url(url: str, override: bool, tor: bool) -> str: +async def real_url(caching: Caching, url: str, override: bool, tor: bool) -> str: if adaas_available and not tor: return await RemoteCache().real_url(url, override, tor) hurl: str = bytes_hash(url.encode()) if not override: - curl: Optional[str] = cache_db.get(f'url:{hurl}', None) + curl: str | None = caching.get(hurl) if curl is not None: print('using cached', hurl) return curl rurl: str = await _resolve_url(url, tor) - _schedule_cache(hurl, rurl, override, tor) + caching.schedule_cache(hurl, rurl, override, tor) return rurl diff --git a/v6d3music/core/ystate.py b/v6d3music/core/ystate.py index c4c36d5..5bdbd6c 100644 --- a/v6d3music/core/ystate.py +++ b/v6d3music/core/ystate.py @@ -14,7 +14,8 @@ __all__ = ('YState',) class YState: - def __init__(self, pool: Pool, ctx: Context, sources: Iterable[UrlCtx]) -> None: + def __init__(self, caching: Caching, pool: Pool, ctx: Context, sources: Iterable[UrlCtx]) -> None: + self.caching = caching self.pool = pool self.ctx = ctx self.sources: deque[UrlCtx] = deque(sources) @@ -52,7 +53,7 @@ class YState: async def result(self, entry: InfoCtx) -> YTAudio | None: try: - return await create_ytaudio(self.ctx, entry) + return await create_ytaudio(self.caching, self.ctx, entry) except Exception: if not entry.ignore: raise diff --git a/v6d3music/core/yt_audios.py b/v6d3music/core/yt_audios.py index 30405e0..221ab33 100644 --- a/v6d3music/core/yt_audios.py +++ b/v6d3music/core/yt_audios.py @@ -1,17 +1,18 @@ from typing import AsyncIterable -from v6d3music.core.entries_effects_for_args import * +from v6d3music.core.default_effects import * from v6d3music.core.ystate import * -from v6d3music.core.ytaudio import YTAudio +from v6d3music.core.ytaudio import * +from v6d3music.core.caching import Caching from v6d3music.processing.pool import * from v6d3music.utils.argctx import * from v6d2ctx.context import Context -async def yt_audios(ctx: Context, args: list[str]) -> AsyncIterable[YTAudio]: +async def yt_audios(caching: Caching, defaulteffects: DefaultEffects, ctx: Context, args: list[str]) -> AsyncIterable[YTAudio]: assert ctx.guild is not None - argctx = ArgCtx(default_effects(ctx.guild.id), args) + argctx = ArgCtx(defaulteffects.get(ctx.guild.id), args) async with Pool(5) as pool: - async for audio in YState(pool, ctx, argctx.sources).iterate(): + async for audio in YState(caching, pool, ctx, argctx.sources).iterate(): yield audio diff --git a/v6d3music/core/ytaudio.py b/v6d3music/core/ytaudio.py index 33312e9..07fb36b 100644 --- a/v6d3music/core/ytaudio.py +++ b/v6d3music/core/ytaudio.py @@ -5,6 +5,7 @@ from typing import Optional import discord from v6d3music.core.ffmpegnormalaudio import FFmpegNormalAudio from v6d3music.core.real_url import real_url +from v6d3music.core.caching import Caching from v6d3music.utils.fill import FILL from v6d3music.utils.sparq import sparq from v6d3music.utils.tor_prefix import tor_prefix @@ -22,6 +23,7 @@ class YTAudio(discord.AudioSource): def __init__( self, + caching: Caching, url: str, origin: str, description: str, @@ -33,6 +35,7 @@ class YTAudio(discord.AudioSource): *, stop_at: int | None = None ): + self.caching = caching self.url = url self.origin = origin self.description = description @@ -181,7 +184,7 @@ class YTAudio(discord.AudioSource): } @classmethod - async def respawn(cls, guild: discord.Guild, respawn: dict) -> 'YTAudio': + async def respawn(cls, caching: Caching, guild: discord.Guild, respawn: dict) -> 'YTAudio': member_id: int | None = respawn['rby'] if member_id is None: member = None @@ -194,6 +197,7 @@ class YTAudio(discord.AudioSource): except discord.NotFound: member = None audio = YTAudio( + caching, respawn['url'], respawn['origin'], respawn['description'], @@ -209,7 +213,7 @@ class YTAudio(discord.AudioSource): async def regenerate(self): try: print(f'regenerating {self.origin}') - self.url = await real_url(self.origin, True, self.tor) + self.url = await real_url(self.caching, self.origin, True, self.tor) if hasattr(self, 'source'): self.source.cleanup() self.set_source() @@ -228,6 +232,7 @@ class YTAudio(discord.AudioSource): def copy(self) -> 'YTAudio': return YTAudio( + self.caching, self.url, self.origin, self.description, @@ -242,6 +247,7 @@ class YTAudio(discord.AudioSource): raise Explicit('already branched') self.stop_at = stop_at = self.already_read + 50 audio = YTAudio( + self.caching, self.url, self.origin, self.description, diff --git a/v6d3music/run-bot.py b/v6d3music/run-bot.py index 92553f4..54bd7e6 100644 --- a/v6d3music/run-bot.py +++ b/v6d3music/run-bot.py @@ -3,22 +3,19 @@ import contextlib import os import sys import time -import traceback import discord -from v6d3music.app import MusicAppFactory, session_db -from v6d3music.commands import of +from v6d3music.app import AppContext +from v6d3music.commands import get_of from v6d3music.config import prefix -from v6d3music.core.cache_url import cache_db -from v6d3music.core.entries_effects_for_args import effects_db -from v6d3music.core.mainasrc import main_for_raw_vc, mainasrcs -from v6d3music.core.mainaudio import volume_db -from v6d3music.core.queueaudio import queue_db +from v6d3music.core.caching import * +from v6d3music.core.default_effects import * +from v6d3music.core.mainservice import MainService +from ptvp35 import * from rainbowadn.instrument import Instrumentation from v6d1tokens.client import request_token from v6d2ctx.handle_content import handle_content -from v6d2ctx.lock_for import lock_for from v6d2ctx.pain import ABlockMonitor from v6d2ctx.serve import serve @@ -27,14 +24,7 @@ asyncio.set_event_loop(loop) class MusicClient(discord.Client): - async def close(self) -> None: - save_task.cancel() - await super().close() - try: - await save_all(False, False) - print('saved') - except Exception: - traceback.print_exc() + pass client = MusicClient( @@ -49,69 +39,9 @@ client = MusicClient( reactions=True, message_content=True, ), - loop=loop + loop=loop, ) -vcs_restored = False - - -async def _restore_vc(guild: discord.Guild, vccid: int, vc_is_paused: bool) -> None: - channels = await guild.fetch_channels() - channel: discord.VoiceChannel - channel, = [ - ch for ch in - ( - chc for chc in channels - if - isinstance(chc, discord.VoiceChannel) - ) - if ch.id == vccid - ] - vp: discord.VoiceProtocol = await channel.connect() - assert isinstance(vp, discord.VoiceClient) - vc = vp - await main_for_raw_vc(vc, create=True, force_play=True) - if vc_is_paused: - vc.pause() - - -async def restore_vc(vcgid: int, vccid: int, vc_is_paused: bool) -> None: - try: - print(f'vc restoring {vcgid}') - guild: discord.Guild = await client.fetch_guild(vcgid) - async with lock_for(guild, 'not in a guild'): - await _restore_vc(guild, vccid, vc_is_paused) - except Exception as e: - print(f'vc {vcgid} {vccid} {vc_is_paused} failed', e) - else: - print(f'vc restored {vcgid} {vccid}') - - -async def restore_vcs(): - global vcs_restored - vcs: list[tuple[int, int, bool]] = queue_db.get('vcs', []) - try: - tasks = [] - for vcgid, vccid, vc_is_paused in vcs: - tasks.append(asyncio.create_task(restore_vc(vcgid, vccid, vc_is_paused))) - for task in tasks: - await task - finally: - vcs_restored = True - - -@client.event -async def on_ready(): - print('ready') - await client.change_presence( - activity=discord.Game( - name='феноменально', - ) - ) - async with lock_for('vcs_restored', '...'): - if not vcs_restored: - await restore_vcs() - banned_guilds = set(map(int, map(str.strip, os.getenv('banned_guilds', '').split(':')))) @@ -124,64 +54,23 @@ def message_allowed(message: discord.Message) -> bool: return guild_allowed(message.guild) -@client.event -async def on_message(message: discord.Message) -> None: - if message_allowed(message): - await handle_content(of, message, message.content, prefix, client) +def register_handlers(mainservice: MainService, defaulteffects: DefaultEffects): + of = get_of(mainservice, defaulteffects) + @client.event + async def on_message(message: discord.Message) -> None: + if message_allowed(message): + await handle_content(of, message, message.content, prefix, client) -async def save_queues(delay: bool): - for mainasrc in list(mainasrcs.values()): - if delay: - await asyncio.sleep(0.01) - await mainasrc.queue.save(delay) - - -async def save_vcs(delay: bool): - if vcs_restored: - vcs = [] - vc: discord.VoiceClient - for vc in (vcc for vcc in client.voice_clients if isinstance(vcc, discord.VoiceClient)): - if delay: - await asyncio.sleep(0.01) - if vc.is_playing(): - if vc.guild is not None and vc.channel is not None: - vcs.append((vc.guild.id, vc.channel.id, vc.is_paused())) - queue_db.set_nowait('vcs', vcs) - - -async def save_commit(): - await queue_db.commit() - - -async def save_all(delay: bool, save_playing: bool): - await save_queues(delay) - if save_playing: - await save_vcs(delay) - await save_commit() - - -async def save_job(): - while True: - await asyncio.sleep(1) - await save_all(True, True) - - -async def save_daemon(): - try: - await save_job() - except asyncio.CancelledError: - pass - - -async def start_app(): - await MusicAppFactory.start(client) - - -async def setup_tasks(): - global save_task - save_task = loop.create_task(save_daemon()) - loop.create_task(start_app()) + @client.event + async def on_ready(): + print('ready') + await client.change_presence( + activity=discord.Game( + name='феноменально', + ) + ) + await mainservice.restore() class UpgradeABMInit(Instrumentation): @@ -223,8 +112,37 @@ def _upgrade_abm() -> contextlib.ExitStack: raise RuntimeError +class PathPrint(Instrumentation): + def __init__(self, methodname: str, pref: str): + super().__init__(DbConnection, methodname) + self.pref = pref + + async def instrument(self, method, db: DbConnection, *args, **kwargs): + result = await method(db, *args, **kwargs) + try: + print(self.pref, db._DbConnection__path) # type: ignore + except Exception: + from traceback import print_exc + print_exc() + return result + + +def _db_ee() -> contextlib.ExitStack: + with contextlib.ExitStack() as es: + es.enter_context(PathPrint('_initialize', 'open :')) + es.enter_context(PathPrint('aclose', 'close:')) + return es.pop_all() + raise RuntimeError + + async def main(): - async with volume_db, queue_db, cache_db, session_db, effects_db, ABlockMonitor(delta=0.5): + async with ( + DefaultEffects() as defaulteffects, + MainService(client) as mainservice, + AppContext(mainservice), + ABlockMonitor(delta=0.5) + ): + register_handlers(mainservice, defaulteffects) if 'guerilla' in sys.argv: from pathlib import Path tokenpath = Path('.token.txt') @@ -238,7 +156,6 @@ async def main(): else: token = await request_token('music', 'token') await client.login(token) - loop.create_task(setup_tasks()) if os.getenv('v6tor', None) is None: print('no tor') await client.connect() @@ -246,5 +163,5 @@ async def main(): if __name__ == '__main__': - with _upgrade_abm(): + with _upgrade_abm(), _db_ee(): serve(main(), client, loop) diff --git a/v6d3music/utils/tor_prefix.py b/v6d3music/utils/tor_prefix.py index 89529af..bd587e7 100644 --- a/v6d3music/utils/tor_prefix.py +++ b/v6d3music/utils/tor_prefix.py @@ -1,24 +1,3 @@ __all__ = ('tor_prefix',) -import os - -if (address := os.getenv('v6tor', None)) is not None: - print('tor through torsocks') - try: - import socket - address = socket.gethostbyname_ex(address)[2][0] - except Exception: - print('failed tor resolution') - _tor_prefix = None - else: - print('tor successfully resolved') - _tor_prefix = ['torsocks', '--address', address] -else: - print('tor unavailable') - _tor_prefix = None - - -def tor_prefix(): - if _tor_prefix is None: - raise ValueError('tor unavailable') - return _tor_prefix +from adaas.tor_prefix import tor_prefix