temporary reservation system

This commit is contained in:
AF 2023-08-25 01:52:48 +00:00
parent 7f1c025d57
commit 58ce900757
2 changed files with 156 additions and 28 deletions

View File

@ -1,10 +1,11 @@
from __future__ import annotations
import asyncio
from collections.abc import Hashable
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Callable, TypeVar
from weakref import WeakValueDictionary
from weakref import ref
import discord
from discord.ext import commands
@ -14,18 +15,81 @@ from .db import AbstractConnection, AbstractDbFactory
T = TypeVar("T")
class Temporaries:
def __init__(self) -> None:
self.__values: WeakValueDictionary[Hashable, Any] = WeakValueDictionary()
class Reserved:
def __init__(self, value: Any, empty: asyncio.Future) -> None:
self.value = value
self.__empty = empty
def get(self, key: Hashable, factory: Callable[[], T]) -> T | Any:
return self.__values.get(key) or self.__values.setdefault(key, factory())
async def empty(self) -> None:
await self.__empty
class ReservedContainer:
def __init__(self, value: Any, empty: asyncio.Future) -> None:
self.reserved = Reserved(value, empty)
self.empty = empty
self.rc = 0
class Reservation:
def __init__(
self,
container: ref[dict[Hashable, Reservation]],
key: Hashable,
loop: asyncio.AbstractEventLoop,
factory: Callable[[], Any],
) -> None:
self.__container = container
self.__key = key
self.__rc: ReservedContainer | None = None
self.__create_future = loop.create_future
self.__factory = factory
def __enter__(self) -> Reserved:
container = self.__container()
if container is None:
raise RuntimeError("can't reserve in a non-existent container")
if self.__rc is None:
other = container.get(self.__key)
if other is not None:
assert other.__rc
self.__rc = other.__rc
else:
self.__rc = ReservedContainer(self.__factory(), self.__create_future())
container[self.__key] = self
rc = self.__rc
rc.rc += 1
return rc.reserved
def __exit__(self, et, ev, tb, /):
rc = self.__rc
assert rc
assert rc.rc > 0
rc.rc -= 1
if rc.rc == 0:
rc.empty.set_result(None)
container = self.__container()
if container is not None:
del container[self.__key]
class RefableDict(dict):
pass
class Reservations:
def __init__(self) -> None:
self.__container: dict[Hashable, Reservation] = RefableDict()
self.__loop = asyncio.get_running_loop()
def reserve(self, key: Hashable, factory: Callable[[], Any]) -> Reservation:
return Reservation(ref(self.__container), key, self.__loop, factory)
class StarState:
def __init__(self, connection: AbstractConnection) -> None:
self.connection = connection
self.temporaries = Temporaries()
self.reservations = Reservations()
@asynccontextmanager

View File

@ -1,35 +1,95 @@
from __future__ import annotations
import asyncio
from typing import Callable, Generic, Hashable, Type, TypeVar
from contextlib import asynccontextmanager
from typing import AsyncContextManager, Callable, Generic, Hashable, Type, TypeVar
import discord
from discord.ext import commands
from .bot import StarBot, StarState, Temporaries
from .bot import Reservations, StarBot, StarState
from .db import AbstractConnection
T = TypeVar("T")
class TypedTemporaries(Generic[T]):
def __init__(self, temporaries: Temporaries, type_: Type[T]) -> None:
self.temporaries = temporaries
class TypedReservations(Generic[T]):
def __init__(self, reservations: Reservations, type_: Type[T], factory: Callable[[], T]) -> None:
self.reservations = reservations
self.type = type_
self.factory = factory
def get(self, key: Hashable, factory: Callable[[], T]) -> T:
value = self.temporaries.get((self.type, key), factory)
if not isinstance(value, self.type):
raise TypeError(self.type, value)
return value
def reserve(self, key: Hashable) -> AsyncContextManager[T]:
@asynccontextmanager
async def inner():
factory_called = False
def factory():
nonlocal factory_called
factory_called = True
return self.factory()
while True:
with self.reservations.reserve(key, factory) as reserved:
if not isinstance(reserved.value, self.type):
if factory_called:
raise TypeError("factory seems to have returned a value of incorrect type")
else:
yield reserved.value
break
await reserved.empty()
return inner()
class Locks:
def __init__(self, state: StarState) -> None:
self.typed = TypedTemporaries(state.temporaries, asyncio.Lock)
class Increment:
def __init__(self) -> None:
self.__current = 0
def lock(self, key: Hashable) -> asyncio.Lock:
return self.typed.get(key, asyncio.Lock)
def next(self) -> int:
self.__current += 1
return self.__current
class AlreadyProcessed(Exception):
pass
class MessageState:
def __init__(self) -> None:
self.increment = Increment()
self.current = self.increment.next()
self.__lock = asyncio.Lock()
def _pull(self):
return self.increment.next()
def _push(self, token: int):
self.current = token
def _before(self, token: int):
return self.current < token
async def __aenter__(self) -> AsyncContextManager[None]:
@asynccontextmanager
async def inner():
token = self._pull()
async with self.__lock:
if self._before(token):
to_push = self._pull()
yield
self._push(to_push)
else:
raise AlreadyProcessed()
return inner()
async def __aexit__(self, et, ev, tb, /):
return et and issubclass(et, AlreadyProcessed)
def states(state: StarState):
return TypedReservations(state.reservations, MessageState, MessageState)
class AdminCtx:
@ -92,13 +152,13 @@ class StarMessageCtx:
class StarEventCtx:
def __init__(self, reaction: ReactionCtx, star_channel_id: int, count: int) -> None:
self.bot = reaction.bot
self.locks = Locks(reaction.state)
self.states = states(reaction.state)
self.star_channel_id = star_channel_id
self.count = count
self.channel_id = reaction.channel_id
self.message_id = reaction.message_id
async def get_channel(self, id_: int) -> discord.abc.Messageable:
async def _get_channel(self, id_: int) -> discord.abc.Messageable:
channel = self.bot.get_channel(id_) or await self.bot.fetch_channel(id_)
match channel:
case discord.CategoryChannel() | discord.ForumChannel() | discord.abc.PrivateChannel():
@ -106,16 +166,20 @@ class StarEventCtx:
case _:
return channel
async def _get_message(self) -> discord.Message:
event_channel = await self._get_channel(self.channel_id)
return await event_channel.fetch_message(self.message_id)
async def _on(self) -> None:
star_channel, event_channel = await asyncio.gather(
self.get_channel(self.star_channel_id), self.get_channel(self.channel_id)
star_channel, message = await asyncio.gather(
self._get_channel(self.star_channel_id), self._get_message()
)
message = await event_channel.fetch_message(self.message_id)
await StarMessageCtx(message, star_channel, self.count).on()
async def on(self) -> None:
async with self.locks.lock(self.message_id):
async with self.states.reserve(self.message_id) as state, state as guard, guard:
await self._on()
await asyncio.sleep(10)
class ReactionCtx: