import asyncio import json import pathlib import pickle import shutil import traceback from io import StringIO from typing import Any, Optional, IO, Type, Hashable __all__ = ('KVRequest', 'KVJson', 'DbConnection', 'DbFactory', 'Db',) class Request: def __init__(self, future: Optional[asyncio.Future]): self.future = future def set_result(self, result): if self.future: self.future.set_result(result) def set_exception(self, exception): if self.future: self.future.set_exception(exception) class KVRequest(Request): def __init__(self, key: Any, value: Any, future: Optional[asyncio.Future]): super().__init__(future) self.key = key self.value = value def free(self): return type(self)(self.key, self.value, None) def line(self) -> str: raise NotImplementedError @classmethod def fromline(cls, line: str) -> 'KVRequest': raise NotImplementedError class DumpRequest(Request): pass class UnkownRequestType(TypeError): pass class KVPickle(KVRequest): def line(self) -> str: return pickle.dumps(self.free()).hex() + "\n" @classmethod def fromline(cls, line: str) -> 'KVPickle': return pickle.loads(bytes.fromhex(line.strip())) class KVJson(KVRequest): def line(self) -> str: return json.dumps({'key': self.key, 'value': self.value}) + "\n" @classmethod def _load_key(cls, key: Any) -> Hashable: if isinstance(key, Hashable): return key elif isinstance(key, list): return tuple(map(cls._load_key, key)) elif isinstance(key, dict): return tuple((cls._load_key(k), cls._load_key(v)) for k, v in key) else: raise TypeError("unknown KVJson key type, cannot convert to hashable") @classmethod def fromline(cls, line: str) -> 'KVJson': d = json.loads(line) return KVJson(cls._load_key(d['key']), d['value'], None) class ErrorRequest(Request): def __init__(self, line: str, future: Optional[asyncio.Future]): super().__init__(future) self.line = line async def wait(self): if self.future: await self.future class TransactionRequest(Request): def __init__(self, buffer: StringIO, future: Optional[asyncio.Future]): super().__init__(future) self.buffer = buffer def line(self) -> str: return self.buffer.getvalue() class DbConnection: __mmdb: dict __loop: asyncio.AbstractEventLoop __queue: asyncio.Queue[Request] __file: IO[str] __buffer: StringIO __task: asyncio.Future __initial_size: int def __init__( self, factory: 'DbFactory', ) -> None: self.factory = factory path = self.factory.path self.__path = pathlib.Path(path) self.__path_backup = pathlib.Path(path + '.backup') self.__path_recover = pathlib.Path(path + '.recover') self.__path_error = pathlib.Path(path + '.error') self.not_running = True def _queue_error(self, line: str): request = ErrorRequest(line, self.__loop.create_future()) self.__queue.put_nowait(request) self.__loop.create_task(request.wait()) def io2db(self, io: IO[str], db: dict) -> int: size = 0 for line in io: try: request = self.factory.kvrequest_type.fromline(line) db[request.key] = request.value size += len(line) except (json.JSONDecodeError, pickle.UnpicklingError, EOFError): traceback.print_exc() self._queue_error(line) return size def db2io(self, db: dict, io: IO[str]) -> int: size = 0 for key, value in db.items(): size += io.write(self.factory.kvrequest_type(key, value, None).line()) return size def get(self, key: Any, default: Any): return self.__mmdb.get(key, default) async def set(self, key: Any, value: Any): self.__mmdb[key] = value future = self.__loop.create_future() self.__queue.put_nowait(self.factory.kvrequest_type(key, value, future)) await future def set_nowait(self, key: Any, value: Any): self.__mmdb[key] = value self.__queue.put_nowait(self.factory.kvrequest_type(key, value, None)) async def _dump_buffer_or_request_so(self, request: Request): if self.__buffer.tell() >= self.factory.buffersize: await self._dump_buffer() request.set_result(None) else: await self.__queue.put(DumpRequest(request.future)) async def _write(self, line: str, request: Request): self.__buffer.write(line) await self._dump_buffer_or_request_so(request) def _clear_buffer(self): self.__buffer = StringIO() def _compress_buffer(self) -> StringIO: self.__buffer.seek(0) bufferdb = {} self.io2db(self.__buffer, bufferdb) buffer = StringIO() self.db2io(bufferdb, buffer) return buffer async def _dump_compressed_buffer(self): buffer = self._compress_buffer() await self.__loop.run_in_executor(None, self.__file.write, buffer.getvalue()) async def _do_dump_buffer(self): await self._dump_compressed_buffer() self._clear_buffer() async def _reload_if_oversized(self): if self.__file.tell() > 2 * self.__initial_size: await self._reload() async def _dump_buffer(self): if self.__buffer.tell(): await self._do_dump_buffer() await self._reload_if_oversized() async def _save_error(self, line: str): with open(self.__path_error, 'a') as file: await self.__loop.run_in_executor(None, file.write, line.strip() + '\n') async def _handle_request(self, request: Request): if isinstance(request, self.factory.kvrequest_type): await self._write(request.line(), request) elif isinstance(request, DumpRequest): await self._dump_buffer() request.set_result(None) elif isinstance(request, ErrorRequest): await self._save_error(request.line) elif isinstance(request, TransactionRequest): await self._write(request.line(), request) else: raise UnkownRequestType async def _background_cycle(self): request: Request = await self.__queue.get() try: await self._handle_request(request) except Exception as e: request.set_exception(e) finally: self.__queue.task_done() async def _background_task(self): while True: await self._background_cycle() async def _finish_recovery(self): await self.__loop.run_in_executor(None, shutil.copy, self.__path_backup, self.__path) self.__path_recover.unlink() self.__path_backup.unlink() async def _build_file(self, db: dict): with open(self.__path_backup, "w") as file: self.__initial_size = await self.__loop.run_in_executor(None, self.db2io, db, file) self.__path_recover.touch() await self._finish_recovery() async def _rebuild_file(self, db: dict): if self.__path_recover.exists(): await self._finish_recovery() self.__path.touch() with open(self.__path) as file: await self.__loop.run_in_executor(None, self.io2db, file, db) await self._build_file(db) async def _reload(self): self.__file.close() await self._rebuild_file({}) self.__file = open(self.__path, "a") async def _load_from_file(self): await self._rebuild_file(self.__mmdb) async def _initialize_mmdb(self): self.__mmdb = {} await self._load_from_file() self.__file = open(self.__path, "a") async def _initialize_queue(self): self.__queue = asyncio.Queue() self.__buffer = StringIO() async def _start_task(self): self.__task = self.__loop.create_task(self._background_task()) async def _initialize(self): assert self.not_running self.__loop = asyncio.get_event_loop() await self._initialize_queue() await self._initialize_mmdb() await self._start_task() self.not_running = False @classmethod async def create(cls, factory: 'DbFactory') -> 'DbConnection': dbconnection = DbConnection(factory) await dbconnection._initialize() return dbconnection async def aclose(self): if not self.__task.done(): await self.__queue.join() self.__task.cancel() await self._dump_buffer() self.__file.close() await self._build_file(self.__mmdb) self.not_running = True del self.__mmdb del self.__loop del self.__queue del self.__file del self.__buffer del self.__task del self.__initial_size async def complete_transaction(self, delta: dict): buffer = StringIO() self.db2io(delta, buffer) self.__mmdb.update(delta) future = self.__loop.create_future() self.__queue.put_nowait(TransactionRequest(buffer, future)) await future async def commit(self): future = self.__loop.create_future() self.__queue.put_nowait(DumpRequest(future)) await future def transaction(self) -> 'Transaction': return Transaction(self) class DbFactory: def __init__(self, path: str | pathlib.Path, *, kvrequest_type: Type[KVRequest], buffersize=1048576): self.path = path = str(path) self.kvrequest_type = kvrequest_type self.buffersize = buffersize async def __aenter__(self) -> DbConnection: self.db = await DbConnection.create(self) return self.db async def __aexit__(self, exc_type, exc_val, exc_tb): await self.db.aclose() class Db(DbFactory, DbConnection): def __init__(self, path: str | pathlib.Path, *, kvrequest_type: Type[KVRequest], buffersize=1048576): DbFactory.__init__(self, path, kvrequest_type=kvrequest_type, buffersize=buffersize) DbConnection.__init__(self, self) async def __aenter__(self) -> DbConnection: await self._initialize() return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self.aclose() class FallbackMapping: def __init__(self, delta: dict, connection: DbConnection) -> None: self.__delta = delta self.__connection = connection def get(self, key: Any, default: Any): return self.__delta.get(key, self.__connection.get(key, default)) def set_nowait(self, key: Any, value: Any): self.__delta[key] = value class Transaction: delta: dict def __init__(self, connection: DbConnection) -> None: self.connection = connection async def __aenter__(self) -> FallbackMapping: self.delta = {} return FallbackMapping(self.delta, self.connection) async def __aexit__(self, exc_type, exc_val, exc_tb): if exc_type is None: await self.connection.complete_transaction(self.delta) del self.delta