ptvp35/ptvp35/__init__.py
2023-08-28 09:56:59 +00:00

1202 lines
36 KiB
Python

# Licensed under MIT License. Copyright: 2021-2023 Alisa Feistel, PARRRATE TNV.
from __future__ import annotations
__all__ = (
"VDELETE",
"KVProtocol",
"KVFactory",
"KVJson",
"VirtualConnection",
"ExtendedVirtualConnection",
"DbInterface",
"AbstractDbConnection",
"DbConnection",
"DbManager",
"DbFactory",
"Db",
"Transaction",
"TransactionView",
"FutureContext",
)
import abc
import asyncio
import concurrent.futures
import json
import os
import pathlib
import threading
from collections.abc import Hashable
from io import StringIO, UnsupportedOperation
from typing import IO, Any, Protocol, TypeAlias
class Request:
__slots__ = ("__future",)
def __init__(self, future: asyncio.Future | None, /) -> None:
self.__future = future
def waiting(self, /) -> bool:
return self.__future is not None
def set_result(self, result, /) -> None:
if self.__future is not None:
self.__future.set_result(result)
def set_exception(self, exception, /) -> None:
if self.__future is not None:
self.__future.set_exception(exception)
async def wait(self, /) -> None:
if self.__future is not None:
await self.__future
class LineRequest(Request):
__slots__ = ("line",)
def __init__(self, line: str, /, *, future: asyncio.Future | None) -> None:
super().__init__(future)
self.line = line
class KVProtocol(abc.ABC):
@abc.abstractmethod
def dbget(self, db: dict, key: Any, default: Any, /) -> Any:
raise NotImplementedError
VDELETE = object()
class KVFactory(KVProtocol):
"""\
this class is for working with already normalised data values, not for data transformation (e.g. reducing keys to a common form).
that functionality may be added in the future, though, probably, only for custom DbConnection implementations.
note: unstable signature."""
__slots__ = ()
@abc.abstractmethod
def line(self, key: Any, value: Any, /) -> str:
"""\
line must contain exactly one '\\n' at exactly the end if the line is not empty.
note: other forms of requests will later be represented by different methods or by instances of Action class."""
raise NotImplementedError
@abc.abstractmethod
def fromline(self, line: str, /) -> tuple[Any, Any]:
"""\
inverse of line().
note: unstable signature."""
raise NotImplementedError
def run(self, line: str, db: dict, reduce: bool, /) -> None:
"""\
run request against the db.
extensible to allow forms of requests other than set.
note: unstable signature."""
key, value = self.fromline(line)
self._dbset(db, key, value, reduce)
def _dbset(self, db: dict, key: Any, value: Any, reduce: bool, /) -> None:
if reduce and value is VDELETE:
db.pop(key, None)
else:
db[key] = value
def dbset(self, db: dict, key: Any, value: Any, /) -> None:
self._dbset(db, key, value, True)
def dbget(self, db: dict, key: Any, default: Any, /) -> Any:
value = db.get(key, default)
return self.filter_value(value, default)
def filter_value(self, value: Any, default: Any, /) -> Any:
if value is VDELETE:
return default
else:
return value
def request(self, key: Any, value: Any, /, *, future: asyncio.Future | None) -> KVRequest:
"""\
form request with Future.
low-level API.
note: unstable signature."""
return KVRequest(key, value, future=future, factory=self)
def free(self, key: Any, value: Any, /) -> KVRequest:
"""\
result free from Future.
note: unstable signature."""
return self.request(key, value, future=None)
def io2db(self, io: IO[str], db: dict, reduce: bool, /) -> int:
"""\
note: unstable signature."""
size = 0
for line in io:
self.run(line, db, reduce)
size += len(line)
return size
def db2io(self, db: dict, io: IO[str], /) -> int:
"""\
note: unstable signature."""
size = 0
for key, value in db.items():
size += io.write(self.line(key, value))
return size
def path2db_sync(self, path: pathlib.Path, db: dict, /) -> int:
path.touch()
with path.open("r") as file:
return self.io2db(file, db, True)
def db2path_sync(self, db: dict, path: pathlib.Path, /) -> int:
with path.open("w") as file:
initial_size = self.db2io(db, file)
os.fsync(file.fileno())
return initial_size
class KVRequest(LineRequest):
__slots__ = (
"__factory",
"key",
"value",
)
def __init__(self, key: Any, value: Any, /, *, future: asyncio.Future | None, factory: KVFactory) -> None:
super().__init__(factory.line(key, value), future=future)
self.__factory = factory
self.key = key
self.value = value
class CommitRequest(Request):
__slots__ = ()
class UnknownRequestType(TypeError):
__slots__ = ()
class KVJson(KVFactory):
"""note: unstable signature."""
__slots__ = ()
def line(self, key: Any, value: Any, /) -> str:
if value is VDELETE:
obj = {"key": key}
else:
obj = {"key": key, "value": value}
return json.dumps(obj) + "\n"
def _load_key(self, key: Any, /) -> Hashable:
"""note: unstable signature."""
match key:
case Hashable():
return key
case list():
return tuple(map(self._load_key, key))
case dict():
return tuple((self._load_key(k), self._load_key(v)) for k, v in key.items())
case _:
raise TypeError("unknown KVJson key type, cannot convert to hashable")
def fromline(self, line: str, /) -> tuple[Any, Any]:
d = json.loads(line)
return self._load_key(d["key"]), d.get("value", VDELETE)
class TransactionRequest(LineRequest):
__slots__ = ("buffer",)
def __init__(self, buffer: StringIO, /, *, future: asyncio.Future | None) -> None:
super().__init__(buffer.getvalue(), future=future)
self.buffer = buffer
class DbParameters:
__slots__ = (
"path",
"kvfactory",
"buffersize",
)
def __init__(self, path: pathlib.Path, /, *, kvfactory: KVFactory, buffersize: int) -> None:
self.path = path
"""note: unstable signature."""
self.kvfactory = kvfactory
"""note: unstable signature."""
self.buffersize = buffersize
"""note: unstable signature."""
class RequestToClosedConnection(asyncio.InvalidStateError):
pass
class VirtualConnection(abc.ABC):
"""minimal intersection of DbConnection and TransactionView functionality"""
__slots__ = ()
@abc.abstractmethod
def get(self, key: Any, default: Any, /) -> Any:
raise NotImplementedError
@abc.abstractmethod
def kvprotocol(self, /) -> KVProtocol:
raise NotImplementedError
@abc.abstractmethod
async def commit_transaction(self, delta: dict, /) -> None:
raise NotImplementedError
@abc.abstractmethod
def submit_transaction_request(self, delta: dict, future: asyncio.Future | None, /) -> None:
raise NotImplementedError
@abc.abstractmethod
def loop(self, /) -> asyncio.AbstractEventLoop:
raise NotImplementedError
def transaction(self, /) -> Transaction:
return Transaction(self)
class ExtendedVirtualConnection(VirtualConnection, abc.ABC):
"""maximal intersection of DbConnection and TransactionView functionality"""
@abc.abstractmethod
def set_nowait(self, key: Any, value: Any, /) -> None:
raise NotImplementedError
@abc.abstractmethod
def submit_transaction(self, delta: dict, /) -> None:
raise NotImplementedError
@abc.abstractmethod
async def commit(self, /) -> None:
raise NotImplementedError
class DbInterface(ExtendedVirtualConnection, abc.ABC):
@abc.abstractmethod
async def set(self, key: Any, value: Any, /) -> None:
raise NotImplementedError
class AbstractDbConnection(Protocol):
def get(self, key: Any, default: Any, /) -> Any:
"""this method is instant."""
raise NotImplementedError
async def set(self, key: Any, value: Any, /) -> None:
"""\
this method may take time to run.
ordering may not be guaranteed (depends on event loop implementation)."""
raise NotImplementedError
def set_nowait(self, key: Any, value: Any, /) -> None:
"""\
this method is instant.
ordering is guaranteed."""
raise NotImplementedError
async def commit(self, /) -> None:
"""\
this method may take time to run.
respects the ordering of previously called :meth:`~ptvp35.AbstractDbConnection.set_nowait` methods.
will, depending on event loop implementation, also execute later changes."""
raise NotImplementedError
def transaction(self, /) -> Transaction:
raise NotImplementedError
class _Loop:
__slots__ = ("__loop",)
def __init__(self, loop: asyncio.AbstractEventLoop, /) -> None:
self.__loop = loop
def create_future(self, /) -> asyncio.Future:
return self.__loop.create_future()
def loop(self, /) -> asyncio.AbstractEventLoop:
return self.__loop
def run_in_thread(self, name: str, fn, /, *args, **kwargs) -> asyncio.Future:
"""\
we are using our own thread to guarantee as much of autonomy and control as possible.
intended for heavy tasks."""
future = self.create_future()
def wrap() -> None:
try:
result = fn(*args, **kwargs)
except Exception as exception:
self.__loop.call_soon_threadsafe(future.set_exception, exception)
else:
self.__loop.call_soon_threadsafe(future.set_result, result)
fname = getattr(fn, "__name__", "?")
threading.Thread(target=wrap, name=f"persistence5-{name}-{fname}").start()
return future
class _Errors:
__slots__ = (
"__path",
"__loop",
"__event_loop",
)
def __init__(self, path: pathlib.Path, loop: _Loop, /) -> None:
self.__path = path.with_name(path.name + ".error")
self.__loop = loop
self.__event_loop = loop.loop()
def _save_sync(self, line: str, /) -> None:
with self.__path.open("a") as file:
file.write(line.strip() + "\n")
async def _save(self, line: str, /) -> None:
await self.__event_loop.run_in_executor(None, self._save_sync, line)
def _schedule(self, line: str, /) -> concurrent.futures.Future:
return asyncio.run_coroutine_threadsafe(self._save(line), self.__event_loop)
def save_from_thread(self, line: str, /) -> None:
self._schedule(line).result()
class _File:
__slots__ = (
"__path",
"__file",
)
__file: IO[str]
def __init__(self, path: pathlib.Path, /) -> None:
self.__path = path
def path(self, /) -> pathlib.Path:
return self.__path
def tell(self, /) -> int:
return self.__file.tell()
def write_to_disk_sync(self, line: str, /) -> None:
self.__file.write(line)
self.__file.flush()
try:
os.fsync(self.__file.fileno())
except UnsupportedOperation:
pass
def open_sync(self, /) -> None:
self.__file = self.__path.open("a")
def close_sync(self, /) -> None:
self.__file.close()
del self.__file
class _Backup:
__slots__ = (
"__file",
"__kvfactory",
"__loop",
"__path",
"__backup",
"__recover",
"__initial_size",
)
__initial_size: int
def __init__(self, path: pathlib.Path, kvfactory: KVFactory, loop: _Loop, /) -> None:
self.__file = _File(path)
self.__kvfactory = kvfactory
self.__loop = loop
self.__path = path
self.__backup = path.with_name(path.name + ".backup")
self.__recover = path.with_name(path.name + ".recover")
def file(self, /) -> _File:
return self.__file
def kvfactory(self, /) -> KVFactory:
return self.__kvfactory
def _copy_sync(self, db: dict | None, /) -> None:
if db is None:
db = {}
self.__kvfactory.path2db_sync(self.__backup, db)
self.__kvfactory.db2path_sync(db, self.__path)
def _recovery_unset_sync(self, /) -> None:
self.__recover.unlink()
self.__backup.unlink()
def _finish_recovery_sync(self, db: dict | None, /) -> None:
self._copy_sync(db)
self._recovery_unset_sync()
def _recovery_set_sync(self, db: dict, /) -> None:
self.__initial_size = self.__kvfactory.db2path_sync(db, self.__backup)
self.__recover.touch()
def build_file_sync(self, db: dict, /) -> None:
self._recovery_set_sync(db)
self._finish_recovery_sync(db)
def _rebuild_file_sync(self, db: dict, /) -> None:
if self.__recover.exists():
self._finish_recovery_sync(None)
self.__kvfactory.path2db_sync(self.__path, db)
self.build_file_sync(db)
def _reload_sync(self, /) -> None:
self.__file.close_sync()
self._rebuild_file_sync({})
self.__file.open_sync()
def run_in_thread(self, fn, /, *args, **kwargs) -> asyncio.Future:
return self.__loop.run_in_thread(self.__path.name, fn, *args, **kwargs)
async def _reload(self, /) -> None:
await self.run_in_thread(self._reload_sync)
async def reload_if_oversized(self, /) -> None:
if self.__file.tell() > 2 * self.__initial_size:
await self._reload()
def load_mmdb_sync(self, /) -> dict:
db = {}
self._rebuild_file_sync(db)
return db
def uninitialize(self, /) -> None:
del self.__initial_size
class _Guard:
__slots__ = (
"__backup",
"__error",
"__file",
"__path",
"__truncate",
"__flag",
)
def __init__(self, backup: _Backup, error: _Errors, /) -> None:
self.__backup = backup
self.__error = error
self.__file = backup.file()
self.__path = path = self.__file.path()
self.__truncate = path.with_name(path.name + ".truncate")
self.__flag = path.with_name(path.name + ".truncate_flag")
def backup(self, /) -> _Backup:
return self.__backup
def _write_bytes_sync(self, s: bytes, /) -> None:
# consider subclassing/rewriting to use `os.fsync`
self.__truncate.write_bytes(s)
def _write_value_sync(self, value: int, /) -> None:
self._write_bytes_sync(value.to_bytes(16, "little"))
def _read_bytes_sync(self, /) -> bytes:
return self.__truncate.read_bytes()
def _read_value_sync(self, /) -> int:
return int.from_bytes(self._read_bytes_sync(), "little")
def _set_sync(self, /) -> None:
self._write_value_sync(self.__file.tell())
self.__flag.touch()
def _unset_sync(self, /) -> None:
self.__flag.unlink(missing_ok=True)
self.__truncate.unlink(missing_ok=True)
def _truncate_sync(self, /) -> None:
with self.__path.open("r+") as file:
self._file_truncate_sync(file, self._read_value_sync())
def assure_sync(self, /) -> None:
if self.__flag.exists():
self._truncate_sync()
self._unset_sync()
def _file_truncate_sync(self, file: IO[str], pos: int, /) -> None:
file.seek(pos)
self.__error.save_from_thread(file.read())
file.truncate(pos)
def file_write_sync(self, line: str, /) -> None:
self._set_sync()
self.__file.write_to_disk_sync(line)
self._unset_sync()
class _ReceivingQueue:
__all__ = ("__queue",)
def __init__(self, queue: asyncio.Queue[Request], /) -> None:
self.__queue: asyncio.Queue[Request] = queue
def submit(self, request: Request, /) -> None:
self.__queue.put_nowait(request)
class _WriteableBuffer:
__slots__ = (
"__buffersize",
"__guard",
"__queue",
"__backup",
"__kvfactory",
"__loop",
"__event_loop",
"__buffer",
"__buffer_future",
"__buffer_requested",
)
__buffer: StringIO
__buffer_future: asyncio.Future
__buffer_requested: bool
def __init__(self, buffersize: int, guard: _Guard, queue: _ReceivingQueue, loop: _Loop, /) -> None:
self.__buffersize = buffersize
self.__guard = guard
self.__queue = queue
self.__backup = self.__guard.backup()
self.__kvfactory = self.__backup.kvfactory()
self.__loop = loop
self.__event_loop = self.__loop.loop()
self._clear()
def writeable(self, /) -> _Guard:
return self.__guard
def loop(self, /) -> _Loop:
return self.__loop
def _compressed(self, /) -> StringIO:
self.__buffer.seek(0)
bufferdb = {}
self.__kvfactory.io2db(self.__buffer, bufferdb, False)
buffer = StringIO()
self.__kvfactory.db2io(bufferdb, buffer)
return buffer
def _commit_compressed_sync(self, /) -> None:
self.__guard.file_write_sync(self._compressed().getvalue())
async def _commit_compressed(self, /) -> None:
await self.__event_loop.run_in_executor(None, self._commit_compressed_sync)
def _clear(self, /) -> None:
self.__buffer = StringIO()
self.__buffer_future = self.__loop.create_future()
self.__buffer_requested = False
def _satisfy_future(self, /) -> None:
self.__buffer_future.set_result(None)
self._clear()
def _fail_future(self, exception: BaseException, /) -> None:
self.__buffer_future.set_exception(exception)
self._clear()
async def _do_commit_buffer(self, /) -> None:
try:
await self._commit_compressed()
except Exception as e:
self._fail_future(e)
else:
self._satisfy_future()
def _request_buffer(self, request: Request, /) -> None:
if request.waiting():
def callback(bf: asyncio.Future) -> None:
if (e := bf.exception()) is not None:
request.set_exception(e)
else:
request.set_result(None)
self.__buffer_future.add_done_callback(callback)
if not self.__buffer_requested:
self.__buffer_requested = True
self.__queue.submit(CommitRequest(None))
async def _commit(self, /) -> None:
if self.__buffer.tell():
await self._do_commit_buffer()
await self.__backup.reload_if_oversized()
elif self.__buffer_requested:
self._satisfy_future()
async def _commit_or_request_so(self, request: Request, /) -> None:
if self.__buffer.tell() >= self.__buffersize:
await self._commit()
request.set_result(None)
else:
self._request_buffer(request)
async def _write(self, line: str, request: Request, /) -> None:
self.__buffer.write(line)
await self._commit_or_request_so(request)
async def _handle_request(self, request: Request, /) -> None:
match request:
case LineRequest():
await self._write(request.line, request)
case CommitRequest():
await self._commit()
request.set_result(None)
case _:
raise UnknownRequestType
async def _close(self, /) -> None:
await self._commit()
if not self.__buffer_future.done():
self.__buffer_future.set_exception(RequestToClosedConnection())
if not isinstance(self.__buffer_future.exception(), RequestToClosedConnection):
raise RuntimeError
del self.__buffer_requested
del self.__buffer_future
del self.__buffer
class _Memory:
__slots__ = (
"__backup",
"__guard",
"__file",
"__kvfactory",
"__loop",
"__mmdb",
)
__mmdb: dict
def __init__(self, guard: _Guard, /) -> None:
self.__guard = guard
self.__backup = guard.backup()
self.__file = self.__backup.file()
self.__kvfactory = self.__backup.kvfactory()
def _initialize_sync(self, /) -> None:
self.__mmdb = self.__backup.load_mmdb_sync()
def _load_from_file_sync(self, /) -> None:
self.__guard.assure_sync()
self._initialize_sync()
self.__file.open_sync()
async def _load_from_file(self, /) -> None:
await self.__backup.run_in_thread(self._load_from_file_sync)
def _close_sync(self, /) -> None:
self.__file.close_sync()
self.__backup.build_file_sync(self.__mmdb)
del self.__mmdb
self.__backup.uninitialize()
async def _close(self, /) -> None:
await self.__backup.run_in_thread(self._close_sync)
def _transaction_buffer(self, delta: dict, /) -> StringIO:
buffer = StringIO()
self.__kvfactory.db2io(delta, buffer)
for key, value in delta.items():
self.__kvfactory.dbset(self.__mmdb, key, value)
return buffer
def get(self, key: Any, default: Any, /) -> Any:
return self.__kvfactory.dbget(self.__mmdb, key, default)
def set(self, key: Any, value: Any, /) -> None:
self.__kvfactory.dbset(self.__mmdb, key, value)
class _QueueTask:
__slots__ = (
"__queue",
"__buffer",
"__event_loop",
"__task",
)
def __init__(self, queue: asyncio.Queue[Request], buffer: _WriteableBuffer, /) -> None:
self.__queue = queue
self.__buffer = buffer
self.__event_loop = buffer.loop().loop()
async def _background_cycle(self, /) -> None:
request: Request = await self.__queue.get()
try:
await self.__buffer._handle_request(request)
except Exception as e:
request.set_exception(e)
finally:
self.__queue.task_done()
async def _background_task(self, /) -> None:
while True:
await self._background_cycle()
async def close(self, /) -> None:
if not self.__task.done():
await self.__queue.join()
self.__task.cancel()
del self.__task
if not self.__queue.empty():
raise RuntimeError
del self.__queue
await self.__buffer._close()
def start(self, /) -> None:
self.__task = self.__event_loop.create_task(self._background_task())
class _DbConnection(DbInterface):
"""note: unstable constructor signature."""
__slots__ = (
"__kvfactory",
"__buffersize",
"__path",
"__error",
"__not_running",
"__mmdb",
"__loop",
"__queue",
"__file",
"__task",
)
__mmdb: _Memory
__loop: _Loop
__queue: _ReceivingQueue
__task: _QueueTask
def __init__(self, parameters: DbParameters, /) -> None:
self.__kvfactory = parameters.kvfactory
self.__buffersize = parameters.buffersize
self.__path = parameters.path
self.__running = False
def kvprotocol(self, /) -> KVProtocol:
return self.__kvfactory
def get(self, key: Any, default: Any, /) -> Any:
"""dict-like get with mandatory default parameter."""
return self.__mmdb.get(key, default)
async def set(self, key: Any, value: Any, /) -> None:
"""set the value and wait until it's written to disk."""
future = self.__loop.create_future()
request = self.__kvfactory.request(key, value, future=future)
self.__mmdb.set(key, value)
self.__queue.submit(request)
await future
def set_nowait(self, key: Any, value: Any, /) -> None:
"""set value and add write-to-disk request to queue."""
request = self.__kvfactory.free(key, value)
self.__mmdb.set(key, value)
self.__queue.submit(request)
async def _initialize_running(self, /) -> None:
self.__loop = _Loop(asyncio.get_running_loop())
guard = _Guard(
_Backup(self.__path, self.__kvfactory, self.__loop),
_Errors(self.__path, self.__loop),
)
queue: asyncio.Queue[Request] = asyncio.Queue()
self.__queue = _ReceivingQueue(queue)
self.__mmdb = _Memory(guard)
await self.__mmdb._load_from_file()
self.__task = _QueueTask(queue, _WriteableBuffer(self.__buffersize, guard, self.__queue, self.__loop))
self.__task.start()
async def _initialize(self, /) -> None:
if self.__running:
raise RuntimeError
self.__running = True
await self._initialize_running()
@classmethod
async def create(cls, parameters: DbParameters, /) -> _DbConnection:
"""\
connect to the factory.
note: unstable signature."""
dbconnection = _DbConnection(parameters)
await dbconnection._initialize()
return dbconnection
async def _close_running(self, /) -> None:
mmdb = self.__mmdb
del self.__mmdb
del self.__queue
del self.__loop
await self.__task.close()
del self.__task
await mmdb._close()
async def aclose(self, /) -> None:
"""\
close the connection.
note: unstable signature."""
await self._close_running()
self.__running = False
async def commit_transaction(self, delta: dict, /) -> None:
"""\
hybrid of set() and dict.update().
note: unstable signature."""
if not delta:
return
buffer = self.__mmdb._transaction_buffer(delta)
future = self.__loop.create_future()
self.__queue.submit(TransactionRequest(buffer, future=future))
await future
def submit_transaction(self, delta: dict, /) -> None:
"""\
hybrid of set_nowait() and dict.update().
_nowait analogue of commit_transaction().
note: this method was added only for async-sync symmetry with commit_transaction().
note: unstable signature."""
if not delta:
return
buffer = self.__mmdb._transaction_buffer(delta)
self.__queue.submit(TransactionRequest(buffer, future=None))
def submit_transaction_request(self, delta: dict, future: asyncio.Future | None, /) -> None:
"""\
low-level API.
for high-level synchronisation use transaction() instead.
note: unstable signature."""
if not delta:
if future:
future.set_result(None)
return
buffer = self.__mmdb._transaction_buffer(delta)
self.__queue.submit(TransactionRequest(buffer, future=future))
async def commit(self, /) -> None:
"""wait until all requests queued before are completed."""
future = self.__loop.create_future()
self.__queue.submit(CommitRequest(future))
await future
def loop(self, /) -> asyncio.AbstractEventLoop:
return self.__loop.loop()
DbConnection: TypeAlias = DbInterface
class DbManager:
__slots__ = (
"__parameters",
"__db",
)
def __init__(self, path: pathlib.Path, /, *, kvfactory: KVFactory, buffersize=1048576) -> None:
self.__parameters = DbParameters(path, kvfactory=kvfactory, buffersize=buffersize)
async def __aenter__(self) -> DbInterface:
self.__db = await _DbConnection.create(self.__parameters)
return self.__db
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.__db.aclose()
DbFactory: TypeAlias = DbManager
class Db(_DbConnection):
"""simplified usecase combining the factory and the connection in one class."""
__slots__ = ()
def __init__(self, path: str | pathlib.Path, /, *, kvfactory: KVFactory, buffersize=1048576) -> None:
_DbConnection.__init__(self, DbParameters(pathlib.Path(path), kvfactory=kvfactory, buffersize=buffersize))
async def __aenter__(self) -> _DbConnection:
await self._initialize()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.aclose()
class FutureContext:
def __init__(self, future: asyncio.Future | None) -> None:
self.__future = future
async def __aenter__(self) -> None:
pass
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.wait()
async def wait(self, /) -> None:
if self.__future is not None:
await self.__future
class TransactionView(ExtendedVirtualConnection):
"""note: unstable constructor signature."""
__slots__ = (
"__delta",
"__shadow",
"__connection",
"__loop",
"__kvprotocol",
"__subfuture",
)
def __init__(self, delta: dict, connection: VirtualConnection, /) -> None:
self.__delta = delta
self.__shadow = {}
self.__connection = connection
self.__loop = connection.loop()
self.__kvprotocol = connection.kvprotocol()
self.__subfuture: asyncio.Future | None = None
def future_context(self, /) -> FutureContext:
"""do something (inside of async with), then wait for submitted changes to be committed."""
return FutureContext(self.__subfuture)
def rollback(self, /) -> None:
"""clear unsubmitted changes."""
self.__delta.clear()
def illuminate(self, /) -> None:
"""clear submitted changes, thus syncing the view (underlying the delta) with the connection."""
self.__shadow.clear()
async def ailluminate(self, /) -> None:
"""illuminate, then wait for submitted changes to be committed."""
async with self.future_context():
self.illuminate()
def fork(self, /) -> None:
"""keep delta, but forget about the shadow entirely (including making sure it's committed)."""
self.illuminate()
self.__subfuture = None
async def afork(self, /) -> None:
"""fork, then wait for submitted changes to be committed."""
async with self.future_context():
self.fork()
def clear(self, /) -> None:
"""clear all changes (including the shadow)."""
self.rollback()
self.illuminate()
async def aclear(self, /) -> None:
"""clear, then wait for submitted changes to be committed."""
async with self.future_context():
self.clear()
def reset(self, /) -> None:
"""reset transaction."""
self.clear()
self.__subfuture = None
async def areset(self, /) -> None:
"""reset, then wait for submitted changes to be committed."""
async with self.future_context():
self.reset()
def kvprotocol(self, /) -> KVProtocol:
return self.__kvprotocol
def get(self, key: Any, default: Any, /) -> Any:
"""get from the delta (unsubmitted), else from the shadow (submitted), else from the connection."""
if key in self.__delta:
return self.__kvprotocol.dbget(self.__delta, key, default)
elif key in self.__shadow:
return self.__kvprotocol.dbget(self.__shadow, key, default)
else:
return self.__connection.get(key, default)
def set_nowait(self, key: Any, value: Any, /) -> None:
"""note: unlike the corresponding db method, this one does not catch serialisation errors early."""
self.__delta[key] = value
def _delta(self, /) -> dict:
delta = self.__delta.copy()
self.__shadow |= delta
self.__delta.clear()
return delta
async def commit(self, /) -> None:
"""bulk analogue of DbConnection.set()."""
# for persistence5('s forks) developers:
# q: why not self.__subfuture = None here?
# a: run two commit calls concurrently. one will quit early and fail semantically.
# we also never implicitly reset self.__subfuture because newly created future may depend on it.
# q: why not self.submit() inside FC block?
# a: that would require using FC block later once more + that future may include extra submitted changes;
# so one would need to do submit, then do an empty FC block. that maybe introduced in the future
# q: why use if delta?
# a: to have code symmetric to that of submit + to not create an extra coroutine.
# note: q&a comments above may become obsolete
async with self.future_context():
delta = self._delta()
if delta:
await self.__connection.commit_transaction(delta)
def submit(self, /) -> None:
"""\
submit changes.
_nowait analogue of commit().
bulk analogue of DbConnection.set_nowait()."""
# for persistence5('s forks) developers:
# q: why use if delta?
# a: to have code symmetric to that of commit + to not create an extra future.
# note: q&a comments above may become obsolete
delta = self._delta()
if delta:
future = self.__loop.create_future()
self.__connection.submit_transaction_request(delta, future)
self.__subfuture = self._gather(self.__subfuture, future)
def _do_gather(self, left: asyncio.Future, right: asyncio.Future) -> asyncio.Future:
future = self.__loop.create_future()
def rcallback(fr: asyncio.Future) -> None:
if (e := fr.exception()) is not None:
future.set_exception(e)
else:
future.set_result(None)
def lcallback(fl: asyncio.Future) -> None:
if (e := fl.exception()) is not None:
future.set_exception(e)
else:
right.add_done_callback(rcallback)
left.add_done_callback(lcallback)
return future
def _reduce_future(self, future: asyncio.Future | None) -> asyncio.Future | None:
if future is None or future.done() and future.exception() is None:
return None
else:
return future
def _gather(self, left: asyncio.Future | None, right: asyncio.Future | None) -> asyncio.Future | None:
match (self._reduce_future(left), self._reduce_future(right)):
case None, ofr:
return ofr
case ofl, None:
return ofl
case asyncio.Future() as fl, asyncio.Future() as fr:
return self._do_gather(fl, fr)
case _:
raise TypeError
async def commit_transaction(self, delta: dict, /) -> None:
if not delta:
return
self.__delta.update(delta)
await self.commit()
def submit_transaction(self, delta: dict, /) -> None:
if not delta:
return
self.__delta.update(delta)
self.submit()
def submit_transaction_request(self, delta: dict, future: asyncio.Future | None, /) -> None:
def set_result(sf: asyncio.Future | None):
if future is None:
pass
elif sf is None or (e := sf.exception()) is None:
future.set_result(None)
else:
future.set_exception(e)
if not delta:
set_result(None)
return
self.submit_transaction(delta)
if self.__subfuture is None:
set_result(None)
return
self.__subfuture.add_done_callback(set_result)
def loop(self, /) -> asyncio.AbstractEventLoop:
return self.__loop
class Transaction:
"""note: unstable signature."""
__slots__ = (
"__connection",
"__view",
"__running",
)
__view: TransactionView
def __init__(self, connection: VirtualConnection, /) -> None:
self.__connection = connection
self.__running = False
async def __aenter__(self) -> TransactionView:
return self.__enter__()
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
if exc_type is None:
await self.__view.commit()
else:
self.__view.rollback()
self._clean()
def _clean(self, /) -> None:
del self.__view
self.__running = False
def __enter__(self) -> TransactionView:
if self.__running:
raise RuntimeError
self.__running = True
self.__view = TransactionView({}, self.__connection)
return self.__view
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
if exc_type is None:
self.__view.submit()
else:
self.__view.rollback()
self._clean()