diff --git a/ptvp35/__init__.py b/ptvp35/__init__.py index 908b221..6cfec13 100644 --- a/ptvp35/__init__.py +++ b/ptvp35/__init__.py @@ -1,77 +1,93 @@ import asyncio +import concurrent.futures import json import os import pathlib -import pickle import threading import traceback -from io import StringIO -from typing import Any, Optional, IO, Type, Hashable +from io import StringIO, UnsupportedOperation +from typing import Any, Optional, IO, Hashable __all__ = ('KVRequest', 'KVJson', 'DbConnection', 'DbFactory', 'Db',) class Request: - def __init__(self, future: Optional[asyncio.Future]): + __slots__ = ( + '__future', + ) + + def __init__(self, future: Optional[asyncio.Future], /) -> None: self.__future = future - def waiting(self) -> bool: + def waiting(self, /) -> bool: return self.__future is not None - def set_result(self, result): + def set_result(self, result, /) -> None: if self.__future is not None: self.__future.set_result(result) - def set_exception(self, exception): + def set_exception(self, exception, /): if self.__future is not None: self.__future.set_exception(exception) - async def wait(self): + async def wait(self, /): if self.__future is not None: await self.__future +class KVFactory: + __slots__ = () + + def line(self, key: Any, value: Any, /) -> str: + raise NotImplementedError + + def fromline(self, line: str, /) -> 'KVRequest': + raise NotImplementedError + + def request(self, key: Any, value: Any, /, *, future: Optional[asyncio.Future]) -> 'KVRequest': + return KVRequest(key, value, future=future, factory=self) + + def free(self, key: Any, value: Any, /) -> 'KVRequest': + return self.request(key, value, future=None) + + class KVRequest(Request): - def __init__(self, key: Any, value: Any, future: Optional[asyncio.Future]): + __slots__ = ( + '__factory', + 'key', + 'value', + ) + + def __init__(self, key: Any, value: Any, /, *, future: Optional[asyncio.Future], factory: KVFactory): super().__init__(future) + self.__factory = factory self.key = key self.value = value - def free(self): - return type(self)(self.key, self.value, None) + def free(self, /): + return self.__factory.free(self.key, self.value) - def line(self) -> str: - raise NotImplementedError - - @classmethod - def fromline(cls, line: str) -> 'KVRequest': - raise NotImplementedError + def line(self, /) -> str: + return self.__factory.line(self.key, self.value) class DumpRequest(Request): + __slots__ = () + + +class UnknownRequestType(TypeError): pass -class UnkownRequestType(TypeError): - pass +class KVJson(KVFactory): + __slots__ = () - -class KVPickle(KVRequest): - def line(self) -> str: - return pickle.dumps(self.free()).hex() + "\n" + def line(self, key: Any, value: Any, /) -> str: + return json.dumps({'key': key, 'value': value}) + '\n' @classmethod - def fromline(cls, line: str) -> 'KVPickle': - return pickle.loads(bytes.fromhex(line.strip())) - - -class KVJson(KVRequest): - def line(self) -> str: - return json.dumps({'key': self.key, 'value': self.value}) + "\n" - - @classmethod - def _load_key(cls, key: Any) -> Hashable: + def _load_key(cls, key: Any, /) -> Hashable: if isinstance(key, Hashable): return key elif isinstance(key, list): @@ -79,30 +95,49 @@ class KVJson(KVRequest): elif isinstance(key, dict): return tuple((cls._load_key(k), cls._load_key(v)) for k, v in key.items()) else: - raise TypeError("unknown KVJson key type, cannot convert to hashable") + raise TypeError( + 'unknown KVJson key type, cannot convert to hashable' + ) - @classmethod - def fromline(cls, line: str) -> 'KVJson': + def fromline(self, line: str, /) -> 'KVRequest': d = json.loads(line) - return KVJson(cls._load_key(d['key']), d['value'], None) - - -class ErrorRequest(Request): - def __init__(self, line: str, future: Optional[asyncio.Future]): - super().__init__(future) - self.line = line + return self.free(self._load_key(d['key']), d['value']) class TransactionRequest(Request): - def __init__(self, buffer: StringIO, future: Optional[asyncio.Future]): + __slots__ = ( + 'buffer', + ) + + def __init__(self, buffer: StringIO, /, *, future: Optional[asyncio.Future]): super().__init__(future) self.buffer = buffer - - def line(self) -> str: + + def line(self, /) -> str: return self.buffer.getvalue() class DbConnection: + __slots__ = ( + '__factory', + '__path', + '__path_backup', + '__path_recover', + '__path_error', + '__path_truncate', + '__path_truncate_flag', + '__not_running', + '__mmdb', + '__loop', + '__queue', + '__file', + '__buffer', + '__buffer_future', + '__buffer_requested', + '__task', + '__initial_size', + ) + __mmdb: dict __loop: asyncio.AbstractEventLoop __queue: asyncio.Queue[Request] @@ -116,9 +151,10 @@ class DbConnection: def __init__( self, factory: 'DbFactory', + / ) -> None: - self.factory = factory - self.__path = path = self.factory.path + self.__factory = factory + self.__path = path = self.__factory.path name = self.__path.name self.__path_backup = path.with_name(name + '.backup') self.__path_recover = path.with_name(name + '.recover') @@ -127,66 +163,73 @@ class DbConnection: self.__path_truncate_flag = path.with_name(name + '.truncate_flag') self.__not_running = True - def _queue_error(self, line: str): - request = ErrorRequest(line, self.__loop.create_future()) - self.__queue.put_nowait(request) - self.__loop.create_task(request.wait()) + def _create_future(self, /) -> asyncio.Future: + return self.__loop.create_future() - def io2db(self, io: IO[str], db: dict) -> int: + def _save_error_sync(self, line: str, /) -> None: + with self.__path_error.open('a') as file: + file.write(line.strip() + '\n') + + async def _save_error(self, line: str, /) -> None: + await self.__loop.run_in_executor(None, self._save_error_sync, line) + + def _queue_error(self, line: str, /) -> concurrent.futures.Future: + return asyncio.run_coroutine_threadsafe( + self._save_error(line), self.__loop + ) + + def _save_error_from_thread(self, line: str, /) -> None: + self._queue_error(line).result() + + def io2db(self, io: IO[str], db: dict, /) -> int: + """there are no guarantees about .error file if error occurs here""" size = 0 for line in io: try: - request = self.factory.kvrequest_type.fromline(line) + request = self.__factory.kvfactory.fromline(line) db[request.key] = request.value size += len(line) - except (json.JSONDecodeError, pickle.UnpicklingError, EOFError): + except (json.JSONDecodeError, EOFError): traceback.print_exc() + # this condition should never occur, but we should be able to handle this UB as best as we can self._queue_error(line) return size - def db2io(self, db: dict, io: IO[str]) -> int: + def db2io(self, db: dict, io: IO[str], /) -> int: size = 0 for key, value in db.items(): - size += io.write(self.factory.kvrequest_type(key, value, None).line()) + size += io.write(self.__factory.kvfactory.free(key, value).line()) return size - def get(self, key: Any, default: Any): + def _path2db_sync(self, path: pathlib.Path, db: dict, /) -> int: + path.touch() + with path.open('r') as file: + return self.io2db(file, db) + + def _db2path_sync(self, db: dict, path: pathlib.Path, /) -> int: + with path.open('w') as file: + return self.db2io(db, file) + + def get(self, key: Any, default: Any, /): return self.__mmdb.get(key, default) - async def set(self, key: Any, value: Any): + async def set(self, key: Any, value: Any, /) -> None: self.__mmdb[key] = value - future = self.__loop.create_future() - self.__queue.put_nowait(self.factory.kvrequest_type(key, value, future)) + future = self._create_future() + self.__queue.put_nowait( + self.__factory.kvfactory.request(key, value, future=future)) await future - def set_nowait(self, key: Any, value: Any): + def set_nowait(self, key: Any, value: Any, /) -> None: self.__mmdb[key] = value - self.__queue.put_nowait(self.factory.kvrequest_type(key, value, None)) + self.__queue.put_nowait(self.__factory.kvfactory.free(key, value)) - def _request_buffer(self, request: Request): - if request.waiting(): - self.__buffer_future.add_done_callback(lambda _: request.set_result(None)) - if not self.__buffer_requested: - self.__buffer_requested = True - self.__queue.put_nowait(DumpRequest(None)) - - async def _dump_buffer_or_request_so(self, request: Request): - if self.__buffer.tell() >= self.factory.buffersize: - await self._dump_buffer() - request.set_result(None) - else: - self._request_buffer(request) - - async def _write(self, line: str, request: Request): - self.__buffer.write(line) - await self._dump_buffer_or_request_so(request) - - def _clear_buffer(self): + def _clear_buffer(self, /) -> None: self.__buffer = StringIO() - self.__buffer_future = self.__loop.create_future() + self.__buffer_future = self._create_future() self.__buffer_requested = False - def _compress_buffer(self) -> StringIO: + def _compress_buffer(self, /) -> StringIO: self.__buffer.seek(0) bufferdb = {} self.io2db(self.__buffer, bufferdb) @@ -194,70 +237,105 @@ class DbConnection: self.db2io(bufferdb, buffer) return buffer - def _file_truncate_set(self): - self.__path_truncate.write_bytes(self.__file.tell().to_bytes(16, 'little')) - self.__path_truncate_flag.touch() + def _dump_compressed_buffer_sync(self, /) -> None: + self._file_write_sync(self._compress_buffer().getvalue()) - def _file_truncate_unset(self): - self.__path_truncate_flag.unlink(missing_ok=True) - self.__path_truncate.unlink(missing_ok=True) - - def _write_to_disk(self, line: str): - self.__file.write(line) - self.__file.flush() - os.fsync(self.__file.fileno()) - - def _file_write(self, line: str): - self._file_truncate_set() - self._write_to_disk(line) - self._file_truncate_unset() - - def _dump_compressed_buffer_sync(self): - self._file_write(self._compress_buffer().getvalue()) - - async def _dump_compressed_buffer(self): + async def _dump_compressed_buffer(self, /) -> None: await self.__loop.run_in_executor(None, self._dump_compressed_buffer_sync) - def _satisfy_buffer_future(self): + def _satisfy_buffer_future(self, /) -> None: self.__buffer_future.set_result(None) self._clear_buffer() - async def _do_dump_buffer(self): + async def _do_dump_buffer(self, /) -> None: await self._dump_compressed_buffer() self._satisfy_buffer_future() - async def _reload_if_oversized(self): - if self.__file.tell() > 2 * self.__initial_size: - await self._reload() - - async def _dump_buffer(self): + async def _dump_buffer(self, /) -> None: if self.__buffer.tell(): await self._do_dump_buffer() await self._reload_if_oversized() elif self.__buffer_requested: self._satisfy_buffer_future() - def _save_error_sync(self, line: str): - with open(self.__path_error, 'a') as file: - file.write(line.strip() + '\n') + def _request_buffer(self, request: Request, /) -> None: + if request.waiting(): + self.__buffer_future.exception + self.__buffer_future.add_done_callback( + lambda bf: request.set_exception(e) if ( + e := bf.exception()) is not None else request.set_result(None) + ) + if not self.__buffer_requested: + self.__buffer_requested = True + self.__queue.put_nowait(DumpRequest(None)) - async def _save_error(self, line: str): - await self.__loop.run_in_executor(None, self._save_error_sync, line) + async def _dump_buffer_or_request_so(self, request: Request, /) -> None: + if self.__buffer.tell() >= self.__factory.buffersize: + await self._dump_buffer() + request.set_result(None) + else: + self._request_buffer(request) - async def _handle_request(self, request: Request): - if isinstance(request, self.factory.kvrequest_type): + async def _write(self, line: str, request: Request, /) -> None: + self.__buffer.write(line) + await self._dump_buffer_or_request_so(request) + + def _truncation_set_sync(self, /) -> None: + self.__path_truncate.write_bytes( + self.__file.tell().to_bytes(16, 'little') + ) + self.__path_truncate_flag.touch() + + def _truncation_unset_sync(self, /) -> None: + self.__path_truncate_flag.unlink(missing_ok=True) + self.__path_truncate.unlink(missing_ok=True) + + def _file_truncate_sync(self, file: IO[str], pos: int, /) -> None: + file.seek(pos) + self._save_error_from_thread(file.read()) + file.truncate(pos) + + def _truncation_target_sync(self, /) -> int: + return int.from_bytes(self.__path_truncate.read_bytes(), 'little') + + def _truncate_sync(self, /) -> None: + with self.__path.open('r+') as file: + self._file_truncate_sync(file, self._truncation_target_sync()) + + def _assure_truncation_sync(self, /) -> None: + if self.__path_truncate_flag.exists(): + self._truncate_sync() + self._truncation_unset_sync() + + def _write_to_disk_sync(self, line: str, /) -> None: + self.__file.write(line) + self.__file.flush() + try: + os.fsync(self.__file.fileno()) + except UnsupportedOperation: + pass + + def _file_write_sync(self, line: str, /) -> None: + self._truncation_set_sync() + self._write_to_disk_sync(line) + self._truncation_unset_sync() + + async def _reload_if_oversized(self, /) -> None: + if self.__file.tell() > 2 * self.__initial_size: + await self._reload() + + async def _handle_request(self, request: Request, /) -> None: + if isinstance(request, KVRequest): await self._write(request.line(), request) elif isinstance(request, DumpRequest): await self._dump_buffer() request.set_result(None) - elif isinstance(request, ErrorRequest): - await self._save_error(request.line) elif isinstance(request, TransactionRequest): await self._write(request.line(), request) else: - raise UnkownRequestType + raise UnknownRequestType - async def _background_cycle(self): + async def _background_cycle(self, /) -> None: request: Request = await self.__queue.get() try: await self._handle_request(request) @@ -266,104 +344,115 @@ class DbConnection: finally: self.__queue.task_done() - async def _background_task(self): + async def _background_task(self, /) -> None: while True: await self._background_cycle() - def _copy_sync(self): - db = {} - with open(self.__path_backup, 'r') as file: - self.io2db(file, db) - with open(self.__path, 'w') as file: - self.db2io(db, file) + def _start_task(self, /) -> None: + self.__task = self.__loop.create_task(self._background_task()) - def _finish_recovery_sync(self): - self._copy_sync() + def _recovery_set_sync(self, db: dict, /) -> None: + self.__initial_size = self._db2path_sync(db, self.__path_backup) + self.__path_recover.touch() + + def _recovery_unset_sync(self, /) -> None: self.__path_recover.unlink() self.__path_backup.unlink() - def _build_file_sync(self, db: dict): - with open(self.__path_backup, "w") as file: - self.__initial_size = self.db2io(db, file) - self.__path_recover.touch() - self._finish_recovery_sync() + def _copy_sync(self, db: Optional[dict], /) -> None: + if db is None: + db = {} + self._path2db_sync(self.__path_backup, db) + self._db2path_sync(db, self.__path) - def _run_in_thread(self, fn, *args, **kwargs) -> asyncio.Future: - future = self.__loop.create_future() + def _finish_recovery_sync(self, db: Optional[dict], /) -> None: + self._copy_sync(db) + self._recovery_unset_sync() - def wrap(): + def _build_file_sync(self, db: dict, /) -> None: + self._recovery_set_sync(db) + self._finish_recovery_sync(db) + + def _run_in_thread(self, fn, /, *args, **kwargs) -> asyncio.Future: + """we are using our own thread to guarantee as much of autonomy and control as possible. +intended for heavy tasks.""" + future = self._create_future() + + def wrap() -> None: try: result = fn(*args, **kwargs) except Exception as exception: - self.__loop.call_soon_threadsafe(future.set_exception, exception) + self.__loop.call_soon_threadsafe( + future.set_exception, exception + ) else: - self.__loop.call_soon_threadsafe(future.set_result, result) + self.__loop.call_soon_threadsafe( + future.set_result, result + ) threading.Thread(target=wrap).start() return future - async def _build_file(self, db: dict): + async def _build_file(self, db: dict, /) -> None: await self._run_in_thread(self._build_file_sync, db) - def _rebuild_file_sync(self, db: dict): + def _rebuild_file_sync(self, db: dict, /) -> None: if self.__path_recover.exists(): - self._finish_recovery_sync() - self.__path.touch() - with open(self.__path, 'r') as file: - self.io2db(file, db) + self._finish_recovery_sync(None) + self._path2db_sync(self.__path, db) self._build_file_sync(db) - async def _rebuild_file(self, db: dict): - await self._run_in_thread(self._rebuild_file_sync, db) + def _file_open(self, /) -> None: + self.__file = self.__path.open('a') - async def _reload(self): + def _reload_sync(self, /) -> None: self.__file.close() del self.__file - await self._rebuild_file({}) - self.__file = open(self.__path, "a") + self._rebuild_file_sync({}) + self._file_open() - async def _load_from_file(self): - await self._rebuild_file(self.__mmdb) + async def _reload(self, /) -> None: + await self._run_in_thread(self._reload_sync) - def _assure_truncation(self): - if self.__path_truncate_flag.exists(): - pos = int.from_bytes(self.__path_truncate.read_bytes(), 'little') - with open(self.__path, 'r+') as file: - file.seek(pos) - asyncio.run_coroutine_threadsafe(self._save_error(file.read()), self.__loop).result() - file.truncate(pos) - self.__path_truncate_flag.unlink(missing_ok=True) - self.__path_truncate.unlink(missing_ok=True) + def _load_mmdb_sync(self, /) -> dict: + db = {} + self._rebuild_file_sync(db) + return db - async def _initialize_mmdb(self): - self.__mmdb = {} - await self.__loop.run_in_executor(None, self._assure_truncation) - await self._load_from_file() - self.__file = open(self.__path, "a") + def _initialize_mmdb_sync(self, /) -> None: + self.__mmdb = self._load_mmdb_sync() - def _initialize_queue(self): + def _load_from_file_sync(self, /) -> None: + self._assure_truncation_sync() + self._initialize_mmdb_sync() + self._file_open() + + async def _load_from_file(self, /) -> None: + await self._run_in_thread(self._load_from_file_sync) + + def _initialize_queue(self, /) -> None: self.__queue = asyncio.Queue() self._clear_buffer() - async def _start_task(self): - self.__task = self.__loop.create_task(self._background_task()) - - async def _initialize(self): - assert self.__not_running - self.__not_running = False + async def _initialize_running(self, /) -> None: self.__loop = asyncio.get_running_loop() self._initialize_queue() - await self._initialize_mmdb() - await self._start_task() + await self._load_from_file() + self._start_task() + + async def _initialize(self, /) -> None: + assert self.__not_running + self.__not_running = False + await self._initialize_running() @classmethod - async def create(cls, factory: 'DbFactory') -> 'DbConnection': + async def create(cls, factory: 'DbFactory', /) -> 'DbConnection': dbconnection = DbConnection(factory) await dbconnection._initialize() return dbconnection - async def aclose(self): + async def aclose(self, /) -> None: if not self.__task.done(): await self.__queue.join() self.__task.cancel() @@ -383,27 +472,36 @@ class DbConnection: del self.__loop self.__not_running = True - async def complete_transaction(self, delta: dict) -> None: + async def complete_transaction(self, delta: dict, /) -> None: + if not delta: + return buffer = StringIO() self.db2io(delta, buffer) self.__mmdb.update(delta) - future = self.__loop.create_future() - self.__queue.put_nowait(TransactionRequest(buffer, future)) + future = self._create_future() + self.__queue.put_nowait(TransactionRequest(buffer, future=future)) await future - - async def commit(self) -> None: - future = self.__loop.create_future() + + async def commit(self, /) -> None: + future = self._create_future() self.__queue.put_nowait(DumpRequest(future)) await future - - def transaction(self) -> 'Transaction': + + def transaction(self, /) -> 'Transaction': return Transaction(self) class DbFactory: - def __init__(self, path: str | pathlib.Path, *, kvrequest_type: Type[KVRequest], buffersize=1048576): - self.path = pathlib.Path(path) - self.kvrequest_type = kvrequest_type + __slots__ = ( + 'path', + 'kvfactory', + 'buffersize', + 'db', + ) + + def __init__(self, path: pathlib.Path, /, *, kvfactory: KVFactory, buffersize=1048576) -> None: + self.path = path + self.kvfactory = kvfactory self.buffersize = buffersize async def __aenter__(self) -> DbConnection: @@ -414,10 +512,16 @@ class DbFactory: await self.db.aclose() -class Db(DbFactory, DbConnection): - def __init__(self, path: str | pathlib.Path, *, kvrequest_type: Type[KVRequest], buffersize=1048576): - DbFactory.__init__(self, path, kvrequest_type=kvrequest_type, buffersize=buffersize) - DbConnection.__init__(self, self) +class Db(DbConnection): + __slots__ = () + + def __init__(self, path: str | pathlib.Path, /, *, kvfactory: KVFactory, buffersize=1048576): + DbConnection.__init__( + self, + DbFactory( + pathlib.Path(path), kvfactory=kvfactory, buffersize=buffersize + ) + ) async def __aenter__(self) -> DbConnection: await self._initialize() @@ -428,22 +532,28 @@ class Db(DbFactory, DbConnection): class FallbackMapping: - def __init__(self, delta: dict, connection: DbConnection) -> None: + __slots__ = ( + '__delta', + '__shadow', + '__connection', + ) + + def __init__(self, delta: dict, connection: DbConnection, /) -> None: self.__delta = delta self.__shadow = {} self.__connection = connection - - def get(self, key: Any, default: Any): + + def get(self, key: Any, default: Any, /): if key in self.__delta: return self.__delta[key] if key in self.__shadow: return self.__shadow[key] return self.__connection.get(key, default) - - def set_nowait(self, key: Any, value: Any) -> None: + + def set_nowait(self, key: Any, value: Any, /) -> None: self.__delta[key] = value - - async def commit(self) -> None: + + async def commit(self, /) -> None: delta = self.__delta.copy() self.__shadow |= delta self.__delta.clear() @@ -451,16 +561,21 @@ class FallbackMapping: class Transaction: - delta: dict + __slots__ = ( + '__connection', + '__delta', + ) - def __init__(self, connection: DbConnection) -> None: - self.connection = connection + __delta: dict + + def __init__(self, connection: DbConnection, /) -> None: + self.__connection = connection async def __aenter__(self) -> FallbackMapping: - self.delta = {} - return FallbackMapping(self.delta, self.connection) + self.__delta = {} + return FallbackMapping(self.__delta, self.__connection) async def __aexit__(self, exc_type, exc_val, exc_tb): if exc_type is None: - await self.connection.complete_transaction(self.delta) - del self.delta + await self.__connection.complete_transaction(self.__delta) + del self.__delta diff --git a/setup.py b/setup.py index 700de63..01a1372 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup setup( name='ptvp35', - version='1.0rc0', + version='1.0rc1', packages=['ptvp35'], url='https://gitea.ongoteam.net/PTV/ptvp35', license='',