From 58ce90075714fefe5f9ff92ba43ffbe43155ed5c Mon Sep 17 00:00:00 2001 From: timofey Date: Fri, 25 Aug 2023 01:52:48 +0000 Subject: [PATCH] temporary reservation system --- starbot/starbot/bot.py | 78 +++++++++++++++++++++++++--- starbot/starbot/stars.py | 106 +++++++++++++++++++++++++++++++-------- 2 files changed, 156 insertions(+), 28 deletions(-) diff --git a/starbot/starbot/bot.py b/starbot/starbot/bot.py index 12872fe..e53e9b0 100644 --- a/starbot/starbot/bot.py +++ b/starbot/starbot/bot.py @@ -1,10 +1,11 @@ from __future__ import annotations +import asyncio from collections.abc import Hashable from contextlib import asynccontextmanager from pathlib import Path from typing import Any, Callable, TypeVar -from weakref import WeakValueDictionary +from weakref import ref import discord from discord.ext import commands @@ -14,18 +15,81 @@ from .db import AbstractConnection, AbstractDbFactory T = TypeVar("T") -class Temporaries: - def __init__(self) -> None: - self.__values: WeakValueDictionary[Hashable, Any] = WeakValueDictionary() +class Reserved: + def __init__(self, value: Any, empty: asyncio.Future) -> None: + self.value = value + self.__empty = empty - def get(self, key: Hashable, factory: Callable[[], T]) -> T | Any: - return self.__values.get(key) or self.__values.setdefault(key, factory()) + async def empty(self) -> None: + await self.__empty + + +class ReservedContainer: + def __init__(self, value: Any, empty: asyncio.Future) -> None: + self.reserved = Reserved(value, empty) + self.empty = empty + self.rc = 0 + + +class Reservation: + def __init__( + self, + container: ref[dict[Hashable, Reservation]], + key: Hashable, + loop: asyncio.AbstractEventLoop, + factory: Callable[[], Any], + ) -> None: + self.__container = container + self.__key = key + self.__rc: ReservedContainer | None = None + self.__create_future = loop.create_future + self.__factory = factory + + def __enter__(self) -> Reserved: + container = self.__container() + if container is None: + raise RuntimeError("can't reserve in a non-existent container") + if self.__rc is None: + other = container.get(self.__key) + if other is not None: + assert other.__rc + self.__rc = other.__rc + else: + self.__rc = ReservedContainer(self.__factory(), self.__create_future()) + container[self.__key] = self + rc = self.__rc + rc.rc += 1 + return rc.reserved + + def __exit__(self, et, ev, tb, /): + rc = self.__rc + assert rc + assert rc.rc > 0 + rc.rc -= 1 + if rc.rc == 0: + rc.empty.set_result(None) + container = self.__container() + if container is not None: + del container[self.__key] + + +class RefableDict(dict): + pass + + +class Reservations: + def __init__(self) -> None: + self.__container: dict[Hashable, Reservation] = RefableDict() + self.__loop = asyncio.get_running_loop() + + def reserve(self, key: Hashable, factory: Callable[[], Any]) -> Reservation: + return Reservation(ref(self.__container), key, self.__loop, factory) class StarState: def __init__(self, connection: AbstractConnection) -> None: self.connection = connection - self.temporaries = Temporaries() + self.reservations = Reservations() @asynccontextmanager diff --git a/starbot/starbot/stars.py b/starbot/starbot/stars.py index c013b63..645ad2c 100644 --- a/starbot/starbot/stars.py +++ b/starbot/starbot/stars.py @@ -1,35 +1,95 @@ from __future__ import annotations import asyncio -from typing import Callable, Generic, Hashable, Type, TypeVar +from contextlib import asynccontextmanager +from typing import AsyncContextManager, Callable, Generic, Hashable, Type, TypeVar import discord from discord.ext import commands -from .bot import StarBot, StarState, Temporaries +from .bot import Reservations, StarBot, StarState from .db import AbstractConnection T = TypeVar("T") -class TypedTemporaries(Generic[T]): - def __init__(self, temporaries: Temporaries, type_: Type[T]) -> None: - self.temporaries = temporaries +class TypedReservations(Generic[T]): + def __init__(self, reservations: Reservations, type_: Type[T], factory: Callable[[], T]) -> None: + self.reservations = reservations self.type = type_ + self.factory = factory - def get(self, key: Hashable, factory: Callable[[], T]) -> T: - value = self.temporaries.get((self.type, key), factory) - if not isinstance(value, self.type): - raise TypeError(self.type, value) - return value + def reserve(self, key: Hashable) -> AsyncContextManager[T]: + @asynccontextmanager + async def inner(): + factory_called = False + + def factory(): + nonlocal factory_called + factory_called = True + return self.factory() + + while True: + with self.reservations.reserve(key, factory) as reserved: + if not isinstance(reserved.value, self.type): + if factory_called: + raise TypeError("factory seems to have returned a value of incorrect type") + else: + yield reserved.value + break + await reserved.empty() + + return inner() -class Locks: - def __init__(self, state: StarState) -> None: - self.typed = TypedTemporaries(state.temporaries, asyncio.Lock) +class Increment: + def __init__(self) -> None: + self.__current = 0 - def lock(self, key: Hashable) -> asyncio.Lock: - return self.typed.get(key, asyncio.Lock) + def next(self) -> int: + self.__current += 1 + return self.__current + + +class AlreadyProcessed(Exception): + pass + + +class MessageState: + def __init__(self) -> None: + self.increment = Increment() + self.current = self.increment.next() + self.__lock = asyncio.Lock() + + def _pull(self): + return self.increment.next() + + def _push(self, token: int): + self.current = token + + def _before(self, token: int): + return self.current < token + + async def __aenter__(self) -> AsyncContextManager[None]: + @asynccontextmanager + async def inner(): + token = self._pull() + async with self.__lock: + if self._before(token): + to_push = self._pull() + yield + self._push(to_push) + else: + raise AlreadyProcessed() + + return inner() + + async def __aexit__(self, et, ev, tb, /): + return et and issubclass(et, AlreadyProcessed) + + +def states(state: StarState): + return TypedReservations(state.reservations, MessageState, MessageState) class AdminCtx: @@ -92,13 +152,13 @@ class StarMessageCtx: class StarEventCtx: def __init__(self, reaction: ReactionCtx, star_channel_id: int, count: int) -> None: self.bot = reaction.bot - self.locks = Locks(reaction.state) + self.states = states(reaction.state) self.star_channel_id = star_channel_id self.count = count self.channel_id = reaction.channel_id self.message_id = reaction.message_id - async def get_channel(self, id_: int) -> discord.abc.Messageable: + async def _get_channel(self, id_: int) -> discord.abc.Messageable: channel = self.bot.get_channel(id_) or await self.bot.fetch_channel(id_) match channel: case discord.CategoryChannel() | discord.ForumChannel() | discord.abc.PrivateChannel(): @@ -106,16 +166,20 @@ class StarEventCtx: case _: return channel + async def _get_message(self) -> discord.Message: + event_channel = await self._get_channel(self.channel_id) + return await event_channel.fetch_message(self.message_id) + async def _on(self) -> None: - star_channel, event_channel = await asyncio.gather( - self.get_channel(self.star_channel_id), self.get_channel(self.channel_id) + star_channel, message = await asyncio.gather( + self._get_channel(self.star_channel_id), self._get_message() ) - message = await event_channel.fetch_message(self.message_id) await StarMessageCtx(message, star_channel, self.count).on() async def on(self) -> None: - async with self.locks.lock(self.message_id): + async with self.states.reserve(self.message_id) as state, state as guard, guard: await self._on() + await asyncio.sleep(10) class ReactionCtx: