DbFactory + DbConnection

This commit is contained in:
AF 2022-11-01 02:03:31 +00:00
parent d29d1b1395
commit 4ffae0a1ab

View File

@ -8,15 +8,20 @@ from io import StringIO
from typing import Any, Optional, IO, Type, Hashable from typing import Any, Optional, IO, Type, Hashable
__all__ = ('KVRequest', 'KVJson', 'DbConnection', 'DbFactory',)
class Request: class Request:
def __init__(self, future: Optional[asyncio.Future]): def __init__(self, future: Optional[asyncio.Future]):
self.future = future self.future = future
def set_result(self, result): def set_result(self, result):
self.future and self.future.set_result(result) if self.future:
self.future.set_result(result)
def set_exception(self, exception): def set_exception(self, exception):
self.future and self.future.set_exception(exception) if self.future:
self.future.set_exception(exception)
class KVRequest(Request): class KVRequest(Request):
@ -80,26 +85,30 @@ class ErrorRequest(Request):
self.line = line self.line = line
async def wait(self): async def wait(self):
if self.future:
await self.future await self.future
class Db: class DbConnection:
__mmdb: Optional[dict] __mmdb: dict
__loop: Optional[asyncio.AbstractEventLoop] __loop: asyncio.AbstractEventLoop
__queue: Optional[asyncio.Queue[Request]] __queue: asyncio.Queue[Request]
__file: Optional[IO[str]] __file: IO[str]
__buffer: Optional[StringIO] __buffer: StringIO
__task: Optional[asyncio.Future] __task: asyncio.Future
__initial_size: int
def __init__(self, path: str, *, kvrequest_type: Type[KVRequest], buffersize=1048576): def __init__(
path = str(path) self,
self.kvrequest_type = kvrequest_type factory: 'DbFactory',
self.buffersize = buffersize ) -> None:
self.factory = factory
path = self.factory.path
self.__path = pathlib.Path(path) self.__path = pathlib.Path(path)
self.__path_backup = pathlib.Path(path + '.backup') self.__path_backup = pathlib.Path(path + '.backup')
self.__path_recover = pathlib.Path(path + '.recover') self.__path_recover = pathlib.Path(path + '.recover')
self.__path_error = pathlib.Path(path + '.error') self.__path_error = pathlib.Path(path + '.error')
self.__task = None self.not_running = True
def _queue_error(self, line: str): def _queue_error(self, line: str):
request = ErrorRequest(line, self.__loop.create_future()) request = ErrorRequest(line, self.__loop.create_future())
@ -110,7 +119,7 @@ class Db:
size = 0 size = 0
for line in io: for line in io:
try: try:
request = self.kvrequest_type.fromline(line) request = self.factory.kvrequest_type.fromline(line)
db[request.key] = request.value db[request.key] = request.value
size += len(line) size += len(line)
except (json.JSONDecodeError, pickle.UnpicklingError, EOFError): except (json.JSONDecodeError, pickle.UnpicklingError, EOFError):
@ -121,7 +130,7 @@ class Db:
def db2io(self, db: dict, io: IO[str]) -> int: def db2io(self, db: dict, io: IO[str]) -> int:
size = 0 size = 0
for key, value in db.items(): for key, value in db.items():
size += io.write(self.kvrequest_type(key, value, None).line()) size += io.write(self.factory.kvrequest_type(key, value, None).line())
return size return size
def get(self, key: Any, default: Any): def get(self, key: Any, default: Any):
@ -130,15 +139,15 @@ class Db:
async def set(self, key: Any, value: Any): async def set(self, key: Any, value: Any):
self.__mmdb[key] = value self.__mmdb[key] = value
future = self.__loop.create_future() future = self.__loop.create_future()
self.__queue.put_nowait(self.kvrequest_type(key, value, future)) self.__queue.put_nowait(self.factory.kvrequest_type(key, value, future))
await future await future
def set_nowait(self, key: Any, value: Any): def set_nowait(self, key: Any, value: Any):
self.__mmdb[key] = value self.__mmdb[key] = value
self.__queue.put_nowait(self.kvrequest_type(key, value, None)) self.__queue.put_nowait(self.factory.kvrequest_type(key, value, None))
async def _dump_buffer_or_request_so(self, request: KVRequest): async def _dump_buffer_or_request_so(self, request: KVRequest):
if self.__buffer.tell() >= self.buffersize: if self.__buffer.tell() >= self.factory.buffersize:
await self._dump_buffer() await self._dump_buffer()
request.set_result(None) request.set_result(None)
else: else:
@ -181,7 +190,7 @@ class Db:
await self.__loop.run_in_executor(None, file.write, line.strip() + '\n') await self.__loop.run_in_executor(None, file.write, line.strip() + '\n')
async def _handle_request(self, request: Request): async def _handle_request(self, request: Request):
if isinstance(request, self.kvrequest_type): if isinstance(request, self.factory.kvrequest_type):
await self._write(request.line(), request) await self._write(request.line(), request)
elif isinstance(request, DumpRequest): elif isinstance(request, DumpRequest):
await self._dump_buffer() await self._dump_buffer()
@ -244,47 +253,50 @@ class Db:
self.__task = self.__loop.create_task(self._background_task()) self.__task = self.__loop.create_task(self._background_task())
async def _initialize(self): async def _initialize(self):
assert self.__task is None assert self.not_running
self.__loop = asyncio.get_event_loop() self.__loop = asyncio.get_event_loop()
await self._initialize_queue() await self._initialize_queue()
await self._initialize_mmdb() await self._initialize_mmdb()
await self._start_task() await self._start_task()
self.not_running = False
async def __aenter__(self): @classmethod
await self._initialize() async def create(cls, factory: 'DbFactory') -> 'DbConnection':
return self dbconnection = DbConnection(factory)
await dbconnection._initialize()
return dbconnection
async def _aclose(self): async def aclose(self):
if not self.__task.done(): if not self.__task.done():
await self.__queue.join() await self.__queue.join()
self.__task.cancel() self.__task.cancel()
await self._dump_buffer() await self._dump_buffer()
self.__file.close() self.__file.close()
await self._build_file(self.__mmdb) await self._build_file(self.__mmdb)
self.not_running = True
del self.__mmdb
del self.__loop
del self.__queue
del self.__file
del self.__buffer
del self.__task
del self.__initial_size
def _uninitialize(self): class DbFactory:
self.__mmdb = None def __init__(self, path: str | pathlib.Path, *, kvrequest_type: Type[KVRequest], buffersize=1048576):
self.__loop = None self.path = path = str(path)
self.__queue = None self.kvrequest_type = kvrequest_type
self.__file = None self.buffersize = buffersize
self.__buffer = None
self.__task = None async def __aenter__(self):
self.db = await DbConnection.create(self)
return self.db
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._aclose() await self.db.aclose()
self._uninitialize()
def cursor(self, **kwargs):
return Cursor(self, **kwargs)
class Cursor: class Db(DbFactory, DbConnection):
def __init__(self, db: Db, default=None): def __init__(self, path: str | pathlib.Path, *, kvrequest_type: Type[KVRequest], buffersize=1048576):
self.default = default DbFactory.__init__(self, path, kvrequest_type=kvrequest_type, buffersize=buffersize)
self.db = db DbConnection.__init__(self, self)
def __getitem__(self, item):
return self.db.get(item, self.default)
def __setitem__(self, key, value):
self.db.set_nowait(key, value)