This commit is contained in:
AF 2023-08-25 00:08:22 +00:00
parent ade19acee8
commit 7f1c025d57
3 changed files with 114 additions and 67 deletions

View File

@ -2,13 +2,14 @@ FROM python:3.11 as base
WORKDIR /app/ WORKDIR /app/
COPY requirements.txt requirements.txt COPY requirements.txt requirements.txt
RUN pip install -r requirements.txt RUN pip install -r requirements.txt
COPY starbot starbot
CMD ["python3", "-m", "starbot"] CMD ["python3", "-m", "starbot"]
FROM base as ptvp35 FROM base as ptvp35
RUN pip install git+https://gitea.parrrate.ru/PTV/ptvp35.git@f8ee5d20f4e159df2e20c40dbf3b81e925c2db36 RUN pip install git+https://gitea.parrrate.ru/PTV/ptvp35.git@f8ee5d20f4e159df2e20c40dbf3b81e925c2db36
ENV DBF_MODULE=starbot.db_ptvp35 ENV DBF_MODULE=starbot.db_ptvp35
COPY starbot starbot
FROM base as sqlite FROM base as sqlite
RUN pip install aiosqlite~=0.19 RUN pip install aiosqlite~=0.19
ENV DBF_MODULE=starbot.db_aiosqlite ENV DBF_MODULE=starbot.db_aiosqlite
COPY starbot starbot

View File

@ -1,17 +1,31 @@
from __future__ import annotations from __future__ import annotations
from contextlib import AsyncExitStack, asynccontextmanager from collections.abc import Hashable
from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
from typing import Any, Callable, TypeVar
from weakref import WeakValueDictionary
import discord import discord
from discord.ext import commands from discord.ext import commands
from .db import AbstractConnection, AbstractDbFactory from .db import AbstractConnection, AbstractDbFactory
T = TypeVar("T")
class Temporaries:
def __init__(self) -> None:
self.__values: WeakValueDictionary[Hashable, Any] = WeakValueDictionary()
def get(self, key: Hashable, factory: Callable[[], T]) -> T | Any:
return self.__values.get(key) or self.__values.setdefault(key, factory())
class StarState: class StarState:
def __init__(self, connection: AbstractConnection) -> None: def __init__(self, connection: AbstractConnection) -> None:
self.connection = connection self.connection = connection
self.temporaries = Temporaries()
@asynccontextmanager @asynccontextmanager

View File

@ -1,31 +1,38 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Hashable from typing import Callable, Generic, Hashable, Type, TypeVar
import discord import discord
from discord.ext import commands from discord.ext import commands
from .bot import StarBot, StarState from .bot import StarBot, StarState, Temporaries
from .db import AbstractConnection from .db import AbstractConnection
T = TypeVar("T")
class TypedTemporaries(Generic[T]):
def __init__(self, temporaries: Temporaries, type_: Type[T]) -> None:
self.temporaries = temporaries
self.type = type_
def get(self, key: Hashable, factory: Callable[[], T]) -> T:
value = self.temporaries.get((self.type, key), factory)
if not isinstance(value, self.type):
raise TypeError(self.type, value)
return value
class Locks: class Locks:
def __init__(self) -> None: def __init__(self, state: StarState) -> None:
self.locks: dict[Hashable, asyncio.Lock] = {} self.typed = TypedTemporaries(state.temporaries, asyncio.Lock)
def lock_for(self, key: Hashable) -> asyncio.Lock: def lock(self, key: Hashable) -> asyncio.Lock:
if key in self.locks: return self.typed.get(key, asyncio.Lock)
return self.locks[key]
else:
return self.locks.setdefault(key, asyncio.Lock())
locks = Locks() class AdminCtx:
lock_for = locks.lock_for
class StarCtx:
def __init__(self, ctx: commands.Context) -> None: def __init__(self, ctx: commands.Context) -> None:
self.ctx = ctx self.ctx = ctx
assert ctx.guild assert ctx.guild
@ -41,8 +48,78 @@ class StarCtx:
await self.connection.set(("assign", self.guild.id), None) await self.connection.set(("assign", self.guild.id), None)
class StarForwardCtx:
def __init__(self, message: discord.Message, star_channel: discord.abc.Messageable) -> None:
self.message = message
self.star_channel = star_channel
async def on(self) -> None:
author = self.message.author
embed = discord.Embed(description=self.message.content or None)
avatar = author.avatar
embed.set_author(name=author.display_name, url=self.message.jump_url, icon_url=avatar and avatar.url)
image = next(
(
attachment.url
for attachment in self.message.attachments
if (attachment.content_type or "").startswith("image/")
),
None,
)
if image is not None:
embed.set_image(url=image)
await asyncio.gather(
self.message.add_reaction(""),
self.star_channel.send(embed=embed),
)
class StarMessageCtx:
def __init__(self, message: discord.Message, star_channel: discord.abc.Messageable, count: int) -> None:
self.message = message
self.star_channel = star_channel
self.count = count
def triggered(self) -> bool:
reaction = next((reaction for reaction in self.message.reactions if reaction.emoji == ""), None)
return reaction is not None and not reaction.me and reaction.count >= self.count
async def on(self) -> None:
if self.triggered():
await StarForwardCtx(self.message, self.star_channel).on()
class StarEventCtx:
def __init__(self, reaction: ReactionCtx, star_channel_id: int, count: int) -> None:
self.bot = reaction.bot
self.locks = Locks(reaction.state)
self.star_channel_id = star_channel_id
self.count = count
self.channel_id = reaction.channel_id
self.message_id = reaction.message_id
async def get_channel(self, id_: int) -> discord.abc.Messageable:
channel = self.bot.get_channel(id_) or await self.bot.fetch_channel(id_)
match channel:
case discord.CategoryChannel() | discord.ForumChannel() | discord.abc.PrivateChannel():
raise TypeError
case _:
return channel
async def _on(self) -> None:
star_channel, event_channel = await asyncio.gather(
self.get_channel(self.star_channel_id), self.get_channel(self.channel_id)
)
message = await event_channel.fetch_message(self.message_id)
await StarMessageCtx(message, star_channel, self.count).on()
async def on(self) -> None:
async with self.locks.lock(self.message_id):
await self._on()
class ReactionCtx: class ReactionCtx:
def __init__(self, bot: StarBot, event: discord.RawReactionActionEvent): def __init__(self, bot: StarBot, event: discord.RawReactionActionEvent) -> None:
self.bot = bot self.bot = bot
self.guild_id = event.guild_id self.guild_id = event.guild_id
self.channel_id = event.channel_id self.channel_id = event.channel_id
@ -51,61 +128,16 @@ class ReactionCtx:
self.state: StarState = self.bot.starstate self.state: StarState = self.bot.starstate
self.connection: AbstractConnection = self.state.connection self.connection: AbstractConnection = self.state.connection
async def get_channel(self, id_: int) -> discord.abc.Messageable:
channel = self.bot.get_channel(id_)
if channel is None:
channel = await self.bot.fetch_channel(id_)
if isinstance(channel, discord.CategoryChannel):
raise TypeError
if isinstance(channel, discord.ForumChannel):
raise TypeError
if isinstance(channel, discord.abc.PrivateChannel):
raise TypeError
return channel
async def on(self) -> None: async def on(self) -> None:
if self.name != "": if self.name != "":
return return
assignment: dict[str, int] | None = self.connection.get(("assign", self.guild_id), None) assignment: dict[str, int] | None = self.connection.get(("assign", self.guild_id), None)
if assignment is None: if assignment is None:
return return
assigned_to, count = assignment["channel"], assignment["count"] star_channel_id, count = assignment["channel"], assignment["count"]
if self.channel_id == assigned_to: if self.channel_id == star_channel_id:
return return
async with lock_for(self.message_id): await StarEventCtx(self, star_channel_id, count).on()
assigned_channel, event_channel = await asyncio.gather(
self.get_channel(assigned_to), self.get_channel(self.channel_id)
)
message = await event_channel.fetch_message(self.message_id)
reaction = next((reaction for reaction in message.reactions if reaction.emoji == ""), None)
if reaction is None:
return
if reaction.me:
return
if reaction.count >= count:
guild = message.guild
if guild is None:
return
member = guild.get_member(message.author.id)
if member is None:
member = await guild.fetch_member(message.author.id)
embed = discord.Embed(description=message.content or None)
avatar = member.avatar
embed.set_author(name=member.display_name, url=message.jump_url, icon_url=avatar and avatar.url)
image = next(
(
attachment.url
for attachment in message.attachments
if (attachment.content_type or "").startswith("image/")
),
None,
)
if image is not None:
embed.set_image(url=image)
await asyncio.gather(
message.add_reaction(""),
assigned_channel.send(embed=embed),
)
class Stars(commands.Cog): class Stars(commands.Cog):
@ -140,13 +172,13 @@ class Stars(commands.Cog):
@commands.hybrid_command() @commands.hybrid_command()
@commands.has_permissions(administrator=True) @commands.has_permissions(administrator=True)
async def assign(self, ctx: commands.Context, count: int): async def assign(self, ctx: commands.Context, count: int):
await StarCtx(ctx).assign(count) await AdminCtx(ctx).assign(count)
await ctx.reply("assigned") await ctx.reply("assigned")
@commands.hybrid_command() @commands.hybrid_command()
@commands.has_permissions(administrator=True) @commands.has_permissions(administrator=True)
async def unassign(self, ctx: commands.Context): async def unassign(self, ctx: commands.Context):
await StarCtx(ctx).unassign() await AdminCtx(ctx).unassign()
await ctx.reply("unassigned") await ctx.reply("unassigned")
@commands.Cog.listener() @commands.Cog.listener()