temporary reservation system

This commit is contained in:
AF 2023-08-25 01:52:48 +00:00
parent 7f1c025d57
commit 58ce900757
2 changed files with 156 additions and 28 deletions

View File

@ -1,10 +1,11 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from collections.abc import Hashable from collections.abc import Hashable
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
from typing import Any, Callable, TypeVar from typing import Any, Callable, TypeVar
from weakref import WeakValueDictionary from weakref import ref
import discord import discord
from discord.ext import commands from discord.ext import commands
@ -14,18 +15,81 @@ from .db import AbstractConnection, AbstractDbFactory
T = TypeVar("T") T = TypeVar("T")
class Temporaries: class Reserved:
def __init__(self) -> None: def __init__(self, value: Any, empty: asyncio.Future) -> None:
self.__values: WeakValueDictionary[Hashable, Any] = WeakValueDictionary() self.value = value
self.__empty = empty
def get(self, key: Hashable, factory: Callable[[], T]) -> T | Any: async def empty(self) -> None:
return self.__values.get(key) or self.__values.setdefault(key, factory()) await self.__empty
class ReservedContainer:
def __init__(self, value: Any, empty: asyncio.Future) -> None:
self.reserved = Reserved(value, empty)
self.empty = empty
self.rc = 0
class Reservation:
def __init__(
self,
container: ref[dict[Hashable, Reservation]],
key: Hashable,
loop: asyncio.AbstractEventLoop,
factory: Callable[[], Any],
) -> None:
self.__container = container
self.__key = key
self.__rc: ReservedContainer | None = None
self.__create_future = loop.create_future
self.__factory = factory
def __enter__(self) -> Reserved:
container = self.__container()
if container is None:
raise RuntimeError("can't reserve in a non-existent container")
if self.__rc is None:
other = container.get(self.__key)
if other is not None:
assert other.__rc
self.__rc = other.__rc
else:
self.__rc = ReservedContainer(self.__factory(), self.__create_future())
container[self.__key] = self
rc = self.__rc
rc.rc += 1
return rc.reserved
def __exit__(self, et, ev, tb, /):
rc = self.__rc
assert rc
assert rc.rc > 0
rc.rc -= 1
if rc.rc == 0:
rc.empty.set_result(None)
container = self.__container()
if container is not None:
del container[self.__key]
class RefableDict(dict):
pass
class Reservations:
def __init__(self) -> None:
self.__container: dict[Hashable, Reservation] = RefableDict()
self.__loop = asyncio.get_running_loop()
def reserve(self, key: Hashable, factory: Callable[[], Any]) -> Reservation:
return Reservation(ref(self.__container), key, self.__loop, factory)
class StarState: class StarState:
def __init__(self, connection: AbstractConnection) -> None: def __init__(self, connection: AbstractConnection) -> None:
self.connection = connection self.connection = connection
self.temporaries = Temporaries() self.reservations = Reservations()
@asynccontextmanager @asynccontextmanager

View File

@ -1,35 +1,95 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from typing import Callable, Generic, Hashable, Type, TypeVar from contextlib import asynccontextmanager
from typing import AsyncContextManager, Callable, Generic, Hashable, Type, TypeVar
import discord import discord
from discord.ext import commands from discord.ext import commands
from .bot import StarBot, StarState, Temporaries from .bot import Reservations, StarBot, StarState
from .db import AbstractConnection from .db import AbstractConnection
T = TypeVar("T") T = TypeVar("T")
class TypedTemporaries(Generic[T]): class TypedReservations(Generic[T]):
def __init__(self, temporaries: Temporaries, type_: Type[T]) -> None: def __init__(self, reservations: Reservations, type_: Type[T], factory: Callable[[], T]) -> None:
self.temporaries = temporaries self.reservations = reservations
self.type = type_ self.type = type_
self.factory = factory
def get(self, key: Hashable, factory: Callable[[], T]) -> T: def reserve(self, key: Hashable) -> AsyncContextManager[T]:
value = self.temporaries.get((self.type, key), factory) @asynccontextmanager
if not isinstance(value, self.type): async def inner():
raise TypeError(self.type, value) factory_called = False
return value
def factory():
nonlocal factory_called
factory_called = True
return self.factory()
while True:
with self.reservations.reserve(key, factory) as reserved:
if not isinstance(reserved.value, self.type):
if factory_called:
raise TypeError("factory seems to have returned a value of incorrect type")
else:
yield reserved.value
break
await reserved.empty()
return inner()
class Locks: class Increment:
def __init__(self, state: StarState) -> None: def __init__(self) -> None:
self.typed = TypedTemporaries(state.temporaries, asyncio.Lock) self.__current = 0
def lock(self, key: Hashable) -> asyncio.Lock: def next(self) -> int:
return self.typed.get(key, asyncio.Lock) self.__current += 1
return self.__current
class AlreadyProcessed(Exception):
pass
class MessageState:
def __init__(self) -> None:
self.increment = Increment()
self.current = self.increment.next()
self.__lock = asyncio.Lock()
def _pull(self):
return self.increment.next()
def _push(self, token: int):
self.current = token
def _before(self, token: int):
return self.current < token
async def __aenter__(self) -> AsyncContextManager[None]:
@asynccontextmanager
async def inner():
token = self._pull()
async with self.__lock:
if self._before(token):
to_push = self._pull()
yield
self._push(to_push)
else:
raise AlreadyProcessed()
return inner()
async def __aexit__(self, et, ev, tb, /):
return et and issubclass(et, AlreadyProcessed)
def states(state: StarState):
return TypedReservations(state.reservations, MessageState, MessageState)
class AdminCtx: class AdminCtx:
@ -92,13 +152,13 @@ class StarMessageCtx:
class StarEventCtx: class StarEventCtx:
def __init__(self, reaction: ReactionCtx, star_channel_id: int, count: int) -> None: def __init__(self, reaction: ReactionCtx, star_channel_id: int, count: int) -> None:
self.bot = reaction.bot self.bot = reaction.bot
self.locks = Locks(reaction.state) self.states = states(reaction.state)
self.star_channel_id = star_channel_id self.star_channel_id = star_channel_id
self.count = count self.count = count
self.channel_id = reaction.channel_id self.channel_id = reaction.channel_id
self.message_id = reaction.message_id self.message_id = reaction.message_id
async def get_channel(self, id_: int) -> discord.abc.Messageable: async def _get_channel(self, id_: int) -> discord.abc.Messageable:
channel = self.bot.get_channel(id_) or await self.bot.fetch_channel(id_) channel = self.bot.get_channel(id_) or await self.bot.fetch_channel(id_)
match channel: match channel:
case discord.CategoryChannel() | discord.ForumChannel() | discord.abc.PrivateChannel(): case discord.CategoryChannel() | discord.ForumChannel() | discord.abc.PrivateChannel():
@ -106,16 +166,20 @@ class StarEventCtx:
case _: case _:
return channel return channel
async def _get_message(self) -> discord.Message:
event_channel = await self._get_channel(self.channel_id)
return await event_channel.fetch_message(self.message_id)
async def _on(self) -> None: async def _on(self) -> None:
star_channel, event_channel = await asyncio.gather( star_channel, message = await asyncio.gather(
self.get_channel(self.star_channel_id), self.get_channel(self.channel_id) self._get_channel(self.star_channel_id), self._get_message()
) )
message = await event_channel.fetch_message(self.message_id)
await StarMessageCtx(message, star_channel, self.count).on() await StarMessageCtx(message, star_channel, self.count).on()
async def on(self) -> None: async def on(self) -> None:
async with self.locks.lock(self.message_id): async with self.states.reserve(self.message_id) as state, state as guard, guard:
await self._on() await self._on()
await asyncio.sleep(10)
class ReactionCtx: class ReactionCtx: