diff --git a/ptvp35/__init__.py b/ptvp35/__init__.py index 5110ed7..6f62f16 100644 --- a/ptvp35/__init__.py +++ b/ptvp35/__init__.py @@ -8,15 +8,20 @@ from io import StringIO from typing import Any, Optional, IO, Type, Hashable +__all__ = ('KVRequest', 'KVJson', 'DbConnection', 'DbFactory',) + + class Request: def __init__(self, future: Optional[asyncio.Future]): self.future = future def set_result(self, result): - self.future and self.future.set_result(result) + if self.future: + self.future.set_result(result) def set_exception(self, exception): - self.future and self.future.set_exception(exception) + if self.future: + self.future.set_exception(exception) class KVRequest(Request): @@ -80,26 +85,30 @@ class ErrorRequest(Request): self.line = line async def wait(self): - await self.future + if self.future: + await self.future -class Db: - __mmdb: Optional[dict] - __loop: Optional[asyncio.AbstractEventLoop] - __queue: Optional[asyncio.Queue[Request]] - __file: Optional[IO[str]] - __buffer: Optional[StringIO] - __task: Optional[asyncio.Future] +class DbConnection: + __mmdb: dict + __loop: asyncio.AbstractEventLoop + __queue: asyncio.Queue[Request] + __file: IO[str] + __buffer: StringIO + __task: asyncio.Future + __initial_size: int - def __init__(self, path: str, *, kvrequest_type: Type[KVRequest], buffersize=1048576): - path = str(path) - self.kvrequest_type = kvrequest_type - self.buffersize = buffersize + def __init__( + self, + factory: 'DbFactory', + ) -> None: + self.factory = factory + path = self.factory.path self.__path = pathlib.Path(path) self.__path_backup = pathlib.Path(path + '.backup') self.__path_recover = pathlib.Path(path + '.recover') self.__path_error = pathlib.Path(path + '.error') - self.__task = None + self.not_running = True def _queue_error(self, line: str): request = ErrorRequest(line, self.__loop.create_future()) @@ -110,7 +119,7 @@ class Db: size = 0 for line in io: try: - request = self.kvrequest_type.fromline(line) + request = self.factory.kvrequest_type.fromline(line) db[request.key] = request.value size += len(line) except (json.JSONDecodeError, pickle.UnpicklingError, EOFError): @@ -121,7 +130,7 @@ class Db: def db2io(self, db: dict, io: IO[str]) -> int: size = 0 for key, value in db.items(): - size += io.write(self.kvrequest_type(key, value, None).line()) + size += io.write(self.factory.kvrequest_type(key, value, None).line()) return size def get(self, key: Any, default: Any): @@ -130,15 +139,15 @@ class Db: async def set(self, key: Any, value: Any): self.__mmdb[key] = value future = self.__loop.create_future() - self.__queue.put_nowait(self.kvrequest_type(key, value, future)) + self.__queue.put_nowait(self.factory.kvrequest_type(key, value, future)) await future def set_nowait(self, key: Any, value: Any): self.__mmdb[key] = value - self.__queue.put_nowait(self.kvrequest_type(key, value, None)) + self.__queue.put_nowait(self.factory.kvrequest_type(key, value, None)) async def _dump_buffer_or_request_so(self, request: KVRequest): - if self.__buffer.tell() >= self.buffersize: + if self.__buffer.tell() >= self.factory.buffersize: await self._dump_buffer() request.set_result(None) else: @@ -181,7 +190,7 @@ class Db: await self.__loop.run_in_executor(None, file.write, line.strip() + '\n') async def _handle_request(self, request: Request): - if isinstance(request, self.kvrequest_type): + if isinstance(request, self.factory.kvrequest_type): await self._write(request.line(), request) elif isinstance(request, DumpRequest): await self._dump_buffer() @@ -244,47 +253,50 @@ class Db: self.__task = self.__loop.create_task(self._background_task()) async def _initialize(self): - assert self.__task is None + assert self.not_running self.__loop = asyncio.get_event_loop() await self._initialize_queue() await self._initialize_mmdb() await self._start_task() + self.not_running = False - async def __aenter__(self): - await self._initialize() - return self + @classmethod + async def create(cls, factory: 'DbFactory') -> 'DbConnection': + dbconnection = DbConnection(factory) + await dbconnection._initialize() + return dbconnection - async def _aclose(self): + async def aclose(self): if not self.__task.done(): await self.__queue.join() self.__task.cancel() await self._dump_buffer() self.__file.close() await self._build_file(self.__mmdb) + self.not_running = True + del self.__mmdb + del self.__loop + del self.__queue + del self.__file + del self.__buffer + del self.__task + del self.__initial_size - def _uninitialize(self): - self.__mmdb = None - self.__loop = None - self.__queue = None - self.__file = None - self.__buffer = None - self.__task = None +class DbFactory: + def __init__(self, path: str | pathlib.Path, *, kvrequest_type: Type[KVRequest], buffersize=1048576): + self.path = path = str(path) + self.kvrequest_type = kvrequest_type + self.buffersize = buffersize + + async def __aenter__(self): + self.db = await DbConnection.create(self) + return self.db async def __aexit__(self, exc_type, exc_val, exc_tb): - await self._aclose() - self._uninitialize() - - def cursor(self, **kwargs): - return Cursor(self, **kwargs) + await self.db.aclose() -class Cursor: - def __init__(self, db: Db, default=None): - self.default = default - self.db = db - - def __getitem__(self, item): - return self.db.get(item, self.default) - - def __setitem__(self, key, value): - self.db.set_nowait(key, value) +class Db(DbFactory, DbConnection): + def __init__(self, path: str | pathlib.Path, *, kvrequest_type: Type[KVRequest], buffersize=1048576): + DbFactory.__init__(self, path, kvrequest_type=kvrequest_type, buffersize=buffersize) + DbConnection.__init__(self, self)