temporary reservation system
This commit is contained in:
parent
7f1c025d57
commit
58ce900757
@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Hashable
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, TypeVar
|
||||
from weakref import WeakValueDictionary
|
||||
from weakref import ref
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
@ -14,18 +15,81 @@ from .db import AbstractConnection, AbstractDbFactory
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Temporaries:
|
||||
def __init__(self) -> None:
|
||||
self.__values: WeakValueDictionary[Hashable, Any] = WeakValueDictionary()
|
||||
class Reserved:
|
||||
def __init__(self, value: Any, empty: asyncio.Future) -> None:
|
||||
self.value = value
|
||||
self.__empty = empty
|
||||
|
||||
def get(self, key: Hashable, factory: Callable[[], T]) -> T | Any:
|
||||
return self.__values.get(key) or self.__values.setdefault(key, factory())
|
||||
async def empty(self) -> None:
|
||||
await self.__empty
|
||||
|
||||
|
||||
class ReservedContainer:
|
||||
def __init__(self, value: Any, empty: asyncio.Future) -> None:
|
||||
self.reserved = Reserved(value, empty)
|
||||
self.empty = empty
|
||||
self.rc = 0
|
||||
|
||||
|
||||
class Reservation:
|
||||
def __init__(
|
||||
self,
|
||||
container: ref[dict[Hashable, Reservation]],
|
||||
key: Hashable,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
factory: Callable[[], Any],
|
||||
) -> None:
|
||||
self.__container = container
|
||||
self.__key = key
|
||||
self.__rc: ReservedContainer | None = None
|
||||
self.__create_future = loop.create_future
|
||||
self.__factory = factory
|
||||
|
||||
def __enter__(self) -> Reserved:
|
||||
container = self.__container()
|
||||
if container is None:
|
||||
raise RuntimeError("can't reserve in a non-existent container")
|
||||
if self.__rc is None:
|
||||
other = container.get(self.__key)
|
||||
if other is not None:
|
||||
assert other.__rc
|
||||
self.__rc = other.__rc
|
||||
else:
|
||||
self.__rc = ReservedContainer(self.__factory(), self.__create_future())
|
||||
container[self.__key] = self
|
||||
rc = self.__rc
|
||||
rc.rc += 1
|
||||
return rc.reserved
|
||||
|
||||
def __exit__(self, et, ev, tb, /):
|
||||
rc = self.__rc
|
||||
assert rc
|
||||
assert rc.rc > 0
|
||||
rc.rc -= 1
|
||||
if rc.rc == 0:
|
||||
rc.empty.set_result(None)
|
||||
container = self.__container()
|
||||
if container is not None:
|
||||
del container[self.__key]
|
||||
|
||||
|
||||
class RefableDict(dict):
|
||||
pass
|
||||
|
||||
|
||||
class Reservations:
|
||||
def __init__(self) -> None:
|
||||
self.__container: dict[Hashable, Reservation] = RefableDict()
|
||||
self.__loop = asyncio.get_running_loop()
|
||||
|
||||
def reserve(self, key: Hashable, factory: Callable[[], Any]) -> Reservation:
|
||||
return Reservation(ref(self.__container), key, self.__loop, factory)
|
||||
|
||||
|
||||
class StarState:
|
||||
def __init__(self, connection: AbstractConnection) -> None:
|
||||
self.connection = connection
|
||||
self.temporaries = Temporaries()
|
||||
self.reservations = Reservations()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
@ -1,35 +1,95 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Callable, Generic, Hashable, Type, TypeVar
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncContextManager, Callable, Generic, Hashable, Type, TypeVar
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from .bot import StarBot, StarState, Temporaries
|
||||
from .bot import Reservations, StarBot, StarState
|
||||
from .db import AbstractConnection
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class TypedTemporaries(Generic[T]):
|
||||
def __init__(self, temporaries: Temporaries, type_: Type[T]) -> None:
|
||||
self.temporaries = temporaries
|
||||
class TypedReservations(Generic[T]):
|
||||
def __init__(self, reservations: Reservations, type_: Type[T], factory: Callable[[], T]) -> None:
|
||||
self.reservations = reservations
|
||||
self.type = type_
|
||||
self.factory = factory
|
||||
|
||||
def get(self, key: Hashable, factory: Callable[[], T]) -> T:
|
||||
value = self.temporaries.get((self.type, key), factory)
|
||||
if not isinstance(value, self.type):
|
||||
raise TypeError(self.type, value)
|
||||
return value
|
||||
def reserve(self, key: Hashable) -> AsyncContextManager[T]:
|
||||
@asynccontextmanager
|
||||
async def inner():
|
||||
factory_called = False
|
||||
|
||||
def factory():
|
||||
nonlocal factory_called
|
||||
factory_called = True
|
||||
return self.factory()
|
||||
|
||||
while True:
|
||||
with self.reservations.reserve(key, factory) as reserved:
|
||||
if not isinstance(reserved.value, self.type):
|
||||
if factory_called:
|
||||
raise TypeError("factory seems to have returned a value of incorrect type")
|
||||
else:
|
||||
yield reserved.value
|
||||
break
|
||||
await reserved.empty()
|
||||
|
||||
return inner()
|
||||
|
||||
|
||||
class Locks:
|
||||
def __init__(self, state: StarState) -> None:
|
||||
self.typed = TypedTemporaries(state.temporaries, asyncio.Lock)
|
||||
class Increment:
|
||||
def __init__(self) -> None:
|
||||
self.__current = 0
|
||||
|
||||
def lock(self, key: Hashable) -> asyncio.Lock:
|
||||
return self.typed.get(key, asyncio.Lock)
|
||||
def next(self) -> int:
|
||||
self.__current += 1
|
||||
return self.__current
|
||||
|
||||
|
||||
class AlreadyProcessed(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class MessageState:
|
||||
def __init__(self) -> None:
|
||||
self.increment = Increment()
|
||||
self.current = self.increment.next()
|
||||
self.__lock = asyncio.Lock()
|
||||
|
||||
def _pull(self):
|
||||
return self.increment.next()
|
||||
|
||||
def _push(self, token: int):
|
||||
self.current = token
|
||||
|
||||
def _before(self, token: int):
|
||||
return self.current < token
|
||||
|
||||
async def __aenter__(self) -> AsyncContextManager[None]:
|
||||
@asynccontextmanager
|
||||
async def inner():
|
||||
token = self._pull()
|
||||
async with self.__lock:
|
||||
if self._before(token):
|
||||
to_push = self._pull()
|
||||
yield
|
||||
self._push(to_push)
|
||||
else:
|
||||
raise AlreadyProcessed()
|
||||
|
||||
return inner()
|
||||
|
||||
async def __aexit__(self, et, ev, tb, /):
|
||||
return et and issubclass(et, AlreadyProcessed)
|
||||
|
||||
|
||||
def states(state: StarState):
|
||||
return TypedReservations(state.reservations, MessageState, MessageState)
|
||||
|
||||
|
||||
class AdminCtx:
|
||||
@ -92,13 +152,13 @@ class StarMessageCtx:
|
||||
class StarEventCtx:
|
||||
def __init__(self, reaction: ReactionCtx, star_channel_id: int, count: int) -> None:
|
||||
self.bot = reaction.bot
|
||||
self.locks = Locks(reaction.state)
|
||||
self.states = states(reaction.state)
|
||||
self.star_channel_id = star_channel_id
|
||||
self.count = count
|
||||
self.channel_id = reaction.channel_id
|
||||
self.message_id = reaction.message_id
|
||||
|
||||
async def get_channel(self, id_: int) -> discord.abc.Messageable:
|
||||
async def _get_channel(self, id_: int) -> discord.abc.Messageable:
|
||||
channel = self.bot.get_channel(id_) or await self.bot.fetch_channel(id_)
|
||||
match channel:
|
||||
case discord.CategoryChannel() | discord.ForumChannel() | discord.abc.PrivateChannel():
|
||||
@ -106,16 +166,20 @@ class StarEventCtx:
|
||||
case _:
|
||||
return channel
|
||||
|
||||
async def _get_message(self) -> discord.Message:
|
||||
event_channel = await self._get_channel(self.channel_id)
|
||||
return await event_channel.fetch_message(self.message_id)
|
||||
|
||||
async def _on(self) -> None:
|
||||
star_channel, event_channel = await asyncio.gather(
|
||||
self.get_channel(self.star_channel_id), self.get_channel(self.channel_id)
|
||||
star_channel, message = await asyncio.gather(
|
||||
self._get_channel(self.star_channel_id), self._get_message()
|
||||
)
|
||||
message = await event_channel.fetch_message(self.message_id)
|
||||
await StarMessageCtx(message, star_channel, self.count).on()
|
||||
|
||||
async def on(self) -> None:
|
||||
async with self.locks.lock(self.message_id):
|
||||
async with self.states.reserve(self.message_id) as state, state as guard, guard:
|
||||
await self._on()
|
||||
await asyncio.sleep(10)
|
||||
|
||||
|
||||
class ReactionCtx:
|
||||
|
Loading…
Reference in New Issue
Block a user