This commit is contained in:
AF 2023-10-02 16:57:36 +00:00
parent 21411d5f9a
commit 93818a3dc1
2 changed files with 68 additions and 83 deletions

View File

@ -2,8 +2,8 @@ import os
from v6d0auth.config import root
__all__ = ('prefix', 'myroot')
__all__ = ("prefix", "myroot")
prefix = os.getenv('v6prefix', '??')
myroot = root / 'v6d3vote'
prefix = os.getenv("v6prefix", "??")
myroot = root / "v6d3vote"
myroot.mkdir(exist_ok=True)

View File

@ -17,7 +17,7 @@ from v6d2ctx.serve import *
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
token = loop.run_until_complete(request_token('vote', 'token'))
token = loop.run_until_complete(request_token("vote", "token"))
client = discord.Client(
intents=discord.Intents(
members=True,
@ -30,30 +30,26 @@ client = discord.Client(
message_content=True,
)
)
vote_db = Db(myroot / 'vote.db', kvfactory=KVJson())
vote_db = Db(myroot / "vote.db", kvfactory=KVJson())
@client.event
async def on_ready():
print("ready")
await client.change_presence(
activity=discord.Game(
name='феноменально'
)
)
await client.change_presence(activity=discord.Game(name="феноменально"))
at_of: AtOf[str, command_type] = AtOf()
at, of = at_of()
@at('help')
@at("help")
async def help_(ctx: Context, args: list[str]) -> None:
match args:
case []:
await ctx.reply('poll bot')
await ctx.reply("poll bot")
case [name]:
await ctx.reply(f'help for {name}: `{name} help`')
await ctx.reply(f"help for {name}: `{name} help`")
class SavedPoll(TypedDict):
@ -67,12 +63,12 @@ class Poll:
tasks: dict[int, asyncio.Future[None]] = {}
def __init__(
self,
message: discord.Message,
votes: dict[discord.Member, str],
emojis: dict[str, str],
options: list[str],
title: str
self,
message: discord.Message,
votes: dict[discord.Member, str],
emojis: dict[str, str],
options: list[str],
title: str,
):
self.message = message
self.votes = votes
@ -83,50 +79,40 @@ class Poll:
def saved(self):
return {
'votes': {
str(member.id): option for member, option in self.votes.items()
},
'emojis': self.emojis,
'options': self.options,
'title': self.title
"votes": {str(member.id): option for member, option in self.votes.items()},
"emojis": self.emojis,
"options": self.options,
"title": self.title,
}
def content(self):
count: dict[str, int] = {}
for _, option in self.votes.items():
count[option] = count.get(option, 0) + 1
return (
f'{self.title}\n'
+
'\n'.join(f'{self.emojis[option]} `{count.get(option, 0)}` {option}' for option in self.options)
return f"{self.title}\n" + "\n".join(
f"{self.emojis[option]} `{count.get(option, 0)}` {option}" for option in self.options
)
async def save(self):
await vote_db.set(
self.message.id,
self.saved()
)
await self.message.edit(
content=self.content()
)
await vote_db.set(self.message.id, self.saved())
await self.message.edit(content=self.content())
@classmethod
async def create(
cls, ctx: Context, options: list[tuple[str, discord.Emoji | str]], title: str,
votes: dict[discord.Member, str]
cls, ctx: Context, options: list[tuple[str, discord.Emoji | str]], title: str, votes: dict[discord.Member, str]
):
message: discord.Message = await ctx.channel.send('creating poll...')
async with lock_for(message, 'failed to create poll'):
message: discord.Message = await ctx.channel.send("creating poll...")
async with lock_for(message, "failed to create poll"):
if len(set(emoji for option, emoji in options)) != len(options):
raise Explicit('duplicate emojis')
raise Explicit("duplicate emojis")
if len(set(option for option, emoji in options)) != len(options):
raise Explicit('duplicate options')
raise Explicit("duplicate options")
poll = cls(
message,
votes,
{option: str(emoji) for option, emoji in options},
[option for option, _ in options],
title
title,
)
for _, emoji in options:
await message.add_reaction(emoji)
@ -152,31 +138,26 @@ class Poll:
guild: discord.Guild = message.guild
return cls(
message,
await cls.load_votes(guild, saved['votes']),
saved['emojis'],
saved['options'],
saved.get('title', 'unnamed')
await cls.load_votes(guild, saved["votes"]),
saved["emojis"],
saved["options"],
saved.get("title", "unnamed"),
)
@classmethod
async def _scheduled_save(cls, message: discord.Message, /) -> None:
async with lock_for(message, 'no message'):
async with lock_for(message, "no message"):
del cls.tasks[message.id]
print('saving')
print("saving")
poll = await cls.load(message)
if poll is None:
return
await vote_db.commit()
await poll.message.edit(
content=poll.content()
)
print('saved')
await poll.message.edit(content=poll.content())
print("saved")
def schedule_save(self) -> asyncio.Future[None]:
vote_db.set_nowait(
self.message.id,
self.saved()
)
vote_db.set_nowait(self.message.id, self.saved())
if self.message.id not in self.tasks:
self.tasks[self.message.id] = asyncio.create_task(self._scheduled_save(self.message))
return self.tasks[self.message.id]
@ -192,19 +173,19 @@ class Poll:
_channel = guild.get_channel(rrae.channel_id)
assert isinstance(_channel, discord.TextChannel)
channel: discord.TextChannel = _channel
print('process? ', rrae.emoji, rrae.event_type)
async with lock_for(rrae.message_id, 'no message id'):
print("process? ", rrae.emoji, rrae.event_type)
async with lock_for(rrae.message_id, "no message id"):
message: discord.Message = await channel.fetch_message(rrae.message_id)
if message.author != client.user:
return
async with lock_for(message, 'no message'):
async with lock_for(message, "no message"):
poll = await cls.load(message)
if poll is None:
return
member: discord.Member = guild.get_member(rrae.user_id) or await guild.fetch_member(rrae.user_id)
print('processing', rrae.emoji, rrae.event_type)
await poll.vote(member, rrae.emoji, rrae.event_type == 'REACTION_REMOVE')
print('processed ', rrae.emoji, rrae.event_type)
print("processing", rrae.emoji, rrae.event_type)
await poll.vote(member, rrae.emoji, rrae.event_type == "REACTION_REMOVE")
print("processed ", rrae.emoji, rrae.event_type)
future = poll.schedule_save()
await future
@ -226,32 +207,35 @@ async def poll_options(args: list[str]) -> list[tuple[str, discord.Emoji | str]]
while args:
match args:
case [emoji, option, *args]:
if '<' in emoji and '>' in emoji:
if "<" in emoji and ">" in emoji:
try:
emoji = client.get_emoji(
int(''.join(c for c in emoji.rsplit(':', 1)[-1] if c.isdecimal()))
) or emoji
emoji = (
client.get_emoji(int("".join(c for c in emoji.rsplit(":", 1)[-1] if c.isdecimal())))
or emoji
)
except (discord.NotFound, ValueError):
pass
options.append((option, emoji))
case _:
raise Explicit('option not specified')
raise Explicit("option not specified")
return options
@at('poll')
@at("poll")
async def create_poll(ctx: Context, args: list[str]) -> None:
match args:
case ['help']:
await ctx.reply('`poll title emoji option [emoji option ...]`')
await ctx.reply('`poll emoji option [emoji option ...]` (reply fork)')
case ["help"]:
await ctx.reply("`poll title emoji option [emoji option ...]`")
await ctx.reply("`poll emoji option [emoji option ...]` (reply fork)")
case []:
raise Explicit('no options')
case [*args] if ctx.message.reference is not None and isinstance(ctx.message.reference.resolved, discord.Message):
raise Explicit("no options")
case [*args] if ctx.message.reference is not None and isinstance(
ctx.message.reference.resolved, discord.Message
):
refd: discord.Message = ctx.message.reference.resolved
poll = await Poll.load(refd)
if poll is None:
raise Explicit('referenced message is not a poll')
raise Explicit("referenced message is not a poll")
await Poll.create(ctx, await poll_options(args), poll.title, poll.votes)
case [title, *args]:
await Poll.create(ctx, await poll_options(args), title, {})
@ -276,20 +260,21 @@ async def main():
async with vote_db:
await client.login(token)
await client.connect()
print('exited')
print("exited")
if __name__ == '__main__':
if __name__ == "__main__":
from contextlib import ExitStack
with ExitStack() as es:
ALog(client, 'connect').enter(es)
ALog(client, 'close').enter(es)
ALog(Db, '__aenter__').enter(es)
ALog(Db, '__aexit__').enter(es)
ALog(Db, 'aclose').enter(es)
ALog(client, "connect").enter(es)
ALog(client, "close").enter(es)
ALog(Db, "__aenter__").enter(es)
ALog(Db, "__aexit__").enter(es)
ALog(Db, "aclose").enter(es)
# SLog(Db, '_build_file_sync').enter(es)
# SLog(Db, '_finish_recovery_sync').enter(es)
# SLog(Db, '_copy_sync').enter(es)
# ALog(loop, 'run_in_executor').enter(es)
serve(main(), client, loop)
print('after serve')
print("after serve")