temporary reservation system
This commit is contained in:
parent
7f1c025d57
commit
58ce900757
@ -1,10 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from collections.abc import Hashable
|
from collections.abc import Hashable
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, TypeVar
|
from typing import Any, Callable, TypeVar
|
||||||
from weakref import WeakValueDictionary
|
from weakref import ref
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
@ -14,18 +15,81 @@ from .db import AbstractConnection, AbstractDbFactory
|
|||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class Temporaries:
|
class Reserved:
|
||||||
def __init__(self) -> None:
|
def __init__(self, value: Any, empty: asyncio.Future) -> None:
|
||||||
self.__values: WeakValueDictionary[Hashable, Any] = WeakValueDictionary()
|
self.value = value
|
||||||
|
self.__empty = empty
|
||||||
|
|
||||||
def get(self, key: Hashable, factory: Callable[[], T]) -> T | Any:
|
async def empty(self) -> None:
|
||||||
return self.__values.get(key) or self.__values.setdefault(key, factory())
|
await self.__empty
|
||||||
|
|
||||||
|
|
||||||
|
class ReservedContainer:
|
||||||
|
def __init__(self, value: Any, empty: asyncio.Future) -> None:
|
||||||
|
self.reserved = Reserved(value, empty)
|
||||||
|
self.empty = empty
|
||||||
|
self.rc = 0
|
||||||
|
|
||||||
|
|
||||||
|
class Reservation:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
container: ref[dict[Hashable, Reservation]],
|
||||||
|
key: Hashable,
|
||||||
|
loop: asyncio.AbstractEventLoop,
|
||||||
|
factory: Callable[[], Any],
|
||||||
|
) -> None:
|
||||||
|
self.__container = container
|
||||||
|
self.__key = key
|
||||||
|
self.__rc: ReservedContainer | None = None
|
||||||
|
self.__create_future = loop.create_future
|
||||||
|
self.__factory = factory
|
||||||
|
|
||||||
|
def __enter__(self) -> Reserved:
|
||||||
|
container = self.__container()
|
||||||
|
if container is None:
|
||||||
|
raise RuntimeError("can't reserve in a non-existent container")
|
||||||
|
if self.__rc is None:
|
||||||
|
other = container.get(self.__key)
|
||||||
|
if other is not None:
|
||||||
|
assert other.__rc
|
||||||
|
self.__rc = other.__rc
|
||||||
|
else:
|
||||||
|
self.__rc = ReservedContainer(self.__factory(), self.__create_future())
|
||||||
|
container[self.__key] = self
|
||||||
|
rc = self.__rc
|
||||||
|
rc.rc += 1
|
||||||
|
return rc.reserved
|
||||||
|
|
||||||
|
def __exit__(self, et, ev, tb, /):
|
||||||
|
rc = self.__rc
|
||||||
|
assert rc
|
||||||
|
assert rc.rc > 0
|
||||||
|
rc.rc -= 1
|
||||||
|
if rc.rc == 0:
|
||||||
|
rc.empty.set_result(None)
|
||||||
|
container = self.__container()
|
||||||
|
if container is not None:
|
||||||
|
del container[self.__key]
|
||||||
|
|
||||||
|
|
||||||
|
class RefableDict(dict):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Reservations:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.__container: dict[Hashable, Reservation] = RefableDict()
|
||||||
|
self.__loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
def reserve(self, key: Hashable, factory: Callable[[], Any]) -> Reservation:
|
||||||
|
return Reservation(ref(self.__container), key, self.__loop, factory)
|
||||||
|
|
||||||
|
|
||||||
class StarState:
|
class StarState:
|
||||||
def __init__(self, connection: AbstractConnection) -> None:
|
def __init__(self, connection: AbstractConnection) -> None:
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
self.temporaries = Temporaries()
|
self.reservations = Reservations()
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
|
@ -1,35 +1,95 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Callable, Generic, Hashable, Type, TypeVar
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import AsyncContextManager, Callable, Generic, Hashable, Type, TypeVar
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
|
||||||
from .bot import StarBot, StarState, Temporaries
|
from .bot import Reservations, StarBot, StarState
|
||||||
from .db import AbstractConnection
|
from .db import AbstractConnection
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class TypedTemporaries(Generic[T]):
|
class TypedReservations(Generic[T]):
|
||||||
def __init__(self, temporaries: Temporaries, type_: Type[T]) -> None:
|
def __init__(self, reservations: Reservations, type_: Type[T], factory: Callable[[], T]) -> None:
|
||||||
self.temporaries = temporaries
|
self.reservations = reservations
|
||||||
self.type = type_
|
self.type = type_
|
||||||
|
self.factory = factory
|
||||||
|
|
||||||
def get(self, key: Hashable, factory: Callable[[], T]) -> T:
|
def reserve(self, key: Hashable) -> AsyncContextManager[T]:
|
||||||
value = self.temporaries.get((self.type, key), factory)
|
@asynccontextmanager
|
||||||
if not isinstance(value, self.type):
|
async def inner():
|
||||||
raise TypeError(self.type, value)
|
factory_called = False
|
||||||
return value
|
|
||||||
|
def factory():
|
||||||
|
nonlocal factory_called
|
||||||
|
factory_called = True
|
||||||
|
return self.factory()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
with self.reservations.reserve(key, factory) as reserved:
|
||||||
|
if not isinstance(reserved.value, self.type):
|
||||||
|
if factory_called:
|
||||||
|
raise TypeError("factory seems to have returned a value of incorrect type")
|
||||||
|
else:
|
||||||
|
yield reserved.value
|
||||||
|
break
|
||||||
|
await reserved.empty()
|
||||||
|
|
||||||
|
return inner()
|
||||||
|
|
||||||
|
|
||||||
class Locks:
|
class Increment:
|
||||||
def __init__(self, state: StarState) -> None:
|
def __init__(self) -> None:
|
||||||
self.typed = TypedTemporaries(state.temporaries, asyncio.Lock)
|
self.__current = 0
|
||||||
|
|
||||||
def lock(self, key: Hashable) -> asyncio.Lock:
|
def next(self) -> int:
|
||||||
return self.typed.get(key, asyncio.Lock)
|
self.__current += 1
|
||||||
|
return self.__current
|
||||||
|
|
||||||
|
|
||||||
|
class AlreadyProcessed(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MessageState:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.increment = Increment()
|
||||||
|
self.current = self.increment.next()
|
||||||
|
self.__lock = asyncio.Lock()
|
||||||
|
|
||||||
|
def _pull(self):
|
||||||
|
return self.increment.next()
|
||||||
|
|
||||||
|
def _push(self, token: int):
|
||||||
|
self.current = token
|
||||||
|
|
||||||
|
def _before(self, token: int):
|
||||||
|
return self.current < token
|
||||||
|
|
||||||
|
async def __aenter__(self) -> AsyncContextManager[None]:
|
||||||
|
@asynccontextmanager
|
||||||
|
async def inner():
|
||||||
|
token = self._pull()
|
||||||
|
async with self.__lock:
|
||||||
|
if self._before(token):
|
||||||
|
to_push = self._pull()
|
||||||
|
yield
|
||||||
|
self._push(to_push)
|
||||||
|
else:
|
||||||
|
raise AlreadyProcessed()
|
||||||
|
|
||||||
|
return inner()
|
||||||
|
|
||||||
|
async def __aexit__(self, et, ev, tb, /):
|
||||||
|
return et and issubclass(et, AlreadyProcessed)
|
||||||
|
|
||||||
|
|
||||||
|
def states(state: StarState):
|
||||||
|
return TypedReservations(state.reservations, MessageState, MessageState)
|
||||||
|
|
||||||
|
|
||||||
class AdminCtx:
|
class AdminCtx:
|
||||||
@ -92,13 +152,13 @@ class StarMessageCtx:
|
|||||||
class StarEventCtx:
|
class StarEventCtx:
|
||||||
def __init__(self, reaction: ReactionCtx, star_channel_id: int, count: int) -> None:
|
def __init__(self, reaction: ReactionCtx, star_channel_id: int, count: int) -> None:
|
||||||
self.bot = reaction.bot
|
self.bot = reaction.bot
|
||||||
self.locks = Locks(reaction.state)
|
self.states = states(reaction.state)
|
||||||
self.star_channel_id = star_channel_id
|
self.star_channel_id = star_channel_id
|
||||||
self.count = count
|
self.count = count
|
||||||
self.channel_id = reaction.channel_id
|
self.channel_id = reaction.channel_id
|
||||||
self.message_id = reaction.message_id
|
self.message_id = reaction.message_id
|
||||||
|
|
||||||
async def get_channel(self, id_: int) -> discord.abc.Messageable:
|
async def _get_channel(self, id_: int) -> discord.abc.Messageable:
|
||||||
channel = self.bot.get_channel(id_) or await self.bot.fetch_channel(id_)
|
channel = self.bot.get_channel(id_) or await self.bot.fetch_channel(id_)
|
||||||
match channel:
|
match channel:
|
||||||
case discord.CategoryChannel() | discord.ForumChannel() | discord.abc.PrivateChannel():
|
case discord.CategoryChannel() | discord.ForumChannel() | discord.abc.PrivateChannel():
|
||||||
@ -106,16 +166,20 @@ class StarEventCtx:
|
|||||||
case _:
|
case _:
|
||||||
return channel
|
return channel
|
||||||
|
|
||||||
|
async def _get_message(self) -> discord.Message:
|
||||||
|
event_channel = await self._get_channel(self.channel_id)
|
||||||
|
return await event_channel.fetch_message(self.message_id)
|
||||||
|
|
||||||
async def _on(self) -> None:
|
async def _on(self) -> None:
|
||||||
star_channel, event_channel = await asyncio.gather(
|
star_channel, message = await asyncio.gather(
|
||||||
self.get_channel(self.star_channel_id), self.get_channel(self.channel_id)
|
self._get_channel(self.star_channel_id), self._get_message()
|
||||||
)
|
)
|
||||||
message = await event_channel.fetch_message(self.message_id)
|
|
||||||
await StarMessageCtx(message, star_channel, self.count).on()
|
await StarMessageCtx(message, star_channel, self.count).on()
|
||||||
|
|
||||||
async def on(self) -> None:
|
async def on(self) -> None:
|
||||||
async with self.locks.lock(self.message_id):
|
async with self.states.reserve(self.message_id) as state, state as guard, guard:
|
||||||
await self._on()
|
await self._on()
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
|
||||||
|
|
||||||
class ReactionCtx:
|
class ReactionCtx:
|
||||||
|
Loading…
Reference in New Issue
Block a user