diff --git a/ptvp35/__init__.py b/ptvp35/__init__.py index bb623d2..908b221 100644 --- a/ptvp35/__init__.py +++ b/ptvp35/__init__.py @@ -14,15 +14,22 @@ __all__ = ('KVRequest', 'KVJson', 'DbConnection', 'DbFactory', 'Db',) class Request: def __init__(self, future: Optional[asyncio.Future]): - self.future = future + self.__future = future + + def waiting(self) -> bool: + return self.__future is not None def set_result(self, result): - if self.future: - self.future.set_result(result) + if self.__future is not None: + self.__future.set_result(result) def set_exception(self, exception): - if self.future: - self.future.set_exception(exception) + if self.__future is not None: + self.__future.set_exception(exception) + + async def wait(self): + if self.__future is not None: + await self.__future class KVRequest(Request): @@ -85,10 +92,6 @@ class ErrorRequest(Request): super().__init__(future) self.line = line - async def wait(self): - if self.future: - await self.future - class TransactionRequest(Request): def __init__(self, buffer: StringIO, future: Optional[asyncio.Future]): @@ -105,6 +108,8 @@ class DbConnection: __queue: asyncio.Queue[Request] __file: IO[str] __buffer: StringIO + __buffer_future: asyncio.Future + __buffer_requested: bool __task: asyncio.Future __initial_size: int @@ -119,7 +124,8 @@ class DbConnection: self.__path_recover = path.with_name(name + '.recover') self.__path_error = path.with_name(name + '.error') self.__path_truncate = path.with_name(name + '.truncate') - self.not_running = True + self.__path_truncate_flag = path.with_name(name + '.truncate_flag') + self.__not_running = True def _queue_error(self, line: str): request = ErrorRequest(line, self.__loop.create_future()) @@ -157,12 +163,19 @@ class DbConnection: self.__mmdb[key] = value self.__queue.put_nowait(self.factory.kvrequest_type(key, value, None)) + def _request_buffer(self, request: Request): + if request.waiting(): + self.__buffer_future.add_done_callback(lambda _: request.set_result(None)) + if not self.__buffer_requested: + self.__buffer_requested = True + self.__queue.put_nowait(DumpRequest(None)) + async def _dump_buffer_or_request_so(self, request: Request): if self.__buffer.tell() >= self.factory.buffersize: await self._dump_buffer() request.set_result(None) else: - self.__queue.put_nowait(DumpRequest(request.future)) + self._request_buffer(request) async def _write(self, line: str, request: Request): self.__buffer.write(line) @@ -170,6 +183,8 @@ class DbConnection: def _clear_buffer(self): self.__buffer = StringIO() + self.__buffer_future = self.__loop.create_future() + self.__buffer_requested = False def _compress_buffer(self) -> StringIO: self.__buffer.seek(0) @@ -179,12 +194,23 @@ class DbConnection: self.db2io(bufferdb, buffer) return buffer - def _file_write(self, line: str): + def _file_truncate_set(self): self.__path_truncate.write_bytes(self.__file.tell().to_bytes(16, 'little')) + self.__path_truncate_flag.touch() + + def _file_truncate_unset(self): + self.__path_truncate_flag.unlink(missing_ok=True) + self.__path_truncate.unlink(missing_ok=True) + + def _write_to_disk(self, line: str): self.__file.write(line) self.__file.flush() os.fsync(self.__file.fileno()) - self.__path_truncate.unlink(missing_ok=True) + + def _file_write(self, line: str): + self._file_truncate_set() + self._write_to_disk(line) + self._file_truncate_unset() def _dump_compressed_buffer_sync(self): self._file_write(self._compress_buffer().getvalue()) @@ -192,9 +218,13 @@ class DbConnection: async def _dump_compressed_buffer(self): await self.__loop.run_in_executor(None, self._dump_compressed_buffer_sync) + def _satisfy_buffer_future(self): + self.__buffer_future.set_result(None) + self._clear_buffer() + async def _do_dump_buffer(self): await self._dump_compressed_buffer() - self._clear_buffer() + self._satisfy_buffer_future() async def _reload_if_oversized(self): if self.__file.tell() > 2 * self.__initial_size: @@ -204,6 +234,8 @@ class DbConnection: if self.__buffer.tell(): await self._do_dump_buffer() await self._reload_if_oversized() + elif self.__buffer_requested: + self._satisfy_buffer_future() def _save_error_sync(self, line: str): with open(self.__path_error, 'a') as file: @@ -240,7 +272,7 @@ class DbConnection: def _copy_sync(self): db = {} - with open(self.__path_backup) as file: + with open(self.__path_backup, 'r') as file: self.io2db(file, db) with open(self.__path, 'w') as file: self.db2io(db, file) @@ -278,7 +310,7 @@ class DbConnection: if self.__path_recover.exists(): self._finish_recovery_sync() self.__path.touch() - with open(self.__path) as file: + with open(self.__path, 'r') as file: self.io2db(file, db) self._build_file_sync(db) @@ -287,6 +319,7 @@ class DbConnection: async def _reload(self): self.__file.close() + del self.__file await self._rebuild_file({}) self.__file = open(self.__path, "a") @@ -294,12 +327,13 @@ class DbConnection: await self._rebuild_file(self.__mmdb) def _assure_truncation(self): - if self.__path_truncate.exists(): + if self.__path_truncate_flag.exists(): pos = int.from_bytes(self.__path_truncate.read_bytes(), 'little') with open(self.__path, 'r+') as file: file.seek(pos) asyncio.run_coroutine_threadsafe(self._save_error(file.read()), self.__loop).result() file.truncate(pos) + self.__path_truncate_flag.unlink(missing_ok=True) self.__path_truncate.unlink(missing_ok=True) async def _initialize_mmdb(self): @@ -308,20 +342,20 @@ class DbConnection: await self._load_from_file() self.__file = open(self.__path, "a") - async def _initialize_queue(self): + def _initialize_queue(self): self.__queue = asyncio.Queue() - self.__buffer = StringIO() + self._clear_buffer() async def _start_task(self): self.__task = self.__loop.create_task(self._background_task()) async def _initialize(self): - assert self.not_running - self.__loop = asyncio.get_event_loop() - await self._initialize_queue() + assert self.__not_running + self.__not_running = False + self.__loop = asyncio.get_running_loop() + self._initialize_queue() await self._initialize_mmdb() await self._start_task() - self.not_running = False @classmethod async def create(cls, factory: 'DbFactory') -> 'DbConnection': @@ -334,18 +368,22 @@ class DbConnection: await self.__queue.join() self.__task.cancel() await self._dump_buffer() + if not self.__buffer_future.done(): + self.__buffer_future.cancel() self.__file.close() await self._build_file(self.__mmdb) - self.not_running = True - del self.__mmdb - del self.__loop - del self.__queue - del self.__file - del self.__buffer del self.__task + del self.__file del self.__initial_size + del self.__mmdb + del self.__buffer_requested + del self.__buffer_future + del self.__buffer + del self.__queue + del self.__loop + self.__not_running = True - async def complete_transaction(self, delta: dict): + async def complete_transaction(self, delta: dict) -> None: buffer = StringIO() self.db2io(delta, buffer) self.__mmdb.update(delta) @@ -353,7 +391,7 @@ class DbConnection: self.__queue.put_nowait(TransactionRequest(buffer, future)) await future - async def commit(self): + async def commit(self) -> None: future = self.__loop.create_future() self.__queue.put_nowait(DumpRequest(future)) await future @@ -392,13 +430,24 @@ class Db(DbFactory, DbConnection): class FallbackMapping: def __init__(self, delta: dict, connection: DbConnection) -> None: self.__delta = delta + self.__shadow = {} self.__connection = connection def get(self, key: Any, default: Any): - return self.__delta.get(key, self.__connection.get(key, default)) + if key in self.__delta: + return self.__delta[key] + if key in self.__shadow: + return self.__shadow[key] + return self.__connection.get(key, default) - def set_nowait(self, key: Any, value: Any): + def set_nowait(self, key: Any, value: Any) -> None: self.__delta[key] = value + + async def commit(self) -> None: + delta = self.__delta.copy() + self.__shadow |= delta + self.__delta.clear() + await self.__connection.complete_transaction(delta) class Transaction: diff --git a/setup.py b/setup.py index 863c09c..700de63 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup setup( name='ptvp35', - version='', + version='1.0rc0', packages=['ptvp35'], url='https://gitea.ongoteam.net/PTV/ptvp35', license='',