This commit is contained in:
AF 2023-08-25 00:08:22 +00:00
parent ade19acee8
commit 7f1c025d57
3 changed files with 114 additions and 67 deletions

View File

@ -2,13 +2,14 @@ FROM python:3.11 as base
WORKDIR /app/
COPY requirements.txt requirements.txt
RUN pip install -r requirements.txt
COPY starbot starbot
CMD ["python3", "-m", "starbot"]
FROM base as ptvp35
RUN pip install git+https://gitea.parrrate.ru/PTV/ptvp35.git@f8ee5d20f4e159df2e20c40dbf3b81e925c2db36
ENV DBF_MODULE=starbot.db_ptvp35
COPY starbot starbot
FROM base as sqlite
RUN pip install aiosqlite~=0.19
ENV DBF_MODULE=starbot.db_aiosqlite
COPY starbot starbot

View File

@ -1,17 +1,31 @@
from __future__ import annotations
from contextlib import AsyncExitStack, asynccontextmanager
from collections.abc import Hashable
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Callable, TypeVar
from weakref import WeakValueDictionary
import discord
from discord.ext import commands
from .db import AbstractConnection, AbstractDbFactory
T = TypeVar("T")
class Temporaries:
def __init__(self) -> None:
self.__values: WeakValueDictionary[Hashable, Any] = WeakValueDictionary()
def get(self, key: Hashable, factory: Callable[[], T]) -> T | Any:
return self.__values.get(key) or self.__values.setdefault(key, factory())
class StarState:
def __init__(self, connection: AbstractConnection) -> None:
self.connection = connection
self.temporaries = Temporaries()
@asynccontextmanager

View File

@ -1,31 +1,38 @@
from __future__ import annotations
import asyncio
from collections.abc import Hashable
from typing import Callable, Generic, Hashable, Type, TypeVar
import discord
from discord.ext import commands
from .bot import StarBot, StarState
from .bot import StarBot, StarState, Temporaries
from .db import AbstractConnection
T = TypeVar("T")
class TypedTemporaries(Generic[T]):
def __init__(self, temporaries: Temporaries, type_: Type[T]) -> None:
self.temporaries = temporaries
self.type = type_
def get(self, key: Hashable, factory: Callable[[], T]) -> T:
value = self.temporaries.get((self.type, key), factory)
if not isinstance(value, self.type):
raise TypeError(self.type, value)
return value
class Locks:
def __init__(self) -> None:
self.locks: dict[Hashable, asyncio.Lock] = {}
def __init__(self, state: StarState) -> None:
self.typed = TypedTemporaries(state.temporaries, asyncio.Lock)
def lock_for(self, key: Hashable) -> asyncio.Lock:
if key in self.locks:
return self.locks[key]
else:
return self.locks.setdefault(key, asyncio.Lock())
def lock(self, key: Hashable) -> asyncio.Lock:
return self.typed.get(key, asyncio.Lock)
locks = Locks()
lock_for = locks.lock_for
class StarCtx:
class AdminCtx:
def __init__(self, ctx: commands.Context) -> None:
self.ctx = ctx
assert ctx.guild
@ -41,8 +48,78 @@ class StarCtx:
await self.connection.set(("assign", self.guild.id), None)
class StarForwardCtx:
def __init__(self, message: discord.Message, star_channel: discord.abc.Messageable) -> None:
self.message = message
self.star_channel = star_channel
async def on(self) -> None:
author = self.message.author
embed = discord.Embed(description=self.message.content or None)
avatar = author.avatar
embed.set_author(name=author.display_name, url=self.message.jump_url, icon_url=avatar and avatar.url)
image = next(
(
attachment.url
for attachment in self.message.attachments
if (attachment.content_type or "").startswith("image/")
),
None,
)
if image is not None:
embed.set_image(url=image)
await asyncio.gather(
self.message.add_reaction(""),
self.star_channel.send(embed=embed),
)
class StarMessageCtx:
def __init__(self, message: discord.Message, star_channel: discord.abc.Messageable, count: int) -> None:
self.message = message
self.star_channel = star_channel
self.count = count
def triggered(self) -> bool:
reaction = next((reaction for reaction in self.message.reactions if reaction.emoji == ""), None)
return reaction is not None and not reaction.me and reaction.count >= self.count
async def on(self) -> None:
if self.triggered():
await StarForwardCtx(self.message, self.star_channel).on()
class StarEventCtx:
def __init__(self, reaction: ReactionCtx, star_channel_id: int, count: int) -> None:
self.bot = reaction.bot
self.locks = Locks(reaction.state)
self.star_channel_id = star_channel_id
self.count = count
self.channel_id = reaction.channel_id
self.message_id = reaction.message_id
async def get_channel(self, id_: int) -> discord.abc.Messageable:
channel = self.bot.get_channel(id_) or await self.bot.fetch_channel(id_)
match channel:
case discord.CategoryChannel() | discord.ForumChannel() | discord.abc.PrivateChannel():
raise TypeError
case _:
return channel
async def _on(self) -> None:
star_channel, event_channel = await asyncio.gather(
self.get_channel(self.star_channel_id), self.get_channel(self.channel_id)
)
message = await event_channel.fetch_message(self.message_id)
await StarMessageCtx(message, star_channel, self.count).on()
async def on(self) -> None:
async with self.locks.lock(self.message_id):
await self._on()
class ReactionCtx:
def __init__(self, bot: StarBot, event: discord.RawReactionActionEvent):
def __init__(self, bot: StarBot, event: discord.RawReactionActionEvent) -> None:
self.bot = bot
self.guild_id = event.guild_id
self.channel_id = event.channel_id
@ -51,61 +128,16 @@ class ReactionCtx:
self.state: StarState = self.bot.starstate
self.connection: AbstractConnection = self.state.connection
async def get_channel(self, id_: int) -> discord.abc.Messageable:
channel = self.bot.get_channel(id_)
if channel is None:
channel = await self.bot.fetch_channel(id_)
if isinstance(channel, discord.CategoryChannel):
raise TypeError
if isinstance(channel, discord.ForumChannel):
raise TypeError
if isinstance(channel, discord.abc.PrivateChannel):
raise TypeError
return channel
async def on(self) -> None:
if self.name != "":
return
assignment: dict[str, int] | None = self.connection.get(("assign", self.guild_id), None)
if assignment is None:
return
assigned_to, count = assignment["channel"], assignment["count"]
if self.channel_id == assigned_to:
star_channel_id, count = assignment["channel"], assignment["count"]
if self.channel_id == star_channel_id:
return
async with lock_for(self.message_id):
assigned_channel, event_channel = await asyncio.gather(
self.get_channel(assigned_to), self.get_channel(self.channel_id)
)
message = await event_channel.fetch_message(self.message_id)
reaction = next((reaction for reaction in message.reactions if reaction.emoji == ""), None)
if reaction is None:
return
if reaction.me:
return
if reaction.count >= count:
guild = message.guild
if guild is None:
return
member = guild.get_member(message.author.id)
if member is None:
member = await guild.fetch_member(message.author.id)
embed = discord.Embed(description=message.content or None)
avatar = member.avatar
embed.set_author(name=member.display_name, url=message.jump_url, icon_url=avatar and avatar.url)
image = next(
(
attachment.url
for attachment in message.attachments
if (attachment.content_type or "").startswith("image/")
),
None,
)
if image is not None:
embed.set_image(url=image)
await asyncio.gather(
message.add_reaction(""),
assigned_channel.send(embed=embed),
)
await StarEventCtx(self, star_channel_id, count).on()
class Stars(commands.Cog):
@ -140,13 +172,13 @@ class Stars(commands.Cog):
@commands.hybrid_command()
@commands.has_permissions(administrator=True)
async def assign(self, ctx: commands.Context, count: int):
await StarCtx(ctx).assign(count)
await AdminCtx(ctx).assign(count)
await ctx.reply("assigned")
@commands.hybrid_command()
@commands.has_permissions(administrator=True)
async def unassign(self, ctx: commands.Context):
await StarCtx(ctx).unassign()
await AdminCtx(ctx).unassign()
await ctx.reply("unassigned")
@commands.Cog.listener()