diff --git a/starbot/Dockerfile b/starbot/Dockerfile index e2baad0..0e01a7d 100644 --- a/starbot/Dockerfile +++ b/starbot/Dockerfile @@ -2,13 +2,14 @@ FROM python:3.11 as base WORKDIR /app/ COPY requirements.txt requirements.txt RUN pip install -r requirements.txt -COPY starbot starbot CMD ["python3", "-m", "starbot"] FROM base as ptvp35 RUN pip install git+https://gitea.parrrate.ru/PTV/ptvp35.git@f8ee5d20f4e159df2e20c40dbf3b81e925c2db36 ENV DBF_MODULE=starbot.db_ptvp35 +COPY starbot starbot FROM base as sqlite RUN pip install aiosqlite~=0.19 ENV DBF_MODULE=starbot.db_aiosqlite +COPY starbot starbot diff --git a/starbot/starbot/bot.py b/starbot/starbot/bot.py index b90e2ca..12872fe 100644 --- a/starbot/starbot/bot.py +++ b/starbot/starbot/bot.py @@ -1,17 +1,31 @@ from __future__ import annotations -from contextlib import AsyncExitStack, asynccontextmanager +from collections.abc import Hashable +from contextlib import asynccontextmanager from pathlib import Path +from typing import Any, Callable, TypeVar +from weakref import WeakValueDictionary import discord from discord.ext import commands from .db import AbstractConnection, AbstractDbFactory +T = TypeVar("T") + + +class Temporaries: + def __init__(self) -> None: + self.__values: WeakValueDictionary[Hashable, Any] = WeakValueDictionary() + + def get(self, key: Hashable, factory: Callable[[], T]) -> T | Any: + return self.__values.get(key) or self.__values.setdefault(key, factory()) + class StarState: def __init__(self, connection: AbstractConnection) -> None: self.connection = connection + self.temporaries = Temporaries() @asynccontextmanager diff --git a/starbot/starbot/stars.py b/starbot/starbot/stars.py index 84f5392..c013b63 100644 --- a/starbot/starbot/stars.py +++ b/starbot/starbot/stars.py @@ -1,31 +1,38 @@ from __future__ import annotations import asyncio -from collections.abc import Hashable +from typing import Callable, Generic, Hashable, Type, TypeVar import discord from discord.ext import commands -from .bot import StarBot, StarState +from .bot import StarBot, StarState, Temporaries from .db import AbstractConnection +T = TypeVar("T") + + +class TypedTemporaries(Generic[T]): + def __init__(self, temporaries: Temporaries, type_: Type[T]) -> None: + self.temporaries = temporaries + self.type = type_ + + def get(self, key: Hashable, factory: Callable[[], T]) -> T: + value = self.temporaries.get((self.type, key), factory) + if not isinstance(value, self.type): + raise TypeError(self.type, value) + return value + class Locks: - def __init__(self) -> None: - self.locks: dict[Hashable, asyncio.Lock] = {} + def __init__(self, state: StarState) -> None: + self.typed = TypedTemporaries(state.temporaries, asyncio.Lock) - def lock_for(self, key: Hashable) -> asyncio.Lock: - if key in self.locks: - return self.locks[key] - else: - return self.locks.setdefault(key, asyncio.Lock()) + def lock(self, key: Hashable) -> asyncio.Lock: + return self.typed.get(key, asyncio.Lock) -locks = Locks() -lock_for = locks.lock_for - - -class StarCtx: +class AdminCtx: def __init__(self, ctx: commands.Context) -> None: self.ctx = ctx assert ctx.guild @@ -41,8 +48,78 @@ class StarCtx: await self.connection.set(("assign", self.guild.id), None) +class StarForwardCtx: + def __init__(self, message: discord.Message, star_channel: discord.abc.Messageable) -> None: + self.message = message + self.star_channel = star_channel + + async def on(self) -> None: + author = self.message.author + embed = discord.Embed(description=self.message.content or None) + avatar = author.avatar + embed.set_author(name=author.display_name, url=self.message.jump_url, icon_url=avatar and avatar.url) + image = next( + ( + attachment.url + for attachment in self.message.attachments + if (attachment.content_type or "").startswith("image/") + ), + None, + ) + if image is not None: + embed.set_image(url=image) + await asyncio.gather( + self.message.add_reaction("⭐"), + self.star_channel.send(embed=embed), + ) + + +class StarMessageCtx: + def __init__(self, message: discord.Message, star_channel: discord.abc.Messageable, count: int) -> None: + self.message = message + self.star_channel = star_channel + self.count = count + + def triggered(self) -> bool: + reaction = next((reaction for reaction in self.message.reactions if reaction.emoji == "⭐"), None) + return reaction is not None and not reaction.me and reaction.count >= self.count + + async def on(self) -> None: + if self.triggered(): + await StarForwardCtx(self.message, self.star_channel).on() + + +class StarEventCtx: + def __init__(self, reaction: ReactionCtx, star_channel_id: int, count: int) -> None: + self.bot = reaction.bot + self.locks = Locks(reaction.state) + self.star_channel_id = star_channel_id + self.count = count + self.channel_id = reaction.channel_id + self.message_id = reaction.message_id + + async def get_channel(self, id_: int) -> discord.abc.Messageable: + channel = self.bot.get_channel(id_) or await self.bot.fetch_channel(id_) + match channel: + case discord.CategoryChannel() | discord.ForumChannel() | discord.abc.PrivateChannel(): + raise TypeError + case _: + return channel + + async def _on(self) -> None: + star_channel, event_channel = await asyncio.gather( + self.get_channel(self.star_channel_id), self.get_channel(self.channel_id) + ) + message = await event_channel.fetch_message(self.message_id) + await StarMessageCtx(message, star_channel, self.count).on() + + async def on(self) -> None: + async with self.locks.lock(self.message_id): + await self._on() + + class ReactionCtx: - def __init__(self, bot: StarBot, event: discord.RawReactionActionEvent): + def __init__(self, bot: StarBot, event: discord.RawReactionActionEvent) -> None: self.bot = bot self.guild_id = event.guild_id self.channel_id = event.channel_id @@ -51,61 +128,16 @@ class ReactionCtx: self.state: StarState = self.bot.starstate self.connection: AbstractConnection = self.state.connection - async def get_channel(self, id_: int) -> discord.abc.Messageable: - channel = self.bot.get_channel(id_) - if channel is None: - channel = await self.bot.fetch_channel(id_) - if isinstance(channel, discord.CategoryChannel): - raise TypeError - if isinstance(channel, discord.ForumChannel): - raise TypeError - if isinstance(channel, discord.abc.PrivateChannel): - raise TypeError - return channel - async def on(self) -> None: if self.name != "⭐": return assignment: dict[str, int] | None = self.connection.get(("assign", self.guild_id), None) if assignment is None: return - assigned_to, count = assignment["channel"], assignment["count"] - if self.channel_id == assigned_to: + star_channel_id, count = assignment["channel"], assignment["count"] + if self.channel_id == star_channel_id: return - async with lock_for(self.message_id): - assigned_channel, event_channel = await asyncio.gather( - self.get_channel(assigned_to), self.get_channel(self.channel_id) - ) - message = await event_channel.fetch_message(self.message_id) - reaction = next((reaction for reaction in message.reactions if reaction.emoji == "⭐"), None) - if reaction is None: - return - if reaction.me: - return - if reaction.count >= count: - guild = message.guild - if guild is None: - return - member = guild.get_member(message.author.id) - if member is None: - member = await guild.fetch_member(message.author.id) - embed = discord.Embed(description=message.content or None) - avatar = member.avatar - embed.set_author(name=member.display_name, url=message.jump_url, icon_url=avatar and avatar.url) - image = next( - ( - attachment.url - for attachment in message.attachments - if (attachment.content_type or "").startswith("image/") - ), - None, - ) - if image is not None: - embed.set_image(url=image) - await asyncio.gather( - message.add_reaction("⭐"), - assigned_channel.send(embed=embed), - ) + await StarEventCtx(self, star_channel_id, count).on() class Stars(commands.Cog): @@ -140,13 +172,13 @@ class Stars(commands.Cog): @commands.hybrid_command() @commands.has_permissions(administrator=True) async def assign(self, ctx: commands.Context, count: int): - await StarCtx(ctx).assign(count) + await AdminCtx(ctx).assign(count) await ctx.reply("assigned") @commands.hybrid_command() @commands.has_permissions(administrator=True) async def unassign(self, ctx: commands.Context): - await StarCtx(ctx).unassign() + await AdminCtx(ctx).unassign() await ctx.reply("unassigned") @commands.Cog.listener()