service hierarchy

This commit is contained in:
AF 2022-12-24 16:37:05 +00:00
parent 45fee6d593
commit 77c0f16da3
19 changed files with 725 additions and 653 deletions

View File

@ -3,7 +3,7 @@ import time
from typing import TypeAlias from typing import TypeAlias
import discord import discord
from v6d3music.core.mainasrc import main_for_raw_vc, raw_vc_for_member from v6d3music.core.mainservice import MainService
from v6d3music.core.mainaudio import MainAudio from v6d3music.core.mainaudio import MainAudio
from v6d2ctx.context import Explicit from v6d2ctx.context import Explicit
@ -27,8 +27,9 @@ class Api:
def json(self) -> dict: def json(self) -> dict:
return super().json() | {'explicit': None} return super().json() | {'explicit': None}
def __init__(self, client: discord.Client, roles: dict[str, str]) -> None: def __init__(self, mainservice: MainService, roles: dict[str, str]) -> None:
self.client = client self.mainservice = mainservice
self.client = mainservice.client
self.roles = roles self.roles = roles
def is_operator(self, user_id: int) -> bool: def is_operator(self, user_id: int) -> bool:
@ -166,10 +167,11 @@ class VoiceApi(GuildApi):
) -> None: ) -> None:
super().__init__(api, api.member) super().__init__(api, api.member)
self.channel = channel self.channel = channel
self.mainservice = self.pi.mainservice
async def _main_api(self) -> 'MainApi': async def _main_api(self) -> 'MainApi':
vc = await raw_vc_for_member(self.member) vc = await self.mainservice.raw_vc_for_member(self.member)
main = await main_for_raw_vc(vc, create=False, force_play=False) main = await self.mainservice.descriptor(create=False, force_play=False).main_for_raw_vc(vc)
return MainApi(self, vc, main) return MainApi(self, vc, main)
def sub(self, request: dict) -> 'VoiceApi': def sub(self, request: dict) -> 'VoiceApi':
@ -210,7 +212,7 @@ class MainApi(VoiceApi):
case {'type': 'queueformat'}: case {'type': 'queueformat'}:
return await self.main.queue.format() return await self.main.queue.format()
case {'type': 'queuejson'}: case {'type': 'queuejson'}:
return await self.main.queue.pubjson(self.member) return await self.main.queue.pubjson(self.member, self.request.get('limit', 1000))
case {'type': '?'}: case {'type': '?'}:
return 'this is main api' return 'this is main api'
case {'type': '*', 'requests': list() | dict() as requests}: case {'type': '*', 'requests': list() | dict() as requests}:

View File

@ -2,6 +2,7 @@ import asyncio
import functools import functools
import os import os
import urllib.parse import urllib.parse
from contextlib import AsyncExitStack
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Coroutine, Generic, Hashable, TypeVar from typing import Any, Callable, Coroutine, Generic, Hashable, TypeVar
@ -10,15 +11,14 @@ import discord
from aiohttp import web from aiohttp import web
from v6d3music.api import Api from v6d3music.api import Api
from v6d3music.config import auth_redirect, myroot from v6d3music.config import auth_redirect, myroot
from v6d3music.core.mainservice import MainService
from v6d3music.utils.bytes_hash import bytes_hash from v6d3music.utils.bytes_hash import bytes_hash
from ptvp35 import Db, KVJson from ptvp35 import *
from v6d0auth.appfactory import AppFactory from v6d0auth.appfactory import AppFactory
from v6d0auth.run_app import start_app from v6d0auth.run_app import start_app
from v6d1tokens.client import request_token from v6d1tokens.client import request_token
session_db = Db(myroot / 'session.db', kvfactory=KVJson())
T = TypeVar('T') T = TypeVar('T')
TKey = TypeVar('TKey', bound=Hashable) TKey = TypeVar('TKey', bound=Hashable)
@ -64,13 +64,15 @@ class MusicAppFactory(AppFactory):
self, self,
secret: str, secret: str,
client: discord.Client, client: discord.Client,
api: Api api: Api,
db: DbConnection
): ):
self.secret = secret self.secret = secret
self.redirect = auth_redirect self.redirect = auth_redirect
self.loop = asyncio.get_running_loop() self.loop = asyncio.get_running_loop()
self.client = client self.client = client
self._api = api self._api = api
self.db = db
self._token_clients: CachedDictionary[str, dict | None] = CachedDictionary( self._token_clients: CachedDictionary[str, dict | None] = CachedDictionary(
self._token_client self._token_client
) )
@ -190,11 +192,10 @@ class MusicAppFactory(AppFactory):
cid = self.user_id(user) cid = self.user_id(user)
return cid return cid
@classmethod def session_data(self, session: str | None) -> dict:
def session_data(cls, session: str | None) -> dict:
if session is None: if session is None:
return {} return {}
data = session_db.get(session, {}) data = self.db.get(session, {})
if not isinstance(data, dict): if not isinstance(data, dict):
return {} return {}
return data return data
@ -224,7 +225,7 @@ class MusicAppFactory(AppFactory):
data = self.session_data(session) data = self.session_data(session)
data['code'] = code data['code'] = code
data['token'] = await self.code_token(code) data['token'] = await self.code_token(code)
await session_db.set(session, data) await self.db.set(session, data)
return response return response
else: else:
return web.FileResponse(self._path('auth.html')) return web.FileResponse(self._path('auth.html'))
@ -279,15 +280,47 @@ class MusicAppFactory(AppFactory):
except Api.MisusedApi as e: except Api.MisusedApi as e:
return web.json_response(e.json(), status=404) return web.json_response(e.json(), status=404)
@classmethod
async def start(cls, client: discord.Client): class AppContext:
def __init__(self, mainservice: MainService) -> None:
self.mainservice = mainservice
async def start(self) -> tuple[web.Application, asyncio.Task[None]] | None:
try: try:
factory = cls( factory = MusicAppFactory(
await request_token('music-client', 'token'), await request_token('music-client', 'token'),
client, self.mainservice.client,
Api(client, {key: value for key, value in os.environ.items() if key.startswith('roles')}) Api(
self.mainservice,
{key: value for key, value in os.environ.items() if key.startswith('roles')},
),
self.__db
) )
except aiohttp.ClientConnectorError: except aiohttp.ClientConnectorError:
print('no web app (likely due to no token)') print('no web app (likely due to no token)')
else: else:
await start_app(factory.app()) app = factory.app()
task = asyncio.create_task(start_app(app))
return app, task
async def __aenter__(self) -> 'AppContext':
async with AsyncExitStack() as es:
self.__db = await es.enter_async_context(DbFactory(myroot / 'session.db', kvfactory=KVJson()))
self.__es = es.pop_all()
self.__task: asyncio.Task[
tuple[web.Application, asyncio.Task[None]] | None
] = asyncio.create_task(self.start())
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
async with self.__es:
if self.__task.done():
result = await self.__task
if result is not None:
app, task = result
await task
await app.shutdown()
await app.cleanup()
else:
self.__task.cancel()
del self.__es

View File

@ -1,8 +1,8 @@
import shlex import shlex
from typing import Callable from typing import Callable
from v6d3music.core.entries_effects_for_args import * from v6d3music.core.default_effects import *
from v6d3music.core.mainasrc import main_for, queue_for, vc_for from v6d3music.core.mainservice import MainService
from v6d3music.core.yt_audios import yt_audios from v6d3music.core.yt_audios import yt_audios
from v6d3music.utils.assert_admin import assert_admin from v6d3music.utils.assert_admin import assert_admin
from v6d3music.utils.catch import catch from v6d3music.utils.catch import catch
@ -14,10 +14,11 @@ from v6d2ctx.at_of import AtOf
from v6d2ctx.context import Context, Explicit, command_type from v6d2ctx.context import Context, Explicit, command_type
from v6d2ctx.lock_for import lock_for from v6d2ctx.lock_for import lock_for
def get_of(mainservice: MainService, defaulteffects: DefaultEffects) -> Callable[[str], command_type]:
at_of: AtOf[str, command_type] = AtOf() at_of: AtOf[str, command_type] = AtOf()
at, of = at_of() at, of = at_of()
@at('help') @at('help')
async def help_(ctx: Context, args: list[str]) -> None: async def help_(ctx: Context, args: list[str]) -> None:
match args: match args:
@ -26,7 +27,6 @@ async def help_(ctx: Context, args: list[str]) -> None:
case [name]: case [name]:
await ctx.reply(f'help for {name}: `{name} help`') await ctx.reply(f'help for {name}: `{name} help`')
@at('/') @at('/')
@at('play') @at('play')
async def play(ctx: Context, args: list[str]) -> None: async def play(ctx: Context, args: list[str]) -> None:
@ -42,16 +42,15 @@ presets: {shlex.join(allowed_presets)}
(), 'help' (), 'help'
) )
async with lock_for(ctx.guild, 'not in a guild'): async with lock_for(ctx.guild, 'not in a guild'):
queue = await queue_for(ctx, create=True, force_play=False) queue = await mainservice.context(ctx, create=True, force_play=False).queue()
if ctx.message.attachments: if ctx.message.attachments:
if len(ctx.message.attachments) > 1: if len(ctx.message.attachments) > 1:
raise Explicit('no more than one attachment') raise Explicit('no more than one attachment')
args = [ctx.message.attachments[0].url] + args args = [ctx.message.attachments[0].url] + args
async for audio in yt_audios(ctx, args): async for audio in yt_audios(mainservice.caching, defaulteffects, ctx, args):
queue.append(audio) queue.append(audio)
await ctx.reply('done') await ctx.reply('done')
@at('skip') @at('skip')
async def skip(ctx: Context, args: list[str]) -> None: async def skip(ctx: Context, args: list[str]) -> None:
await catch( await catch(
@ -62,15 +61,15 @@ async def skip(ctx: Context, args: list[str]) -> None:
assert ctx.member is not None assert ctx.member is not None
match args: match args:
case []: case []:
queue = await queue_for(ctx, create=False, force_play=False) queue = await mainservice.context(ctx, create=False, force_play=False).queue()
queue.skip_at(0, ctx.member) queue.skip_at(0, ctx.member)
case [pos] if pos.isdecimal(): case [pos] if pos.isdecimal():
pos = int(pos) pos = int(pos)
queue = await queue_for(ctx, create=False, force_play=False) queue = await mainservice.context(ctx, create=False, force_play=False).queue()
queue.skip_at(pos, ctx.member) queue.skip_at(pos, ctx.member)
case [pos0, pos1] if pos0.isdecimal() and pos1.isdecimal(): case [pos0, pos1] if pos0.isdecimal() and pos1.isdecimal():
pos0, pos1 = int(pos0), int(pos1) pos0, pos1 = int(pos0), int(pos1)
queue = await queue_for(ctx, create=False, force_play=False) queue = await mainservice.context(ctx, create=False, force_play=False).queue()
for i in range(pos0, pos1 + 1): for i in range(pos0, pos1 + 1):
if not queue.skip_at(pos0, ctx.member): if not queue.skip_at(pos0, ctx.member):
pos0 += 1 pos0 += 1
@ -78,7 +77,6 @@ async def skip(ctx: Context, args: list[str]) -> None:
raise Explicit('misformatted') raise Explicit('misformatted')
await ctx.reply('done') await ctx.reply('done')
@at('to') @at('to')
async def skip_to(ctx: Context, args: list[str]) -> None: async def skip_to(ctx: Context, args: list[str]) -> None:
await catch( await catch(
@ -95,10 +93,9 @@ async def skip_to(ctx: Context, args: list[str]) -> None:
seconds = int(s) seconds = int(s)
case _: case _:
raise Explicit('misformatted') raise Explicit('misformatted')
queue = await queue_for(ctx, create=False, force_play=False) queue = await mainservice.context(ctx, create=False, force_play=False).queue()
queue.queue[0].set_seconds(seconds) queue.queue[0].set_seconds(seconds)
@at('effects') @at('effects')
async def effects_(ctx: Context, args: list[str]) -> None: async def effects_(ctx: Context, args: list[str]) -> None:
await catch( await catch(
@ -115,13 +112,12 @@ async def effects_(ctx: Context, args: list[str]) -> None:
case _: case _:
raise Explicit('misformatted') raise Explicit('misformatted')
assert_admin(ctx.member) assert_admin(ctx.member)
queue = await queue_for(ctx, create=False, force_play=False) queue = await mainservice.context(ctx, create=False, force_play=False).queue()
yta = queue.queue[0] yta = queue.queue[0]
seconds = yta.source_seconds() seconds = yta.source_seconds()
yta.options = options_for_effects(effects) yta.options = options_for_effects(effects)
yta.set_seconds(seconds) yta.set_seconds(seconds)
@at('default') @at('default')
async def default(ctx: Context, args: list[str]) -> None: async def default(ctx: Context, args: list[str]) -> None:
await catch( await catch(
@ -140,15 +136,14 @@ async def default(ctx: Context, args: list[str]) -> None:
case ['none']: case ['none']:
effects = None effects = None
case []: case []:
await ctx.reply(f'current default effects: {default_effects(ctx.guild.id)}') await ctx.reply(f'current default effects: {defaulteffects.get(ctx.guild.id)}')
return return
case _: case _:
raise Explicit('misformatted') raise Explicit('misformatted')
assert_admin(ctx.member) assert_admin(ctx.member)
await set_default_effects(ctx.guild.id, effects) await defaulteffects.set(ctx.guild.id, effects)
await ctx.reply(f'effects set to `{effects}`') await ctx.reply(f'effects set to `{effects}`')
@at('repeat') @at('repeat')
async def repeat(ctx: Context, args: list[str]): async def repeat(ctx: Context, args: list[str]):
match args: match args:
@ -159,7 +154,7 @@ async def repeat(ctx: Context, args: list[str]):
case _: case _:
raise Explicit('misformatted') raise Explicit('misformatted')
assert_admin(ctx.member) assert_admin(ctx.member)
queue = await queue_for(ctx, create=False, force_play=False) queue = await mainservice.context(ctx, create=False, force_play=False).queue()
if not queue.queue: if not queue.queue:
raise Explicit('empty queue') raise Explicit('empty queue')
if n > 99: if n > 99:
@ -168,7 +163,6 @@ async def repeat(ctx: Context, args: list[str]):
for _ in range(n): for _ in range(n):
queue.queue.insert(1, audio.copy()) queue.queue.insert(1, audio.copy())
@at('branch') @at('branch')
async def branch(ctx: Context, args: list[str]): async def branch(ctx: Context, args: list[str]):
match args: match args:
@ -183,7 +177,7 @@ async def branch(ctx: Context, args: list[str]):
case _: case _:
raise Explicit('misformatted') raise Explicit('misformatted')
assert_admin(ctx.member) assert_admin(ctx.member)
queue = await queue_for(ctx, create=False, force_play=False) queue = await mainservice.context(ctx, create=False, force_play=False).queue()
if not queue.queue: if not queue.queue:
raise Explicit('empty queue') raise Explicit('empty queue')
audio = queue.queue[0].branch() audio = queue.queue[0].branch()
@ -195,7 +189,6 @@ async def branch(ctx: Context, args: list[str]):
audio.set_source() audio.set_source()
queue.queue.insert(1, audio) queue.queue.insert(1, audio)
@at('//') @at('//')
@at('queue') @at('queue')
async def queue_(ctx: Context, args: list[str]) -> None: async def queue_(ctx: Context, args: list[str]) -> None:
@ -211,24 +204,27 @@ async def queue_(ctx: Context, args: list[str]) -> None:
match args: match args:
case []: case []:
await ctx.long( await ctx.long(
(await (await queue_for(ctx, create=True, force_play=False)).format()).strip() or 'no queue' (
await (
await mainservice.context(ctx, create=True, force_play=False).queue()
).format()
).strip() or 'no queue'
) )
case ['clear']: case ['clear']:
(await queue_for(ctx, create=False, force_play=False)).clear(ctx.member) (await mainservice.context(ctx, create=False, force_play=False).queue()).clear(ctx.member)
await ctx.reply('done') await ctx.reply('done')
case ['resume']: case ['resume']:
async with lock_for(ctx.guild, 'not in a guild'): async with lock_for(ctx.guild, 'not in a guild'):
await queue_for(ctx, create=True, force_play=True) await mainservice.context(ctx, create=True, force_play=True).vc()
await ctx.reply('done') await ctx.reply('done')
case ['pause']: case ['pause']:
async with lock_for(ctx.guild, 'not in a guild'): async with lock_for(ctx.guild, 'not in a guild'):
vc = await vc_for(ctx, create=True, force_play=False) vc = await mainservice.context(ctx, create=True, force_play=False).vc()
vc.pause() vc.pause()
await ctx.reply('done') await ctx.reply('done')
case _: case _:
raise Explicit('misformatted') raise Explicit('misformatted')
@at('swap') @at('swap')
async def swap(ctx: Context, args: list[str]) -> None: async def swap(ctx: Context, args: list[str]) -> None:
await catch( await catch(
@ -240,11 +236,10 @@ async def swap(ctx: Context, args: list[str]) -> None:
match args: match args:
case [a, b] if a.isdecimal() and b.isdecimal(): case [a, b] if a.isdecimal() and b.isdecimal():
a, b = int(a), int(b) a, b = int(a), int(b)
(await queue_for(ctx, create=False, force_play=False)).swap(ctx.member, a, b) (await mainservice.context(ctx, create=False, force_play=False).queue()).swap(ctx.member, a, b)
case _: case _:
raise Explicit('misformatted') raise Explicit('misformatted')
@at('move') @at('move')
async def move(ctx: Context, args: list[str]) -> None: async def move(ctx: Context, args: list[str]) -> None:
await catch( await catch(
@ -256,11 +251,10 @@ async def move(ctx: Context, args: list[str]) -> None:
match args: match args:
case [a, b] if a.isdecimal() and b.isdecimal(): case [a, b] if a.isdecimal() and b.isdecimal():
a, b = int(a), int(b) a, b = int(a), int(b)
(await queue_for(ctx, create=False, force_play=False)).move(ctx.member, a, b) (await mainservice.context(ctx, create=False, force_play=False).queue()).move(ctx.member, a, b)
case _: case _:
raise Explicit('misformatted') raise Explicit('misformatted')
@at('volume') @at('volume')
async def volume_(ctx: Context, args: list[str]) -> None: async def volume_(ctx: Context, args: list[str]) -> None:
await catch( await catch(
@ -272,18 +266,18 @@ async def volume_(ctx: Context, args: list[str]) -> None:
match args: match args:
case [volume]: case [volume]:
volume = float(volume) volume = float(volume)
await (await main_for(ctx, create=True, force_play=False)).set(volume, ctx.member) await (await mainservice.context(ctx, create=True, force_play=False).main()).set(volume, ctx.member)
case _: case _:
raise Explicit('misformatted') raise Explicit('misformatted')
@at('pause') @at('pause')
async def pause(ctx: Context, _args: list[str]) -> None: async def pause(ctx: Context, _args: list[str]) -> None:
vc = await vc_for(ctx, create=False, force_play=False) vc = await mainservice.context(ctx, create=False, force_play=False).vc()
vc.pause() vc.pause()
@at('resume') @at('resume')
async def resume(ctx: Context, _args: list[str]) -> None: async def resume(ctx: Context, _args: list[str]) -> None:
vc = await vc_for(ctx, create=False, force_play=True) vc = await mainservice.context(ctx, create=False, force_play=True).vc()
vc.resume() vc.resume()
return of

View File

@ -1,44 +0,0 @@
import asyncio
from ptvp35 import Db, KVJson
from v6d2ctx.lock_for import lock_for
from v6d3music.config import myroot
from v6d3music.utils.tor_prefix import tor_prefix
cache_root = myroot / 'cache'
cache_root.mkdir(exist_ok=True)
cache_db = Db(myroot / 'cache.db', kvfactory=KVJson())
async def cache_url(hurl: str, rurl: str, override: bool, tor: bool) -> None:
async with lock_for(('cache', hurl), 'cache failed'):
if not override and cache_db.get(f'url:{hurl}', None) is not None:
return
cachable: bool = cache_db.get(f'cachable:{hurl}', False)
if cachable:
print('caching', hurl)
path = cache_root / f'{hurl}.opus'
tmp_path = cache_root / f'{hurl}.tmp.opus'
args = []
if tor:
args.extend(tor_prefix())
args.extend(
[
'ffmpeg', '-hide_banner', '-loglevel', 'warning',
'-reconnect', '1', '-reconnect_at_eof', '0',
'-reconnect_streamed', '1', '-reconnect_delay_max', '10', '-copy_unknown',
'-y', '-i', rurl, '-b:a', '128k', str(tmp_path)
]
)
ap = await asyncio.create_subprocess_exec(*args)
code = await ap.wait()
if code:
print(f'caching {hurl} failed with {code}')
return
await asyncio.to_thread(tmp_path.rename, path)
await cache_db.set(f'url:{hurl}', str(path))
print('cached', hurl)
# await cache_db.set(f'cachable:{hurl}', False)
else:
await cache_db.set(f'cachable:{hurl}', True)

67
v6d3music/core/caching.py Normal file
View File

@ -0,0 +1,67 @@
import asyncio
from contextlib import AsyncExitStack
from v6d3music.config import myroot
from v6d3music.utils.tor_prefix import tor_prefix
from ptvp35 import *
from v6d2ctx.lock_for import lock_for
__all__ = ('Caching',)
cache_root = myroot / 'cache'
cache_root.mkdir(exist_ok=True)
class Caching:
async def cache_url(self, hurl: str, rurl: str, override: bool, tor: bool) -> None:
async with lock_for(('cache', hurl), 'cache failed'):
if not override and self.__db.get(f'url:{hurl}', None) is not None:
return
cachable: bool = self.__db.get(f'cachable:{hurl}', False)
if cachable:
print('caching', hurl)
path = cache_root / f'{hurl}.opus'
tmp_path = cache_root / f'{hurl}.tmp.opus'
args = []
if tor:
args.extend(tor_prefix())
args.extend(
[
'ffmpeg', '-hide_banner', '-loglevel', 'warning',
'-reconnect', '1', '-reconnect_at_eof', '0',
'-reconnect_streamed', '1', '-reconnect_delay_max', '10', '-copy_unknown',
'-y', '-i', rurl, '-b:a', '128k', str(tmp_path)
]
)
ap = await asyncio.create_subprocess_exec(*args)
code = await ap.wait()
if code:
print(f'caching {hurl} failed with {code}')
return
await asyncio.to_thread(tmp_path.rename, path)
await self.__db.set(f'url:{hurl}', str(path))
print('cached', hurl)
# await cache_db.set(f'cachable:{hurl}', False)
else:
await self.__db.set(f'cachable:{hurl}', True)
def get(self, hurl: str) -> str | None:
return self.__db.get(f'url:{hurl}', None)
async def __aenter__(self) -> 'Caching':
es = AsyncExitStack()
async with es:
self.__db = await es.enter_async_context(DbFactory(myroot / 'cache.db', kvfactory=KVJson()))
self.__tasks = set()
self.__es = es.pop_all()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
async with self.__es:
del self.__es
def schedule_cache(self, hurl: str, rurl: str, override: bool, tor: bool):
task = asyncio.create_task(self.cache_url(hurl, rurl, override, tor))
self.__tasks.add(task)
task.add_done_callback(self.__tasks.discard)

View File

@ -4,6 +4,7 @@ from typing import Any, Optional
from v6d2ctx.context import Context, Explicit, escape from v6d2ctx.context import Context, Explicit, escape
from v6d3music.core.real_url import real_url from v6d3music.core.real_url import real_url
from v6d3music.core.caching import Caching
from v6d3music.core.ytaudio import YTAudio from v6d3music.core.ytaudio import YTAudio
from v6d3music.utils.assert_admin import assert_admin from v6d3music.utils.assert_admin import assert_admin
from v6d3music.utils.options_for_effects import options_for_effects from v6d3music.utils.options_for_effects import options_for_effects
@ -12,7 +13,7 @@ from v6d3music.utils.argctx import InfoCtx
async def create_ytaudio( async def create_ytaudio(
ctx: Context, it: InfoCtx caching: Caching, ctx: Context, it: InfoCtx
) -> YTAudio: ) -> YTAudio:
assert ctx.member is not None assert ctx.member is not None
if it.effects: if it.effects:
@ -24,7 +25,8 @@ async def create_ytaudio(
else: else:
options = None options = None
return YTAudio( return YTAudio(
await real_url(it.info['url'], False, it.tor), caching,
await real_url(caching, it.info['url'], False, it.tor),
it.info['url'], it.info['url'],
f'{escape(it.info.get("title", "unknown"))} `Rby` {ctx.member}', f'{escape(it.info.get("title", "unknown"))} `Rby` {ctx.member}',
options, options,

View File

@ -1,21 +0,0 @@
import asyncio
from typing import AsyncIterable
from v6d2ctx.context import Context
from v6d3music.core.create_ytaudio import create_ytaudio
from v6d3music.core.ytaudio import YTAudio
from v6d3music.utils.argctx import InfoCtx
async def create_ytaudios(ctx: Context, infos: list[InfoCtx]) -> AsyncIterable[YTAudio]:
for audio in await asyncio.gather(
*[
create_ytaudio(ctx, it)
for
it
in
infos
]
):
yield audio

View File

@ -0,0 +1,33 @@
from contextlib import AsyncExitStack
from v6d3music.config import myroot
from v6d3music.utils.presets import allowed_effects
from ptvp35 import *
from v6d2ctx.context import Explicit
__all__ = ('DefaultEffects',)
class DefaultEffects:
def get(self, gid: int) -> str | None:
effects = self.__db.get(gid, None)
if effects in allowed_effects:
return effects
else:
return None
async def set(self, gid: int, effects: str | None) -> None:
if effects is not None and effects not in allowed_effects:
raise Explicit('these effects are not allowed')
await self.__db.set(gid, effects)
async def __aenter__(self) -> 'DefaultEffects':
async with AsyncExitStack() as es:
self.__db = await es.enter_async_context(DbFactory(myroot / 'effects.db', kvfactory=KVJson()))
self.__es = es.pop_all()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
async with self.__es:
del self.__es

View File

@ -1,34 +0,0 @@
from typing import AsyncIterable
from ptvp35 import Db, KVJson
from v6d2ctx.context import Explicit
from v6d3music.utils.argctx import ArgCtx, InfoCtx
from v6d3music.config import myroot
from v6d3music.utils.presets import allowed_effects
__all__ = ('effects_db', 'default_effects', 'set_default_effects', 'entries_effects_for_args',)
effects_db = Db(myroot / 'effects.db', kvfactory=KVJson())
def default_effects(gid: int) -> str | None:
effects = effects_db.get(gid, None)
if effects in allowed_effects:
return effects
else:
return None
async def set_default_effects(gid: int, effects: str | None) -> None:
if effects is not None and effects not in allowed_effects:
raise Explicit('these effects are not allowed')
await effects_db.set(gid, effects)
async def entries_effects_for_args(args: list[str], gid: int) -> AsyncIterable[InfoCtx]:
for ctx in ArgCtx(default_effects(gid), args).sources:
async for it in ctx.entries():
yield it

View File

@ -1,69 +0,0 @@
import discord
from v6d2ctx.context import Context, Explicit
from v6d3music.core.mainaudio import MainAudio
from v6d3music.core.queueaudio import QueueAudio
mainasrcs: dict[discord.Guild, MainAudio] = {}
async def raw_vc_for_member(member: discord.Member) -> discord.VoiceClient:
vc: discord.VoiceProtocol | None = member.guild.voice_client
if vc is None or isinstance(vc, discord.VoiceClient) and not vc.is_connected():
vs: discord.VoiceState | None = member.voice
if vs is None:
raise Explicit('not connected')
vch: discord.VoiceChannel | None = vs.channel # type: ignore
if vch is None:
raise Explicit('not connected')
try:
vc = await vch.connect()
except discord.ClientException:
vc = member.guild.voice_client
assert vc is not None
await member.guild.fetch_channels()
await vc.disconnect(force=True)
raise Explicit('try again later')
assert isinstance(vc, discord.VoiceClient)
return vc
async def raw_vc_for(ctx: Context) -> discord.VoiceClient:
if ctx.member is None:
raise Explicit('not in a guild')
return await raw_vc_for_member(ctx.member)
async def main_for_raw_vc(vc: discord.VoiceClient, *, create: bool, force_play: bool) -> MainAudio:
if vc.guild in mainasrcs:
source = mainasrcs[vc.guild]
else:
if create:
source = mainasrcs.setdefault(
vc.guild,
await MainAudio.create(vc.guild)
)
else:
raise Explicit('not playing, use `queue pause` or `queue resume`')
if vc.source != source or create and not vc.is_playing() and (force_play or not vc.is_paused()):
vc.play(source)
return source
async def vc_main_for(ctx: Context, *, create: bool, force_play: bool) -> tuple[discord.VoiceClient, MainAudio]:
vc = await raw_vc_for(ctx)
return vc, await main_for_raw_vc(vc, create=create, force_play=force_play)
async def vc_for(ctx: Context, *, create: bool, force_play: bool) -> discord.VoiceClient:
vc, source = await vc_main_for(ctx, create=create, force_play=force_play)
return vc
async def main_for(ctx: Context, *, create: bool, force_play: bool) -> MainAudio:
vc, source = await vc_main_for(ctx, create=create, force_play=force_play)
return source
async def queue_for(ctx: Context, *, create: bool, force_play: bool) -> QueueAudio:
return (await main_for(ctx, create=create, force_play=force_play)).queue

View File

@ -1,16 +1,15 @@
import discord import discord
from ptvp35 import Db, KVJson from v6d3music.core.caching import Caching
from v6d2ctx.context import Explicit
from v6d3music.config import myroot
from v6d3music.core.queueaudio import QueueAudio from v6d3music.core.queueaudio import QueueAudio
from v6d3music.utils.assert_admin import assert_admin from v6d3music.utils.assert_admin import assert_admin
volume_db = Db(myroot / 'volume.db', kvfactory=KVJson()) from ptvp35 import *
from v6d2ctx.context import Explicit
class MainAudio(discord.PCMVolumeTransformer): class MainAudio(discord.PCMVolumeTransformer):
def __init__(self, queue: QueueAudio, volume: float): def __init__(self, db: DbConnection, queue: QueueAudio, volume: float):
self.db = db
self.queue = queue self.queue = queue
super().__init__(self.queue, volume=volume) super().__init__(self.queue, volume=volume)
@ -21,8 +20,8 @@ class MainAudio(discord.PCMVolumeTransformer):
if volume > 1: if volume > 1:
raise Explicit('volume too big') raise Explicit('volume too big')
self.volume = volume self.volume = volume
await volume_db.set(member.guild.id, volume) await self.db.set(member.guild.id, volume)
@classmethod @classmethod
async def create(cls, guild: discord.Guild) -> 'MainAudio': async def create(cls, caching: Caching, db: DbConnection, queues: DbConnection, guild: discord.Guild) -> 'MainAudio':
return cls(await QueueAudio.create(guild), volume=volume_db.get(guild.id, 0.2)) return cls(db, await QueueAudio.create(caching, queues, guild), volume=db.get(guild.id, 0.2))

View File

@ -0,0 +1,215 @@
import asyncio
import traceback
from contextlib import AsyncExitStack
import discord
from v6d3music.config import myroot
from v6d3music.core.caching import Caching
from v6d3music.core.mainaudio import MainAudio
from v6d3music.core.queueaudio import QueueAudio
from ptvp35 import *
from v6d2ctx.context import Context, Explicit
from v6d2ctx.lock_for import lock_for
class MainService:
def __init__(self, client: discord.Client) -> None:
self.client = client
self.mains: dict[discord.Guild, MainAudio] = {}
@staticmethod
async def raw_vc_for_member(member: discord.Member) -> discord.VoiceClient:
vc: discord.VoiceProtocol | None = member.guild.voice_client
if vc is None or isinstance(vc, discord.VoiceClient) and not vc.is_connected():
vs: discord.VoiceState | None = member.voice
if vs is None:
raise Explicit('not connected')
vch: discord.abc.Connectable | None = vs.channel
if vch is None:
raise Explicit('not connected')
try:
vc = await vch.connect()
except discord.ClientException:
vc = member.guild.voice_client
assert vc is not None
await member.guild.fetch_channels()
await vc.disconnect(force=True)
raise Explicit('try again later')
assert isinstance(vc, discord.VoiceClient)
return vc
async def raw_vc_for(self, ctx: Context) -> discord.VoiceClient:
if ctx.member is None:
raise Explicit('not in a guild')
return await self.raw_vc_for_member(ctx.member)
def descriptor(self, *, create: bool, force_play: bool) -> 'MainDescriptor':
return MainDescriptor(self, create=create, force_play=force_play)
def context(self, ctx: Context, *, create: bool, force_play: bool) -> 'MainContext':
return self.descriptor(create=create, force_play=force_play).context(ctx)
async def create(self, guild: discord.Guild) -> MainAudio:
return await MainAudio.create(self.caching, self.__volumes, self.queues, guild)
async def __aenter__(self) -> 'MainService':
async with AsyncExitStack() as es:
self.__volumes = await es.enter_async_context(DbFactory(myroot / 'volume.db', kvfactory=KVJson()))
self.queues = await es.enter_async_context(DbFactory(myroot / 'queue.db', kvfactory=KVJson()))
self.caching = await es.enter_async_context(Caching())
self.__vcs_restored: asyncio.Future[None] = asyncio.Future()
self.__es = es.pop_all()
self.__save_task = asyncio.create_task(self.save_daemon())
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
async with self.__es:
await self.final_save()
del self.__es
async def save_queues(self, delay: bool) -> None:
for mainasrc in list(self.mains.values()):
if delay:
await asyncio.sleep(0.01)
await mainasrc.queue.save(delay)
async def save_vcs(self, delay: bool) -> None:
vcs = []
vc: discord.VoiceClient
for vc in (vcc for vcc in self.client.voice_clients if isinstance(vcc, discord.VoiceClient)):
if delay:
await asyncio.sleep(0.01)
if vc.is_playing():
if vc.guild is not None and vc.channel is not None:
vcs.append((vc.guild.id, vc.channel.id, vc.is_paused()))
self.queues.set_nowait('vcs', vcs)
async def save_commit(self) -> None:
await self.queues.commit()
async def _save_all(self, delay: bool, save_playing: bool) -> None:
await self.save_queues(delay)
if save_playing:
await self.save_vcs(delay)
await self.save_commit()
async def save_all(self, delay: bool, save_playing: bool) -> None:
await self._save_all(delay, save_playing)
async def save_job(self):
await self.__vcs_restored
print('starting saving')
while True:
await asyncio.sleep(1)
await self.save_all(True, not self.client.is_closed())
async def save_daemon(self):
try:
await self.save_job()
except asyncio.CancelledError:
pass
async def final_save(self):
self.__save_task.cancel()
if not self.__vcs_restored.done():
self.__vcs_restored.cancel()
else:
try:
await self.save_all(False, False)
print('saved')
except Exception:
traceback.print_exc()
async def _restore_vc(self, guild: discord.Guild, vccid: int, vc_is_paused: bool) -> None:
channels = await guild.fetch_channels()
channel: discord.VoiceChannel
channel, = [
ch for ch in
(
chc for chc in channels
if
isinstance(chc, discord.VoiceChannel)
)
if ch.id == vccid
]
vp: discord.VoiceProtocol = await channel.connect()
assert isinstance(vp, discord.VoiceClient)
vc = vp
await self.descriptor(create=True, force_play=True).main_for_raw_vc(vc)
if vc_is_paused:
vc.pause()
async def restore_vc(self, vcgid: int, vccid: int, vc_is_paused: bool) -> None:
try:
print(f'vc restoring {vcgid}')
guild: discord.Guild = await self.client.fetch_guild(vcgid)
async with lock_for(guild, 'not in a guild'):
await self._restore_vc(guild, vccid, vc_is_paused)
except Exception as e:
print(f'vc {vcgid} {vccid} {vc_is_paused} failed', e)
else:
print(f'vc restored {vcgid} {vccid}')
async def restore_vcs(self) -> None:
vcs: list[tuple[int, int, bool]] = self.queues.get('vcs', [])
try:
tasks = []
for vcgid, vccid, vc_is_paused in vcs:
tasks.append(asyncio.create_task(self.restore_vc(vcgid, vccid, vc_is_paused)))
for task in tasks:
await task
finally:
self.__vcs_restored.set_result(None)
async def restore(self) -> None:
async with lock_for('vcs_restored', '...'):
if not self.__vcs_restored.done():
await self.restore_vcs()
class MainDescriptor:
def __init__(self, service: MainService, *, create: bool, force_play: bool) -> None:
self.mainservice = service
self.mains = service.mains
self.create = create
self.force_play = force_play
async def main_for_raw_vc(self, vc: discord.VoiceClient) -> MainAudio:
if vc.guild in self.mains:
source = self.mains[vc.guild]
elif self.create:
source = self.mains.setdefault(
vc.guild,
await self.mainservice.create(vc.guild)
)
else:
raise Explicit('not playing, use `queue pause` or `queue resume`')
if vc.source != source or self.create and not vc.is_playing() and (self.force_play or not vc.is_paused()):
vc.play(source)
return source
def context(self, ctx: Context) -> 'MainContext':
return MainContext(self, ctx)
class MainContext:
def __init__(self, descriptor: MainDescriptor, ctx: Context) -> None:
self.mainservice = descriptor.mainservice
self.descriptor = descriptor
self.ctx = ctx
async def vc_main(self) -> tuple[discord.VoiceClient, MainAudio]:
vc = await self.mainservice.raw_vc_for(self.ctx)
return vc, await self.descriptor.main_for_raw_vc(vc)
async def vc(self) -> discord.VoiceClient:
vc, _ = await self.vc_main()
return vc
async def main(self) -> MainAudio:
_, source = await self.vc_main()
return source
async def queue(self) -> QueueAudio:
return (await self.main()).queue

View File

@ -3,19 +3,19 @@ from collections import deque
from io import StringIO from io import StringIO
import discord import discord
from ptvp35 import Db, KVJson from v6d3music.core.caching import Caching
from v6d3music.config import myroot
from v6d3music.core.ytaudio import YTAudio from v6d3music.core.ytaudio import YTAudio
from v6d3music.utils.assert_admin import assert_admin from v6d3music.utils.assert_admin import assert_admin
from v6d3music.utils.fill import FILL from v6d3music.utils.fill import FILL
queue_db = Db(myroot / 'queue.db', kvfactory=KVJson()) from ptvp35 import *
PRE_SET_LENGTH = 24 PRE_SET_LENGTH = 24
class QueueAudio(discord.AudioSource): class QueueAudio(discord.AudioSource):
def __init__(self, guild: discord.Guild, respawned: list[YTAudio]): def __init__(self, db: DbConnection, guild: discord.Guild, respawned: list[YTAudio]):
self.db = db
self.queue: deque[YTAudio] = deque() self.queue: deque[YTAudio] = deque()
for audio in respawned: for audio in respawned:
self.append(audio) self.append(audio)
@ -30,12 +30,12 @@ class QueueAudio(discord.AudioSource):
return return
@staticmethod @staticmethod
async def respawned(guild: discord.Guild) -> list[YTAudio]: async def respawned(caching: Caching, db: DbConnection, guild: discord.Guild) -> list[YTAudio]:
respawned = [] respawned = []
try: try:
for audio_respawn in queue_db.get(guild.id, []): for audio_respawn in db.get(guild.id, []):
try: try:
respawned.append(await YTAudio.respawn(guild, audio_respawn)) respawned.append(await YTAudio.respawn(caching, guild, audio_respawn))
except Exception as e: except Exception as e:
print('audio respawn failed', e) print('audio respawn failed', e)
raise raise
@ -44,16 +44,16 @@ class QueueAudio(discord.AudioSource):
return respawned return respawned
@classmethod @classmethod
async def create(cls, guild: discord.Guild): async def create(cls, caching: Caching, db: DbConnection, guild: discord.Guild) -> 'QueueAudio':
return cls(guild, await cls.respawned(guild)) return cls(db, guild, await cls.respawned(caching, db, guild))
async def save(self, delay: bool): async def save(self, delay: bool) -> None:
hybernated = [] hybernated = []
for audio in list(self.queue): for audio in list(self.queue):
if delay: if delay:
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
hybernated.append(audio.hybernate()) hybernated.append(audio.hybernate())
queue_db.set_nowait(self.guild.id, hybernated) self.db.set_nowait(self.guild.id, hybernated)
def append(self, audio: YTAudio): def append(self, audio: YTAudio):
if len(self.queue) < PRE_SET_LENGTH: if len(self.queue) < PRE_SET_LENGTH:
@ -139,6 +139,7 @@ class QueueAudio(discord.AudioSource):
except ValueError: except ValueError:
pass pass
async def pubjson(self, member: discord.Member) -> list: async def pubjson(self, member: discord.Member, limit: int) -> list:
import random
audios = list(self.queue) audios = list(self.queue)
return [await audio.pubjson(member) for audio in audios] return [await audio.pubjson(member) for audio, _ in zip(audios, range(limit))]

View File

@ -1,27 +1,17 @@
import asyncio import asyncio
import os import os
from typing import Optional
from adaas.cachedb import RemoteCache
from v6d3music.core.cache_url import cache_db, cache_url from v6d3music.core.caching import Caching
from v6d3music.utils.bytes_hash import bytes_hash from v6d3music.utils.bytes_hash import bytes_hash
from v6d3music.utils.tor_prefix import tor_prefix from v6d3music.utils.tor_prefix import tor_prefix
from adaas.cachedb import RemoteCache
adaas_available = bool(os.getenv('adaasurl')) adaas_available = bool(os.getenv('adaasurl'))
if adaas_available: if adaas_available:
print('running real_url through adaas') print('running real_url through adaas')
_tasks = set()
def _schedule_cache(hurl: str, rurl: str, override: bool, tor: bool):
task = asyncio.create_task(cache_url(hurl, rurl, override, tor))
_tasks.add(task)
task.add_done_callback(_tasks.discard)
async def _resolve_url(url: str, tor: bool) -> str: async def _resolve_url(url: str, tor: bool) -> str:
args = [] args = []
if tor: if tor:
@ -39,15 +29,15 @@ async def _resolve_url(url: str, tor: bool) -> str:
return (await ap.stdout.readline()).decode()[:-1] return (await ap.stdout.readline()).decode()[:-1]
async def real_url(url: str, override: bool, tor: bool) -> str: async def real_url(caching: Caching, url: str, override: bool, tor: bool) -> str:
if adaas_available and not tor: if adaas_available and not tor:
return await RemoteCache().real_url(url, override, tor) return await RemoteCache().real_url(url, override, tor)
hurl: str = bytes_hash(url.encode()) hurl: str = bytes_hash(url.encode())
if not override: if not override:
curl: Optional[str] = cache_db.get(f'url:{hurl}', None) curl: str | None = caching.get(hurl)
if curl is not None: if curl is not None:
print('using cached', hurl) print('using cached', hurl)
return curl return curl
rurl: str = await _resolve_url(url, tor) rurl: str = await _resolve_url(url, tor)
_schedule_cache(hurl, rurl, override, tor) caching.schedule_cache(hurl, rurl, override, tor)
return rurl return rurl

View File

@ -14,7 +14,8 @@ __all__ = ('YState',)
class YState: class YState:
def __init__(self, pool: Pool, ctx: Context, sources: Iterable[UrlCtx]) -> None: def __init__(self, caching: Caching, pool: Pool, ctx: Context, sources: Iterable[UrlCtx]) -> None:
self.caching = caching
self.pool = pool self.pool = pool
self.ctx = ctx self.ctx = ctx
self.sources: deque[UrlCtx] = deque(sources) self.sources: deque[UrlCtx] = deque(sources)
@ -52,7 +53,7 @@ class YState:
async def result(self, entry: InfoCtx) -> YTAudio | None: async def result(self, entry: InfoCtx) -> YTAudio | None:
try: try:
return await create_ytaudio(self.ctx, entry) return await create_ytaudio(self.caching, self.ctx, entry)
except Exception: except Exception:
if not entry.ignore: if not entry.ignore:
raise raise

View File

@ -1,17 +1,18 @@
from typing import AsyncIterable from typing import AsyncIterable
from v6d3music.core.entries_effects_for_args import * from v6d3music.core.default_effects import *
from v6d3music.core.ystate import * from v6d3music.core.ystate import *
from v6d3music.core.ytaudio import YTAudio from v6d3music.core.ytaudio import *
from v6d3music.core.caching import Caching
from v6d3music.processing.pool import * from v6d3music.processing.pool import *
from v6d3music.utils.argctx import * from v6d3music.utils.argctx import *
from v6d2ctx.context import Context from v6d2ctx.context import Context
async def yt_audios(ctx: Context, args: list[str]) -> AsyncIterable[YTAudio]: async def yt_audios(caching: Caching, defaulteffects: DefaultEffects, ctx: Context, args: list[str]) -> AsyncIterable[YTAudio]:
assert ctx.guild is not None assert ctx.guild is not None
argctx = ArgCtx(default_effects(ctx.guild.id), args) argctx = ArgCtx(defaulteffects.get(ctx.guild.id), args)
async with Pool(5) as pool: async with Pool(5) as pool:
async for audio in YState(pool, ctx, argctx.sources).iterate(): async for audio in YState(caching, pool, ctx, argctx.sources).iterate():
yield audio yield audio

View File

@ -5,6 +5,7 @@ from typing import Optional
import discord import discord
from v6d3music.core.ffmpegnormalaudio import FFmpegNormalAudio from v6d3music.core.ffmpegnormalaudio import FFmpegNormalAudio
from v6d3music.core.real_url import real_url from v6d3music.core.real_url import real_url
from v6d3music.core.caching import Caching
from v6d3music.utils.fill import FILL from v6d3music.utils.fill import FILL
from v6d3music.utils.sparq import sparq from v6d3music.utils.sparq import sparq
from v6d3music.utils.tor_prefix import tor_prefix from v6d3music.utils.tor_prefix import tor_prefix
@ -22,6 +23,7 @@ class YTAudio(discord.AudioSource):
def __init__( def __init__(
self, self,
caching: Caching,
url: str, url: str,
origin: str, origin: str,
description: str, description: str,
@ -33,6 +35,7 @@ class YTAudio(discord.AudioSource):
*, *,
stop_at: int | None = None stop_at: int | None = None
): ):
self.caching = caching
self.url = url self.url = url
self.origin = origin self.origin = origin
self.description = description self.description = description
@ -181,7 +184,7 @@ class YTAudio(discord.AudioSource):
} }
@classmethod @classmethod
async def respawn(cls, guild: discord.Guild, respawn: dict) -> 'YTAudio': async def respawn(cls, caching: Caching, guild: discord.Guild, respawn: dict) -> 'YTAudio':
member_id: int | None = respawn['rby'] member_id: int | None = respawn['rby']
if member_id is None: if member_id is None:
member = None member = None
@ -194,6 +197,7 @@ class YTAudio(discord.AudioSource):
except discord.NotFound: except discord.NotFound:
member = None member = None
audio = YTAudio( audio = YTAudio(
caching,
respawn['url'], respawn['url'],
respawn['origin'], respawn['origin'],
respawn['description'], respawn['description'],
@ -209,7 +213,7 @@ class YTAudio(discord.AudioSource):
async def regenerate(self): async def regenerate(self):
try: try:
print(f'regenerating {self.origin}') print(f'regenerating {self.origin}')
self.url = await real_url(self.origin, True, self.tor) self.url = await real_url(self.caching, self.origin, True, self.tor)
if hasattr(self, 'source'): if hasattr(self, 'source'):
self.source.cleanup() self.source.cleanup()
self.set_source() self.set_source()
@ -228,6 +232,7 @@ class YTAudio(discord.AudioSource):
def copy(self) -> 'YTAudio': def copy(self) -> 'YTAudio':
return YTAudio( return YTAudio(
self.caching,
self.url, self.url,
self.origin, self.origin,
self.description, self.description,
@ -242,6 +247,7 @@ class YTAudio(discord.AudioSource):
raise Explicit('already branched') raise Explicit('already branched')
self.stop_at = stop_at = self.already_read + 50 self.stop_at = stop_at = self.already_read + 50
audio = YTAudio( audio = YTAudio(
self.caching,
self.url, self.url,
self.origin, self.origin,
self.description, self.description,

View File

@ -3,22 +3,19 @@ import contextlib
import os import os
import sys import sys
import time import time
import traceback
import discord import discord
from v6d3music.app import MusicAppFactory, session_db from v6d3music.app import AppContext
from v6d3music.commands import of from v6d3music.commands import get_of
from v6d3music.config import prefix from v6d3music.config import prefix
from v6d3music.core.cache_url import cache_db from v6d3music.core.caching import *
from v6d3music.core.entries_effects_for_args import effects_db from v6d3music.core.default_effects import *
from v6d3music.core.mainasrc import main_for_raw_vc, mainasrcs from v6d3music.core.mainservice import MainService
from v6d3music.core.mainaudio import volume_db
from v6d3music.core.queueaudio import queue_db
from ptvp35 import *
from rainbowadn.instrument import Instrumentation from rainbowadn.instrument import Instrumentation
from v6d1tokens.client import request_token from v6d1tokens.client import request_token
from v6d2ctx.handle_content import handle_content from v6d2ctx.handle_content import handle_content
from v6d2ctx.lock_for import lock_for
from v6d2ctx.pain import ABlockMonitor from v6d2ctx.pain import ABlockMonitor
from v6d2ctx.serve import serve from v6d2ctx.serve import serve
@ -27,14 +24,7 @@ asyncio.set_event_loop(loop)
class MusicClient(discord.Client): class MusicClient(discord.Client):
async def close(self) -> None: pass
save_task.cancel()
await super().close()
try:
await save_all(False, False)
print('saved')
except Exception:
traceback.print_exc()
client = MusicClient( client = MusicClient(
@ -49,69 +39,9 @@ client = MusicClient(
reactions=True, reactions=True,
message_content=True, message_content=True,
), ),
loop=loop loop=loop,
) )
vcs_restored = False
async def _restore_vc(guild: discord.Guild, vccid: int, vc_is_paused: bool) -> None:
channels = await guild.fetch_channels()
channel: discord.VoiceChannel
channel, = [
ch for ch in
(
chc for chc in channels
if
isinstance(chc, discord.VoiceChannel)
)
if ch.id == vccid
]
vp: discord.VoiceProtocol = await channel.connect()
assert isinstance(vp, discord.VoiceClient)
vc = vp
await main_for_raw_vc(vc, create=True, force_play=True)
if vc_is_paused:
vc.pause()
async def restore_vc(vcgid: int, vccid: int, vc_is_paused: bool) -> None:
try:
print(f'vc restoring {vcgid}')
guild: discord.Guild = await client.fetch_guild(vcgid)
async with lock_for(guild, 'not in a guild'):
await _restore_vc(guild, vccid, vc_is_paused)
except Exception as e:
print(f'vc {vcgid} {vccid} {vc_is_paused} failed', e)
else:
print(f'vc restored {vcgid} {vccid}')
async def restore_vcs():
global vcs_restored
vcs: list[tuple[int, int, bool]] = queue_db.get('vcs', [])
try:
tasks = []
for vcgid, vccid, vc_is_paused in vcs:
tasks.append(asyncio.create_task(restore_vc(vcgid, vccid, vc_is_paused)))
for task in tasks:
await task
finally:
vcs_restored = True
@client.event
async def on_ready():
print('ready')
await client.change_presence(
activity=discord.Game(
name='феноменально',
)
)
async with lock_for('vcs_restored', '...'):
if not vcs_restored:
await restore_vcs()
banned_guilds = set(map(int, map(str.strip, os.getenv('banned_guilds', '').split(':')))) banned_guilds = set(map(int, map(str.strip, os.getenv('banned_guilds', '').split(':'))))
@ -124,64 +54,23 @@ def message_allowed(message: discord.Message) -> bool:
return guild_allowed(message.guild) return guild_allowed(message.guild)
def register_handlers(mainservice: MainService, defaulteffects: DefaultEffects):
of = get_of(mainservice, defaulteffects)
@client.event @client.event
async def on_message(message: discord.Message) -> None: async def on_message(message: discord.Message) -> None:
if message_allowed(message): if message_allowed(message):
await handle_content(of, message, message.content, prefix, client) await handle_content(of, message, message.content, prefix, client)
@client.event
async def save_queues(delay: bool): async def on_ready():
for mainasrc in list(mainasrcs.values()): print('ready')
if delay: await client.change_presence(
await asyncio.sleep(0.01) activity=discord.Game(
await mainasrc.queue.save(delay) name='феноменально',
)
)
async def save_vcs(delay: bool): await mainservice.restore()
if vcs_restored:
vcs = []
vc: discord.VoiceClient
for vc in (vcc for vcc in client.voice_clients if isinstance(vcc, discord.VoiceClient)):
if delay:
await asyncio.sleep(0.01)
if vc.is_playing():
if vc.guild is not None and vc.channel is not None:
vcs.append((vc.guild.id, vc.channel.id, vc.is_paused()))
queue_db.set_nowait('vcs', vcs)
async def save_commit():
await queue_db.commit()
async def save_all(delay: bool, save_playing: bool):
await save_queues(delay)
if save_playing:
await save_vcs(delay)
await save_commit()
async def save_job():
while True:
await asyncio.sleep(1)
await save_all(True, True)
async def save_daemon():
try:
await save_job()
except asyncio.CancelledError:
pass
async def start_app():
await MusicAppFactory.start(client)
async def setup_tasks():
global save_task
save_task = loop.create_task(save_daemon())
loop.create_task(start_app())
class UpgradeABMInit(Instrumentation): class UpgradeABMInit(Instrumentation):
@ -223,8 +112,37 @@ def _upgrade_abm() -> contextlib.ExitStack:
raise RuntimeError raise RuntimeError
class PathPrint(Instrumentation):
def __init__(self, methodname: str, pref: str):
super().__init__(DbConnection, methodname)
self.pref = pref
async def instrument(self, method, db: DbConnection, *args, **kwargs):
result = await method(db, *args, **kwargs)
try:
print(self.pref, db._DbConnection__path) # type: ignore
except Exception:
from traceback import print_exc
print_exc()
return result
def _db_ee() -> contextlib.ExitStack:
with contextlib.ExitStack() as es:
es.enter_context(PathPrint('_initialize', 'open :'))
es.enter_context(PathPrint('aclose', 'close:'))
return es.pop_all()
raise RuntimeError
async def main(): async def main():
async with volume_db, queue_db, cache_db, session_db, effects_db, ABlockMonitor(delta=0.5): async with (
DefaultEffects() as defaulteffects,
MainService(client) as mainservice,
AppContext(mainservice),
ABlockMonitor(delta=0.5)
):
register_handlers(mainservice, defaulteffects)
if 'guerilla' in sys.argv: if 'guerilla' in sys.argv:
from pathlib import Path from pathlib import Path
tokenpath = Path('.token.txt') tokenpath = Path('.token.txt')
@ -238,7 +156,6 @@ async def main():
else: else:
token = await request_token('music', 'token') token = await request_token('music', 'token')
await client.login(token) await client.login(token)
loop.create_task(setup_tasks())
if os.getenv('v6tor', None) is None: if os.getenv('v6tor', None) is None:
print('no tor') print('no tor')
await client.connect() await client.connect()
@ -246,5 +163,5 @@ async def main():
if __name__ == '__main__': if __name__ == '__main__':
with _upgrade_abm(): with _upgrade_abm(), _db_ee():
serve(main(), client, loop) serve(main(), client, loop)

View File

@ -1,24 +1,3 @@
__all__ = ('tor_prefix',) __all__ = ('tor_prefix',)
import os from adaas.tor_prefix import tor_prefix
if (address := os.getenv('v6tor', None)) is not None:
print('tor through torsocks')
try:
import socket
address = socket.gethostbyname_ex(address)[2][0]
except Exception:
print('failed tor resolution')
_tor_prefix = None
else:
print('tor successfully resolved')
_tor_prefix = ['torsocks', '--address', address]
else:
print('tor unavailable')
_tor_prefix = None
def tor_prefix():
if _tor_prefix is None:
raise ValueError('tor unavailable')
return _tor_prefix