42 lines
1.3 KiB
Python
42 lines
1.3 KiB
Python
__all__ = ("CachedEntry", "CachedDictionary")
|
|
|
|
import asyncio
|
|
import functools
|
|
from typing import Any, Callable, Coroutine, Generic, Hashable, TypeVar
|
|
|
|
T = TypeVar("T")
|
|
TKey = TypeVar("TKey", bound=Hashable)
|
|
|
|
|
|
class CachedEntry(Generic[T]):
|
|
def __init__(self, value: T, getter: Callable[[], Coroutine[Any, Any, T]]) -> None:
|
|
self.__value: T = value
|
|
self.__getter = getter
|
|
self.__task: asyncio.Future[T] = asyncio.Future()
|
|
self.__task.set_result(value)
|
|
|
|
async def _set(self) -> T:
|
|
self.__value = await self.__getter()
|
|
return self.__value
|
|
|
|
def get_nowait(self) -> T:
|
|
if self.__task.done():
|
|
self.__task = asyncio.create_task(self._set())
|
|
return self.__value
|
|
|
|
async def get(self) -> T:
|
|
if self.__task.done():
|
|
self.__task = asyncio.create_task(self._set())
|
|
return await self.__task
|
|
|
|
|
|
class CachedDictionary(Generic[TKey, T]):
|
|
def __init__(self, factory: Callable[[TKey], Coroutine[Any, Any, T]]) -> None:
|
|
self.__factory = factory
|
|
self.__entries: dict[TKey, CachedEntry[T]] = {}
|
|
|
|
def entry(self, key: TKey, default: T) -> CachedEntry[T]:
|
|
if key not in self.__entries:
|
|
self.__entries[key] = CachedEntry(default, functools.partial(self.__factory, key))
|
|
return self.__entries[key]
|