v6d3music/v6d3music/utils/argctx.py
2023-12-27 04:34:30 +00:00

171 lines
6.2 KiB
Python

import string
from typing import Any, AsyncIterable
from v6d2ctx.context import Context, Explicit, escape
from v6d3music.config import prefix
from v6d3music.utils.assert_admin import assert_admin
from v6d3music.utils.effects_for_preset import effects_for_preset
from v6d3music.utils.entries_for_url import entries_for_url
from v6d3music.utils.options_for_effects import options_for_effects
from v6d3music.utils.presets import allowed_effects, presets
from v6d3music.utils.sparq import sparq
__all__ = (
"InfoCtx",
"BoundCtx",
"UrlCtx",
"ArgCtx",
)
class PostCtx:
def __init__(self, effects: str | None) -> None:
self.effects: str | None = effects
self.already_read: int = 0
self.ignore: bool = False
class InfoCtx:
def __init__(self, info: dict[str, Any], post: PostCtx) -> None:
self.info = info
self.post = post
self.effects = post.effects
self.already_read = post.already_read
self.ignore = post.ignore
def bind(self, ctx: Context) -> "BoundCtx":
return BoundCtx(self, ctx)
class BoundCtx:
def __init__(self, it: InfoCtx, ctx: Context, /) -> None:
if ctx.member is None:
raise Explicit("not in a guild")
self.member = ctx.member
self.ctx = ctx
self.url = it.info["url"]
if "yandex" in self.url and (
not ctx.guild
or (
ctx.guild.id != 541241763042689025
and ctx.author.id != 264054779888533515
)
):
raise Explicit("yandex is not allowed")
self.description = (
f'{escape(it.info.get("title", "unknown"))} `Rby` {ctx.member}'
)
self.effects = it.effects
self.already_read = it.already_read
self.options = self._options()
def _options(self) -> str | None:
assert self.ctx.member is not None
if self.effects:
if self.effects not in allowed_effects:
assert_admin(self.ctx.member)
if not set(self.effects) <= set(
string.ascii_letters + string.digits + "*,=+-/()|.^:_"
):
raise Explicit("malformed effects")
return options_for_effects(self.effects)
else:
return None
class UrlCtx:
def __init__(self, url: str, post: PostCtx) -> None:
self.url = url
self.post = post
self.effects: str | None = post.effects
self.already_read = post.already_read
self.ignore = post.ignore
async def entries(self) -> AsyncIterable[InfoCtx]:
try:
async for info in entries_for_url(self.url):
yield InfoCtx(info, self.post)
except Exception:
if not self.ignore:
raise
class ArgCtx:
def __init__(self, default_effects: str | None, args: list[str]) -> None:
self.sources: list[UrlCtx] = []
while args:
match args:
case ["[[", *args]:
try:
close_ix = args.index("]]")
except ValueError:
raise Explicit("expected closing `]]`, not found")
urls = args[:close_ix]
assert isinstance(args, list)
args = args[close_ix + 1 :]
case ["]]", *args]:
raise Explicit("unexpected `]]`")
case [_url, *args]:
urls = [_url]
case _:
raise RuntimeError
for url in urls:
if url in presets:
raise Explicit(
"expected url, got preset. maybe you are missing `+`?"
)
if url in {"+", "-"}:
raise Explicit(
"expected url, got `+` or `-`. maybe you tried to use multiple effects?"
)
if url.startswith("+") or url.startswith('-"') or url.startswith("-'"):
raise Explicit(
"expected url, got `+` or `-\"` or `-'`. maybe you forgot to separate control symbol from the effects?"
)
if "youtu" in url and "watch" in url and "list" in url:
raise Explicit(
"bot cannot decide how to handle an URL with both `watch` and `list` in it.\n"
"instead, use either `youtube.com/watch` URL without `&list=` part or `youtube.com/playlist` URL."
)
if url == "skip":
raise Explicit(f"to skip, use `{prefix}skip`")
match args:
case ["-", effects, *args]:
pass
case ["+", preset, *args]:
effects = effects_for_preset(preset)
case [*args]:
effects = default_effects
case _:
raise RuntimeError
post = PostCtx(effects)
match args:
case [
h,
m,
s,
*args,
] if h.isdecimal() and m.isdecimal() and s.isdecimal():
seconds = 3600 * int(h) + 60 * int(m) + int(s)
case [m, s, *args] if m.isdecimal() and s.isdecimal():
seconds = 60 * int(m) + int(s)
case [s, *args] if s.isdecimal():
seconds = int(s)
case [*args]:
seconds = 0
post.already_read = round(seconds / sparq(options_for_effects(effects)))
while True:
match args:
case ["tor", *args]:
raise Explicit("tor support is temporarily suspended")
case ["ignore", *args]:
if post.ignore:
raise Explicit("duplicate ignore")
post.ignore = True
case [*args]:
break
for url in urls:
if url.startswith("<") and url.endswith(">"):
url = url[1:-1]
self.sources.append(UrlCtx(url, post))