sessionservice/app/cacheddictionary.py
2023-03-01 19:01:35 +00:00

42 lines
1.3 KiB
Python

__all__ = ("CachedEntry", "CachedDictionary")
import asyncio
import functools
from typing import Any, Callable, Coroutine, Generic, Hashable, TypeVar
T = TypeVar("T")
TKey = TypeVar("TKey", bound=Hashable)
class CachedEntry(Generic[T]):
def __init__(self, value: T, getter: Callable[[], Coroutine[Any, Any, T]]) -> None:
self.__value: T = value
self.__getter = getter
self.__task: asyncio.Future[T] = asyncio.Future()
self.__task.set_result(value)
async def _set(self) -> T:
self.__value = await self.__getter()
return self.__value
def get_nowait(self) -> T:
if self.__task.done():
self.__task = asyncio.create_task(self._set())
return self.__value
async def get(self) -> T:
if self.__task.done():
self.__task = asyncio.create_task(self._set())
return await self.__task
class CachedDictionary(Generic[TKey, T]):
def __init__(self, factory: Callable[[TKey], Coroutine[Any, Any, T]]) -> None:
self.__factory = factory
self.__entries: dict[TKey, CachedEntry[T]] = {}
def entry(self, key: TKey, default: T) -> CachedEntry[T]:
if key not in self.__entries:
self.__entries[key] = CachedEntry(default, functools.partial(self.__factory, key))
return self.__entries[key]