refactor
This commit is contained in:
parent
ade19acee8
commit
7f1c025d57
@ -2,13 +2,14 @@ FROM python:3.11 as base
|
||||
WORKDIR /app/
|
||||
COPY requirements.txt requirements.txt
|
||||
RUN pip install -r requirements.txt
|
||||
COPY starbot starbot
|
||||
CMD ["python3", "-m", "starbot"]
|
||||
|
||||
FROM base as ptvp35
|
||||
RUN pip install git+https://gitea.parrrate.ru/PTV/ptvp35.git@f8ee5d20f4e159df2e20c40dbf3b81e925c2db36
|
||||
ENV DBF_MODULE=starbot.db_ptvp35
|
||||
COPY starbot starbot
|
||||
|
||||
FROM base as sqlite
|
||||
RUN pip install aiosqlite~=0.19
|
||||
ENV DBF_MODULE=starbot.db_aiosqlite
|
||||
COPY starbot starbot
|
||||
|
@ -1,17 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from collections.abc import Hashable
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, TypeVar
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from .db import AbstractConnection, AbstractDbFactory
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Temporaries:
|
||||
def __init__(self) -> None:
|
||||
self.__values: WeakValueDictionary[Hashable, Any] = WeakValueDictionary()
|
||||
|
||||
def get(self, key: Hashable, factory: Callable[[], T]) -> T | Any:
|
||||
return self.__values.get(key) or self.__values.setdefault(key, factory())
|
||||
|
||||
|
||||
class StarState:
|
||||
def __init__(self, connection: AbstractConnection) -> None:
|
||||
self.connection = connection
|
||||
self.temporaries = Temporaries()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
@ -1,31 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Hashable
|
||||
from typing import Callable, Generic, Hashable, Type, TypeVar
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from .bot import StarBot, StarState
|
||||
from .bot import StarBot, StarState, Temporaries
|
||||
from .db import AbstractConnection
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class TypedTemporaries(Generic[T]):
|
||||
def __init__(self, temporaries: Temporaries, type_: Type[T]) -> None:
|
||||
self.temporaries = temporaries
|
||||
self.type = type_
|
||||
|
||||
def get(self, key: Hashable, factory: Callable[[], T]) -> T:
|
||||
value = self.temporaries.get((self.type, key), factory)
|
||||
if not isinstance(value, self.type):
|
||||
raise TypeError(self.type, value)
|
||||
return value
|
||||
|
||||
|
||||
class Locks:
|
||||
def __init__(self) -> None:
|
||||
self.locks: dict[Hashable, asyncio.Lock] = {}
|
||||
def __init__(self, state: StarState) -> None:
|
||||
self.typed = TypedTemporaries(state.temporaries, asyncio.Lock)
|
||||
|
||||
def lock_for(self, key: Hashable) -> asyncio.Lock:
|
||||
if key in self.locks:
|
||||
return self.locks[key]
|
||||
else:
|
||||
return self.locks.setdefault(key, asyncio.Lock())
|
||||
def lock(self, key: Hashable) -> asyncio.Lock:
|
||||
return self.typed.get(key, asyncio.Lock)
|
||||
|
||||
|
||||
locks = Locks()
|
||||
lock_for = locks.lock_for
|
||||
|
||||
|
||||
class StarCtx:
|
||||
class AdminCtx:
|
||||
def __init__(self, ctx: commands.Context) -> None:
|
||||
self.ctx = ctx
|
||||
assert ctx.guild
|
||||
@ -41,8 +48,78 @@ class StarCtx:
|
||||
await self.connection.set(("assign", self.guild.id), None)
|
||||
|
||||
|
||||
class StarForwardCtx:
|
||||
def __init__(self, message: discord.Message, star_channel: discord.abc.Messageable) -> None:
|
||||
self.message = message
|
||||
self.star_channel = star_channel
|
||||
|
||||
async def on(self) -> None:
|
||||
author = self.message.author
|
||||
embed = discord.Embed(description=self.message.content or None)
|
||||
avatar = author.avatar
|
||||
embed.set_author(name=author.display_name, url=self.message.jump_url, icon_url=avatar and avatar.url)
|
||||
image = next(
|
||||
(
|
||||
attachment.url
|
||||
for attachment in self.message.attachments
|
||||
if (attachment.content_type or "").startswith("image/")
|
||||
),
|
||||
None,
|
||||
)
|
||||
if image is not None:
|
||||
embed.set_image(url=image)
|
||||
await asyncio.gather(
|
||||
self.message.add_reaction("⭐"),
|
||||
self.star_channel.send(embed=embed),
|
||||
)
|
||||
|
||||
|
||||
class StarMessageCtx:
|
||||
def __init__(self, message: discord.Message, star_channel: discord.abc.Messageable, count: int) -> None:
|
||||
self.message = message
|
||||
self.star_channel = star_channel
|
||||
self.count = count
|
||||
|
||||
def triggered(self) -> bool:
|
||||
reaction = next((reaction for reaction in self.message.reactions if reaction.emoji == "⭐"), None)
|
||||
return reaction is not None and not reaction.me and reaction.count >= self.count
|
||||
|
||||
async def on(self) -> None:
|
||||
if self.triggered():
|
||||
await StarForwardCtx(self.message, self.star_channel).on()
|
||||
|
||||
|
||||
class StarEventCtx:
|
||||
def __init__(self, reaction: ReactionCtx, star_channel_id: int, count: int) -> None:
|
||||
self.bot = reaction.bot
|
||||
self.locks = Locks(reaction.state)
|
||||
self.star_channel_id = star_channel_id
|
||||
self.count = count
|
||||
self.channel_id = reaction.channel_id
|
||||
self.message_id = reaction.message_id
|
||||
|
||||
async def get_channel(self, id_: int) -> discord.abc.Messageable:
|
||||
channel = self.bot.get_channel(id_) or await self.bot.fetch_channel(id_)
|
||||
match channel:
|
||||
case discord.CategoryChannel() | discord.ForumChannel() | discord.abc.PrivateChannel():
|
||||
raise TypeError
|
||||
case _:
|
||||
return channel
|
||||
|
||||
async def _on(self) -> None:
|
||||
star_channel, event_channel = await asyncio.gather(
|
||||
self.get_channel(self.star_channel_id), self.get_channel(self.channel_id)
|
||||
)
|
||||
message = await event_channel.fetch_message(self.message_id)
|
||||
await StarMessageCtx(message, star_channel, self.count).on()
|
||||
|
||||
async def on(self) -> None:
|
||||
async with self.locks.lock(self.message_id):
|
||||
await self._on()
|
||||
|
||||
|
||||
class ReactionCtx:
|
||||
def __init__(self, bot: StarBot, event: discord.RawReactionActionEvent):
|
||||
def __init__(self, bot: StarBot, event: discord.RawReactionActionEvent) -> None:
|
||||
self.bot = bot
|
||||
self.guild_id = event.guild_id
|
||||
self.channel_id = event.channel_id
|
||||
@ -51,61 +128,16 @@ class ReactionCtx:
|
||||
self.state: StarState = self.bot.starstate
|
||||
self.connection: AbstractConnection = self.state.connection
|
||||
|
||||
async def get_channel(self, id_: int) -> discord.abc.Messageable:
|
||||
channel = self.bot.get_channel(id_)
|
||||
if channel is None:
|
||||
channel = await self.bot.fetch_channel(id_)
|
||||
if isinstance(channel, discord.CategoryChannel):
|
||||
raise TypeError
|
||||
if isinstance(channel, discord.ForumChannel):
|
||||
raise TypeError
|
||||
if isinstance(channel, discord.abc.PrivateChannel):
|
||||
raise TypeError
|
||||
return channel
|
||||
|
||||
async def on(self) -> None:
|
||||
if self.name != "⭐":
|
||||
return
|
||||
assignment: dict[str, int] | None = self.connection.get(("assign", self.guild_id), None)
|
||||
if assignment is None:
|
||||
return
|
||||
assigned_to, count = assignment["channel"], assignment["count"]
|
||||
if self.channel_id == assigned_to:
|
||||
star_channel_id, count = assignment["channel"], assignment["count"]
|
||||
if self.channel_id == star_channel_id:
|
||||
return
|
||||
async with lock_for(self.message_id):
|
||||
assigned_channel, event_channel = await asyncio.gather(
|
||||
self.get_channel(assigned_to), self.get_channel(self.channel_id)
|
||||
)
|
||||
message = await event_channel.fetch_message(self.message_id)
|
||||
reaction = next((reaction for reaction in message.reactions if reaction.emoji == "⭐"), None)
|
||||
if reaction is None:
|
||||
return
|
||||
if reaction.me:
|
||||
return
|
||||
if reaction.count >= count:
|
||||
guild = message.guild
|
||||
if guild is None:
|
||||
return
|
||||
member = guild.get_member(message.author.id)
|
||||
if member is None:
|
||||
member = await guild.fetch_member(message.author.id)
|
||||
embed = discord.Embed(description=message.content or None)
|
||||
avatar = member.avatar
|
||||
embed.set_author(name=member.display_name, url=message.jump_url, icon_url=avatar and avatar.url)
|
||||
image = next(
|
||||
(
|
||||
attachment.url
|
||||
for attachment in message.attachments
|
||||
if (attachment.content_type or "").startswith("image/")
|
||||
),
|
||||
None,
|
||||
)
|
||||
if image is not None:
|
||||
embed.set_image(url=image)
|
||||
await asyncio.gather(
|
||||
message.add_reaction("⭐"),
|
||||
assigned_channel.send(embed=embed),
|
||||
)
|
||||
await StarEventCtx(self, star_channel_id, count).on()
|
||||
|
||||
|
||||
class Stars(commands.Cog):
|
||||
@ -140,13 +172,13 @@ class Stars(commands.Cog):
|
||||
@commands.hybrid_command()
|
||||
@commands.has_permissions(administrator=True)
|
||||
async def assign(self, ctx: commands.Context, count: int):
|
||||
await StarCtx(ctx).assign(count)
|
||||
await AdminCtx(ctx).assign(count)
|
||||
await ctx.reply("assigned")
|
||||
|
||||
@commands.hybrid_command()
|
||||
@commands.has_permissions(administrator=True)
|
||||
async def unassign(self, ctx: commands.Context):
|
||||
await StarCtx(ctx).unassign()
|
||||
await AdminCtx(ctx).unassign()
|
||||
await ctx.reply("unassigned")
|
||||
|
||||
@commands.Cog.listener()
|
||||
|
Loading…
Reference in New Issue
Block a user