diff --git a/ptvp35/__init__.py b/ptvp35/__init__.py index b14f0e3..01ee527 100644 --- a/ptvp35/__init__.py +++ b/ptvp35/__init__.py @@ -89,6 +89,15 @@ class ErrorRequest(Request): await self.future +class TransactionRequest(Request): + def __init__(self, buffer: StringIO, future: Optional[asyncio.Future]): + super().__init__(future) + self.buffer = buffer + + def line(self) -> str: + return self.buffer.getvalue() + + class DbConnection: __mmdb: dict __loop: asyncio.AbstractEventLoop @@ -146,14 +155,14 @@ class DbConnection: self.__mmdb[key] = value self.__queue.put_nowait(self.factory.kvrequest_type(key, value, None)) - async def _dump_buffer_or_request_so(self, request: KVRequest): + async def _dump_buffer_or_request_so(self, request: Request): if self.__buffer.tell() >= self.factory.buffersize: await self._dump_buffer() request.set_result(None) else: await self.__queue.put(DumpRequest(request.future)) - async def _write(self, line: str, request: KVRequest): + async def _write(self, line: str, request: Request): self.__buffer.write(line) await self._dump_buffer_or_request_so(request) @@ -197,6 +206,8 @@ class DbConnection: request.set_result(None) elif isinstance(request, ErrorRequest): await self._save_error(request.line) + elif isinstance(request, TransactionRequest): + await self._write(request.line(), request) else: raise UnkownRequestType @@ -282,6 +293,22 @@ class DbConnection: del self.__task del self.__initial_size + async def complete_transaction(self, delta: dict): + buffer = StringIO() + self.db2io(delta, buffer) + self.__mmdb.update(delta) + future = self.__loop.create_future() + self.__queue.put_nowait(TransactionRequest(buffer, future)) + await future + + async def commit(self): + future = self.__loop.create_future() + self.__queue.put_nowait(DumpRequest(future)) + await future + + def transaction(self) -> 'Transaction': + return Transaction(self) + class DbFactory: def __init__(self, path: str | pathlib.Path, *, kvrequest_type: Type[KVRequest], buffersize=1048576): @@ -289,7 +316,7 @@ class DbFactory: self.kvrequest_type = kvrequest_type self.buffersize = buffersize - async def __aenter__(self): + async def __aenter__(self) -> DbConnection: self.db = await DbConnection.create(self) return self.db @@ -302,9 +329,37 @@ class Db(DbFactory, DbConnection): DbFactory.__init__(self, path, kvrequest_type=kvrequest_type, buffersize=buffersize) DbConnection.__init__(self, self) - async def __aenter__(self): + async def __aenter__(self) -> DbConnection: await self._initialize() return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self.aclose() + + +class FallbackMapping: + def __init__(self, delta: dict, connection: DbConnection) -> None: + self.__delta = delta + self.__connection = connection + + def get(self, key: Any, default: Any): + return self.__delta.get(key, self.__connection.get(key, default)) + + def set_nowait(self, key: Any, value: Any): + self.__delta[key] = value + + +class Transaction: + delta: dict + + def __init__(self, connection: DbConnection) -> None: + self.connection = connection + + async def __aenter__(self) -> FallbackMapping: + self.delta = {} + return FallbackMapping(self.delta, self.connection) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + await self.connection.complete_transaction(self.delta) + del self.delta