v6d3vote/v6d3vote/run-bot.py

257 lines
7.9 KiB
Python

import asyncio
import os
import shlex
from typing import Optional
# noinspection PyPackageRequirements
import discord
from ptvp35 import Db, KVJson
from v6d0auth.config import root
from v6d1tokens.client import request_token
from v6d3vote.config import prefix
from v6d3vote.context import Context, of, at, Implicit, monitor, Explicit
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
token = loop.run_until_complete(request_token('vote'))
client = discord.Client(
intents=discord.Intents(
members=True,
guilds=True,
bans=True,
emojis=True,
invites=True,
guild_messages=True,
reactions=True
),
)
myroot = root / 'v6d3vote'
myroot.mkdir(exist_ok=True)
vote_db = Db(myroot / 'vote.db', kvrequest_type=KVJson)
@client.event
async def on_ready():
print("ready")
await client.change_presence(activity=discord.Game(
name='феноменально',
))
async def handle_command(ctx: Context, name: str, args: list[str]) -> None:
await of('commands', name)(ctx, args)
@at('commands', 'help')
async def help_(ctx: Context, args: list[str]) -> None:
match args:
case []:
await ctx.reply('poll bot')
case [name]:
await ctx.reply(f'help for {name}: `{name} help`')
class Poll:
def __init__(
self,
message: discord.Message,
votes: dict[discord.Member, str],
emojis: dict[str, str],
options: list[str],
title: str
):
self.message = message
self.votes = votes
self.emojis = emojis
self.reverse: dict[str, str] = {emoji: option for option, emoji in emojis.items()}
self.options = options
self.title = title
def saved(self):
return {
'votes': {
str(member.id): option for member, option in self.votes.items()
},
'emojis': self.emojis,
'options': self.options,
'title': self.title
}
def content(self):
count: dict[str, int] = {}
for _, option in self.votes.items():
count[option] = count.get(option, 0) + 1
return (
f'{self.title}\n'
+
'\n'.join(f'{self.emojis[option]} `{count.get(option, 0)}` {option}' for option in self.options)
)
async def save(self):
await vote_db.set(
self.message.id,
self.saved()
)
await self.message.edit(
content=self.content()
)
@classmethod
async def create(cls, ctx: Context, options: list[tuple[str, discord.Emoji | str]], title: str):
message: discord.Message = await ctx.reply('creating poll...')
async with lock_for(message):
poll = Poll(
message,
{},
{option: str(emoji) for option, emoji in options},
[option for option, _ in options],
title
)
for _, emoji in options:
await message.add_reaction(emoji)
await poll.save()
await ctx.message.delete()
@staticmethod
async def load_votes(guild: discord.Guild, votes: dict[str, str]) -> dict[discord.Member, str]:
loaded: dict[discord.Member, str] = {}
for member, option in votes.items():
try:
loaded[guild.get_member(int(member)) or await guild.fetch_member(int(member))] = option
except (ValueError, discord.HTTPException):
pass
return loaded
@classmethod
async def load(cls, message: discord.Message) -> Optional['Poll']:
saved: Optional[dict[str, dict[str, str] | list[str]]] = vote_db.get(message.id, None)
if saved is None:
return None
# noinspection PyTypeChecker
guild: discord.Guild = message.guild
return cls(
message,
await cls.load_votes(guild, saved['votes']),
saved['emojis'],
saved['options'],
saved.get('title', 'unnamed')
)
@classmethod
async def global_vote(cls, rrae: discord.RawReactionActionEvent):
if rrae.user_id == client.user.id:
return
guild: discord.Guild = client.get_guild(rrae.guild_id) or await client.fetch_guild(rrae.guild_id)
member: discord.Member = guild.get_member(rrae.user_id) or await guild.fetch_member(rrae.user_id)
channel: discord.TextChannel = guild.get_channel(rrae.channel_id)
message: discord.Message = await channel.fetch_message(rrae.message_id)
if message.author != client.user:
return
async with lock_for(message):
poll = await cls.load(message)
if poll is None:
return
await poll.vote(member, rrae.emoji, rrae.event_type == 'REACTION_REMOVE')
await poll.save()
async def vote(self, member: discord.Member, emoji: discord.Emoji | str, remove: bool):
if str(emoji) in self.reverse:
option = self.reverse[str(emoji)]
if remove:
if self.votes.get(member) == option:
del self.votes[member]
else:
self.votes[member] = option
for other_reaction in self.message.reactions:
if str(other_reaction.emoji) != str(emoji):
await self.message.remove_reaction(other_reaction.emoji, member)
locks: dict[discord.Message, asyncio.Lock] = {}
def lock_for(message: discord.Message) -> asyncio.Lock:
if message is None:
raise Explicit('not in a guild')
if message in locks:
return locks[message]
else:
return locks.setdefault(message, asyncio.Lock())
async def poll_options(args: list[str]) -> list[tuple[str, discord.Emoji | str]]:
options: list[tuple[str, discord.Emoji | str]] = []
while args:
match args:
case [emoji, option, *args]:
try:
emoji = client.get_emoji(
int(''.join(c for c in emoji.rsplit(':', 1)[-1] if c.isdecimal()))
)
except (discord.NotFound, ValueError):
pass
options.append((option, emoji))
case _:
raise Explicit('option not specified')
return options
@at('commands', 'poll')
async def create_poll(ctx: Context, args: list[str]) -> None:
match args:
case ['help']:
await ctx.reply('`poll title emoji option [emoji option ...]`')
case []:
raise Explicit('no options')
case [title, *args]:
await Poll.create(ctx, await poll_options(args), title)
async def handle_args(message: discord.Message, args: list[str]):
match args:
case []:
return
case [command_name, *command_args]:
ctx = Context(message)
try:
await handle_command(ctx, command_name, command_args)
except Implicit:
pass
except Explicit as e:
await ctx.reply(e.msg)
@client.event
async def on_message(message: discord.Message) -> None:
if message.author.bot:
return
content: str = message.content
if not content.startswith(prefix):
return
content = content.removeprefix(prefix)
args = shlex.split(content)
await handle_args(message, args)
@client.event
async def on_raw_reaction_add(rrae: discord.RawReactionActionEvent) -> None:
await Poll.global_vote(rrae)
@client.event
async def on_raw_reaction_remove(rrae: discord.RawReactionActionEvent) -> None:
await Poll.global_vote(rrae)
async def main():
async with vote_db:
await client.login(token)
if os.getenv('v6monitor'):
loop.create_task(monitor())
await client.connect()
if __name__ == '__main__':
loop.run_until_complete(main())