transactions

This commit is contained in:
AF 2022-11-03 10:26:34 +00:00
parent 632569a135
commit d1564637f2

View File

@ -89,6 +89,15 @@ class ErrorRequest(Request):
await self.future await self.future
class TransactionRequest(Request):
def __init__(self, buffer: StringIO, future: Optional[asyncio.Future]):
super().__init__(future)
self.buffer = buffer
def line(self) -> str:
return self.buffer.getvalue()
class DbConnection: class DbConnection:
__mmdb: dict __mmdb: dict
__loop: asyncio.AbstractEventLoop __loop: asyncio.AbstractEventLoop
@ -146,14 +155,14 @@ class DbConnection:
self.__mmdb[key] = value self.__mmdb[key] = value
self.__queue.put_nowait(self.factory.kvrequest_type(key, value, None)) self.__queue.put_nowait(self.factory.kvrequest_type(key, value, None))
async def _dump_buffer_or_request_so(self, request: KVRequest): async def _dump_buffer_or_request_so(self, request: Request):
if self.__buffer.tell() >= self.factory.buffersize: if self.__buffer.tell() >= self.factory.buffersize:
await self._dump_buffer() await self._dump_buffer()
request.set_result(None) request.set_result(None)
else: else:
await self.__queue.put(DumpRequest(request.future)) await self.__queue.put(DumpRequest(request.future))
async def _write(self, line: str, request: KVRequest): async def _write(self, line: str, request: Request):
self.__buffer.write(line) self.__buffer.write(line)
await self._dump_buffer_or_request_so(request) await self._dump_buffer_or_request_so(request)
@ -197,6 +206,8 @@ class DbConnection:
request.set_result(None) request.set_result(None)
elif isinstance(request, ErrorRequest): elif isinstance(request, ErrorRequest):
await self._save_error(request.line) await self._save_error(request.line)
elif isinstance(request, TransactionRequest):
await self._write(request.line(), request)
else: else:
raise UnkownRequestType raise UnkownRequestType
@ -282,6 +293,22 @@ class DbConnection:
del self.__task del self.__task
del self.__initial_size del self.__initial_size
async def complete_transaction(self, delta: dict):
buffer = StringIO()
self.db2io(delta, buffer)
self.__mmdb.update(delta)
future = self.__loop.create_future()
self.__queue.put_nowait(TransactionRequest(buffer, future))
await future
async def commit(self):
future = self.__loop.create_future()
self.__queue.put_nowait(DumpRequest(future))
await future
def transaction(self) -> 'Transaction':
return Transaction(self)
class DbFactory: class DbFactory:
def __init__(self, path: str | pathlib.Path, *, kvrequest_type: Type[KVRequest], buffersize=1048576): def __init__(self, path: str | pathlib.Path, *, kvrequest_type: Type[KVRequest], buffersize=1048576):
@ -289,7 +316,7 @@ class DbFactory:
self.kvrequest_type = kvrequest_type self.kvrequest_type = kvrequest_type
self.buffersize = buffersize self.buffersize = buffersize
async def __aenter__(self): async def __aenter__(self) -> DbConnection:
self.db = await DbConnection.create(self) self.db = await DbConnection.create(self)
return self.db return self.db
@ -302,9 +329,37 @@ class Db(DbFactory, DbConnection):
DbFactory.__init__(self, path, kvrequest_type=kvrequest_type, buffersize=buffersize) DbFactory.__init__(self, path, kvrequest_type=kvrequest_type, buffersize=buffersize)
DbConnection.__init__(self, self) DbConnection.__init__(self, self)
async def __aenter__(self): async def __aenter__(self) -> DbConnection:
await self._initialize() await self._initialize()
return self return self
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.aclose() await self.aclose()
class FallbackMapping:
def __init__(self, delta: dict, connection: DbConnection) -> None:
self.__delta = delta
self.__connection = connection
def get(self, key: Any, default: Any):
return self.__delta.get(key, self.__connection.get(key, default))
def set_nowait(self, key: Any, value: Any):
self.__delta[key] = value
class Transaction:
delta: dict
def __init__(self, connection: DbConnection) -> None:
self.connection = connection
async def __aenter__(self) -> FallbackMapping:
self.delta = {}
return FallbackMapping(self.delta, self.connection)
async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
await self.connection.complete_transaction(self.delta)
del self.delta