1.0rc1: call restrictions + kvfactory + slots
This commit is contained in:
parent
85a6bc0301
commit
f52bad680c
@ -1,77 +1,93 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import pickle
|
||||
import threading
|
||||
import traceback
|
||||
from io import StringIO
|
||||
from typing import Any, Optional, IO, Type, Hashable
|
||||
from io import StringIO, UnsupportedOperation
|
||||
from typing import Any, Optional, IO, Hashable
|
||||
|
||||
|
||||
__all__ = ('KVRequest', 'KVJson', 'DbConnection', 'DbFactory', 'Db',)
|
||||
|
||||
|
||||
class Request:
|
||||
def __init__(self, future: Optional[asyncio.Future]):
|
||||
__slots__ = (
|
||||
'__future',
|
||||
)
|
||||
|
||||
def __init__(self, future: Optional[asyncio.Future], /) -> None:
|
||||
self.__future = future
|
||||
|
||||
def waiting(self) -> bool:
|
||||
def waiting(self, /) -> bool:
|
||||
return self.__future is not None
|
||||
|
||||
def set_result(self, result):
|
||||
def set_result(self, result, /) -> None:
|
||||
if self.__future is not None:
|
||||
self.__future.set_result(result)
|
||||
|
||||
def set_exception(self, exception):
|
||||
def set_exception(self, exception, /):
|
||||
if self.__future is not None:
|
||||
self.__future.set_exception(exception)
|
||||
|
||||
async def wait(self):
|
||||
async def wait(self, /):
|
||||
if self.__future is not None:
|
||||
await self.__future
|
||||
|
||||
|
||||
class KVFactory:
|
||||
__slots__ = ()
|
||||
|
||||
def line(self, key: Any, value: Any, /) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def fromline(self, line: str, /) -> 'KVRequest':
|
||||
raise NotImplementedError
|
||||
|
||||
def request(self, key: Any, value: Any, /, *, future: Optional[asyncio.Future]) -> 'KVRequest':
|
||||
return KVRequest(key, value, future=future, factory=self)
|
||||
|
||||
def free(self, key: Any, value: Any, /) -> 'KVRequest':
|
||||
return self.request(key, value, future=None)
|
||||
|
||||
|
||||
class KVRequest(Request):
|
||||
def __init__(self, key: Any, value: Any, future: Optional[asyncio.Future]):
|
||||
__slots__ = (
|
||||
'__factory',
|
||||
'key',
|
||||
'value',
|
||||
)
|
||||
|
||||
def __init__(self, key: Any, value: Any, /, *, future: Optional[asyncio.Future], factory: KVFactory):
|
||||
super().__init__(future)
|
||||
self.__factory = factory
|
||||
self.key = key
|
||||
self.value = value
|
||||
|
||||
def free(self):
|
||||
return type(self)(self.key, self.value, None)
|
||||
def free(self, /):
|
||||
return self.__factory.free(self.key, self.value)
|
||||
|
||||
def line(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def fromline(cls, line: str) -> 'KVRequest':
|
||||
raise NotImplementedError
|
||||
def line(self, /) -> str:
|
||||
return self.__factory.line(self.key, self.value)
|
||||
|
||||
|
||||
class DumpRequest(Request):
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class UnknownRequestType(TypeError):
|
||||
pass
|
||||
|
||||
|
||||
class UnkownRequestType(TypeError):
|
||||
pass
|
||||
class KVJson(KVFactory):
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class KVPickle(KVRequest):
|
||||
def line(self) -> str:
|
||||
return pickle.dumps(self.free()).hex() + "\n"
|
||||
def line(self, key: Any, value: Any, /) -> str:
|
||||
return json.dumps({'key': key, 'value': value}) + '\n'
|
||||
|
||||
@classmethod
|
||||
def fromline(cls, line: str) -> 'KVPickle':
|
||||
return pickle.loads(bytes.fromhex(line.strip()))
|
||||
|
||||
|
||||
class KVJson(KVRequest):
|
||||
def line(self) -> str:
|
||||
return json.dumps({'key': self.key, 'value': self.value}) + "\n"
|
||||
|
||||
@classmethod
|
||||
def _load_key(cls, key: Any) -> Hashable:
|
||||
def _load_key(cls, key: Any, /) -> Hashable:
|
||||
if isinstance(key, Hashable):
|
||||
return key
|
||||
elif isinstance(key, list):
|
||||
@ -79,30 +95,49 @@ class KVJson(KVRequest):
|
||||
elif isinstance(key, dict):
|
||||
return tuple((cls._load_key(k), cls._load_key(v)) for k, v in key.items())
|
||||
else:
|
||||
raise TypeError("unknown KVJson key type, cannot convert to hashable")
|
||||
raise TypeError(
|
||||
'unknown KVJson key type, cannot convert to hashable'
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def fromline(cls, line: str) -> 'KVJson':
|
||||
def fromline(self, line: str, /) -> 'KVRequest':
|
||||
d = json.loads(line)
|
||||
return KVJson(cls._load_key(d['key']), d['value'], None)
|
||||
|
||||
|
||||
class ErrorRequest(Request):
|
||||
def __init__(self, line: str, future: Optional[asyncio.Future]):
|
||||
super().__init__(future)
|
||||
self.line = line
|
||||
return self.free(self._load_key(d['key']), d['value'])
|
||||
|
||||
|
||||
class TransactionRequest(Request):
|
||||
def __init__(self, buffer: StringIO, future: Optional[asyncio.Future]):
|
||||
__slots__ = (
|
||||
'buffer',
|
||||
)
|
||||
|
||||
def __init__(self, buffer: StringIO, /, *, future: Optional[asyncio.Future]):
|
||||
super().__init__(future)
|
||||
self.buffer = buffer
|
||||
|
||||
def line(self) -> str:
|
||||
def line(self, /) -> str:
|
||||
return self.buffer.getvalue()
|
||||
|
||||
|
||||
class DbConnection:
|
||||
__slots__ = (
|
||||
'__factory',
|
||||
'__path',
|
||||
'__path_backup',
|
||||
'__path_recover',
|
||||
'__path_error',
|
||||
'__path_truncate',
|
||||
'__path_truncate_flag',
|
||||
'__not_running',
|
||||
'__mmdb',
|
||||
'__loop',
|
||||
'__queue',
|
||||
'__file',
|
||||
'__buffer',
|
||||
'__buffer_future',
|
||||
'__buffer_requested',
|
||||
'__task',
|
||||
'__initial_size',
|
||||
)
|
||||
|
||||
__mmdb: dict
|
||||
__loop: asyncio.AbstractEventLoop
|
||||
__queue: asyncio.Queue[Request]
|
||||
@ -116,9 +151,10 @@ class DbConnection:
|
||||
def __init__(
|
||||
self,
|
||||
factory: 'DbFactory',
|
||||
/
|
||||
) -> None:
|
||||
self.factory = factory
|
||||
self.__path = path = self.factory.path
|
||||
self.__factory = factory
|
||||
self.__path = path = self.__factory.path
|
||||
name = self.__path.name
|
||||
self.__path_backup = path.with_name(name + '.backup')
|
||||
self.__path_recover = path.with_name(name + '.recover')
|
||||
@ -127,66 +163,73 @@ class DbConnection:
|
||||
self.__path_truncate_flag = path.with_name(name + '.truncate_flag')
|
||||
self.__not_running = True
|
||||
|
||||
def _queue_error(self, line: str):
|
||||
request = ErrorRequest(line, self.__loop.create_future())
|
||||
self.__queue.put_nowait(request)
|
||||
self.__loop.create_task(request.wait())
|
||||
def _create_future(self, /) -> asyncio.Future:
|
||||
return self.__loop.create_future()
|
||||
|
||||
def io2db(self, io: IO[str], db: dict) -> int:
|
||||
def _save_error_sync(self, line: str, /) -> None:
|
||||
with self.__path_error.open('a') as file:
|
||||
file.write(line.strip() + '\n')
|
||||
|
||||
async def _save_error(self, line: str, /) -> None:
|
||||
await self.__loop.run_in_executor(None, self._save_error_sync, line)
|
||||
|
||||
def _queue_error(self, line: str, /) -> concurrent.futures.Future:
|
||||
return asyncio.run_coroutine_threadsafe(
|
||||
self._save_error(line), self.__loop
|
||||
)
|
||||
|
||||
def _save_error_from_thread(self, line: str, /) -> None:
|
||||
self._queue_error(line).result()
|
||||
|
||||
def io2db(self, io: IO[str], db: dict, /) -> int:
|
||||
"""there are no guarantees about .error file if error occurs here"""
|
||||
size = 0
|
||||
for line in io:
|
||||
try:
|
||||
request = self.factory.kvrequest_type.fromline(line)
|
||||
request = self.__factory.kvfactory.fromline(line)
|
||||
db[request.key] = request.value
|
||||
size += len(line)
|
||||
except (json.JSONDecodeError, pickle.UnpicklingError, EOFError):
|
||||
except (json.JSONDecodeError, EOFError):
|
||||
traceback.print_exc()
|
||||
# this condition should never occur, but we should be able to handle this UB as best as we can
|
||||
self._queue_error(line)
|
||||
return size
|
||||
|
||||
def db2io(self, db: dict, io: IO[str]) -> int:
|
||||
def db2io(self, db: dict, io: IO[str], /) -> int:
|
||||
size = 0
|
||||
for key, value in db.items():
|
||||
size += io.write(self.factory.kvrequest_type(key, value, None).line())
|
||||
size += io.write(self.__factory.kvfactory.free(key, value).line())
|
||||
return size
|
||||
|
||||
def get(self, key: Any, default: Any):
|
||||
def _path2db_sync(self, path: pathlib.Path, db: dict, /) -> int:
|
||||
path.touch()
|
||||
with path.open('r') as file:
|
||||
return self.io2db(file, db)
|
||||
|
||||
def _db2path_sync(self, db: dict, path: pathlib.Path, /) -> int:
|
||||
with path.open('w') as file:
|
||||
return self.db2io(db, file)
|
||||
|
||||
def get(self, key: Any, default: Any, /):
|
||||
return self.__mmdb.get(key, default)
|
||||
|
||||
async def set(self, key: Any, value: Any):
|
||||
async def set(self, key: Any, value: Any, /) -> None:
|
||||
self.__mmdb[key] = value
|
||||
future = self.__loop.create_future()
|
||||
self.__queue.put_nowait(self.factory.kvrequest_type(key, value, future))
|
||||
future = self._create_future()
|
||||
self.__queue.put_nowait(
|
||||
self.__factory.kvfactory.request(key, value, future=future))
|
||||
await future
|
||||
|
||||
def set_nowait(self, key: Any, value: Any):
|
||||
def set_nowait(self, key: Any, value: Any, /) -> None:
|
||||
self.__mmdb[key] = value
|
||||
self.__queue.put_nowait(self.factory.kvrequest_type(key, value, None))
|
||||
self.__queue.put_nowait(self.__factory.kvfactory.free(key, value))
|
||||
|
||||
def _request_buffer(self, request: Request):
|
||||
if request.waiting():
|
||||
self.__buffer_future.add_done_callback(lambda _: request.set_result(None))
|
||||
if not self.__buffer_requested:
|
||||
self.__buffer_requested = True
|
||||
self.__queue.put_nowait(DumpRequest(None))
|
||||
|
||||
async def _dump_buffer_or_request_so(self, request: Request):
|
||||
if self.__buffer.tell() >= self.factory.buffersize:
|
||||
await self._dump_buffer()
|
||||
request.set_result(None)
|
||||
else:
|
||||
self._request_buffer(request)
|
||||
|
||||
async def _write(self, line: str, request: Request):
|
||||
self.__buffer.write(line)
|
||||
await self._dump_buffer_or_request_so(request)
|
||||
|
||||
def _clear_buffer(self):
|
||||
def _clear_buffer(self, /) -> None:
|
||||
self.__buffer = StringIO()
|
||||
self.__buffer_future = self.__loop.create_future()
|
||||
self.__buffer_future = self._create_future()
|
||||
self.__buffer_requested = False
|
||||
|
||||
def _compress_buffer(self) -> StringIO:
|
||||
def _compress_buffer(self, /) -> StringIO:
|
||||
self.__buffer.seek(0)
|
||||
bufferdb = {}
|
||||
self.io2db(self.__buffer, bufferdb)
|
||||
@ -194,70 +237,105 @@ class DbConnection:
|
||||
self.db2io(bufferdb, buffer)
|
||||
return buffer
|
||||
|
||||
def _file_truncate_set(self):
|
||||
self.__path_truncate.write_bytes(self.__file.tell().to_bytes(16, 'little'))
|
||||
self.__path_truncate_flag.touch()
|
||||
def _dump_compressed_buffer_sync(self, /) -> None:
|
||||
self._file_write_sync(self._compress_buffer().getvalue())
|
||||
|
||||
def _file_truncate_unset(self):
|
||||
self.__path_truncate_flag.unlink(missing_ok=True)
|
||||
self.__path_truncate.unlink(missing_ok=True)
|
||||
|
||||
def _write_to_disk(self, line: str):
|
||||
self.__file.write(line)
|
||||
self.__file.flush()
|
||||
os.fsync(self.__file.fileno())
|
||||
|
||||
def _file_write(self, line: str):
|
||||
self._file_truncate_set()
|
||||
self._write_to_disk(line)
|
||||
self._file_truncate_unset()
|
||||
|
||||
def _dump_compressed_buffer_sync(self):
|
||||
self._file_write(self._compress_buffer().getvalue())
|
||||
|
||||
async def _dump_compressed_buffer(self):
|
||||
async def _dump_compressed_buffer(self, /) -> None:
|
||||
await self.__loop.run_in_executor(None, self._dump_compressed_buffer_sync)
|
||||
|
||||
def _satisfy_buffer_future(self):
|
||||
def _satisfy_buffer_future(self, /) -> None:
|
||||
self.__buffer_future.set_result(None)
|
||||
self._clear_buffer()
|
||||
|
||||
async def _do_dump_buffer(self):
|
||||
async def _do_dump_buffer(self, /) -> None:
|
||||
await self._dump_compressed_buffer()
|
||||
self._satisfy_buffer_future()
|
||||
|
||||
async def _reload_if_oversized(self):
|
||||
if self.__file.tell() > 2 * self.__initial_size:
|
||||
await self._reload()
|
||||
|
||||
async def _dump_buffer(self):
|
||||
async def _dump_buffer(self, /) -> None:
|
||||
if self.__buffer.tell():
|
||||
await self._do_dump_buffer()
|
||||
await self._reload_if_oversized()
|
||||
elif self.__buffer_requested:
|
||||
self._satisfy_buffer_future()
|
||||
|
||||
def _save_error_sync(self, line: str):
|
||||
with open(self.__path_error, 'a') as file:
|
||||
file.write(line.strip() + '\n')
|
||||
def _request_buffer(self, request: Request, /) -> None:
|
||||
if request.waiting():
|
||||
self.__buffer_future.exception
|
||||
self.__buffer_future.add_done_callback(
|
||||
lambda bf: request.set_exception(e) if (
|
||||
e := bf.exception()) is not None else request.set_result(None)
|
||||
)
|
||||
if not self.__buffer_requested:
|
||||
self.__buffer_requested = True
|
||||
self.__queue.put_nowait(DumpRequest(None))
|
||||
|
||||
async def _save_error(self, line: str):
|
||||
await self.__loop.run_in_executor(None, self._save_error_sync, line)
|
||||
async def _dump_buffer_or_request_so(self, request: Request, /) -> None:
|
||||
if self.__buffer.tell() >= self.__factory.buffersize:
|
||||
await self._dump_buffer()
|
||||
request.set_result(None)
|
||||
else:
|
||||
self._request_buffer(request)
|
||||
|
||||
async def _handle_request(self, request: Request):
|
||||
if isinstance(request, self.factory.kvrequest_type):
|
||||
async def _write(self, line: str, request: Request, /) -> None:
|
||||
self.__buffer.write(line)
|
||||
await self._dump_buffer_or_request_so(request)
|
||||
|
||||
def _truncation_set_sync(self, /) -> None:
|
||||
self.__path_truncate.write_bytes(
|
||||
self.__file.tell().to_bytes(16, 'little')
|
||||
)
|
||||
self.__path_truncate_flag.touch()
|
||||
|
||||
def _truncation_unset_sync(self, /) -> None:
|
||||
self.__path_truncate_flag.unlink(missing_ok=True)
|
||||
self.__path_truncate.unlink(missing_ok=True)
|
||||
|
||||
def _file_truncate_sync(self, file: IO[str], pos: int, /) -> None:
|
||||
file.seek(pos)
|
||||
self._save_error_from_thread(file.read())
|
||||
file.truncate(pos)
|
||||
|
||||
def _truncation_target_sync(self, /) -> int:
|
||||
return int.from_bytes(self.__path_truncate.read_bytes(), 'little')
|
||||
|
||||
def _truncate_sync(self, /) -> None:
|
||||
with self.__path.open('r+') as file:
|
||||
self._file_truncate_sync(file, self._truncation_target_sync())
|
||||
|
||||
def _assure_truncation_sync(self, /) -> None:
|
||||
if self.__path_truncate_flag.exists():
|
||||
self._truncate_sync()
|
||||
self._truncation_unset_sync()
|
||||
|
||||
def _write_to_disk_sync(self, line: str, /) -> None:
|
||||
self.__file.write(line)
|
||||
self.__file.flush()
|
||||
try:
|
||||
os.fsync(self.__file.fileno())
|
||||
except UnsupportedOperation:
|
||||
pass
|
||||
|
||||
def _file_write_sync(self, line: str, /) -> None:
|
||||
self._truncation_set_sync()
|
||||
self._write_to_disk_sync(line)
|
||||
self._truncation_unset_sync()
|
||||
|
||||
async def _reload_if_oversized(self, /) -> None:
|
||||
if self.__file.tell() > 2 * self.__initial_size:
|
||||
await self._reload()
|
||||
|
||||
async def _handle_request(self, request: Request, /) -> None:
|
||||
if isinstance(request, KVRequest):
|
||||
await self._write(request.line(), request)
|
||||
elif isinstance(request, DumpRequest):
|
||||
await self._dump_buffer()
|
||||
request.set_result(None)
|
||||
elif isinstance(request, ErrorRequest):
|
||||
await self._save_error(request.line)
|
||||
elif isinstance(request, TransactionRequest):
|
||||
await self._write(request.line(), request)
|
||||
else:
|
||||
raise UnkownRequestType
|
||||
raise UnknownRequestType
|
||||
|
||||
async def _background_cycle(self):
|
||||
async def _background_cycle(self, /) -> None:
|
||||
request: Request = await self.__queue.get()
|
||||
try:
|
||||
await self._handle_request(request)
|
||||
@ -266,104 +344,115 @@ class DbConnection:
|
||||
finally:
|
||||
self.__queue.task_done()
|
||||
|
||||
async def _background_task(self):
|
||||
async def _background_task(self, /) -> None:
|
||||
while True:
|
||||
await self._background_cycle()
|
||||
|
||||
def _copy_sync(self):
|
||||
db = {}
|
||||
with open(self.__path_backup, 'r') as file:
|
||||
self.io2db(file, db)
|
||||
with open(self.__path, 'w') as file:
|
||||
self.db2io(db, file)
|
||||
def _start_task(self, /) -> None:
|
||||
self.__task = self.__loop.create_task(self._background_task())
|
||||
|
||||
def _finish_recovery_sync(self):
|
||||
self._copy_sync()
|
||||
def _recovery_set_sync(self, db: dict, /) -> None:
|
||||
self.__initial_size = self._db2path_sync(db, self.__path_backup)
|
||||
self.__path_recover.touch()
|
||||
|
||||
def _recovery_unset_sync(self, /) -> None:
|
||||
self.__path_recover.unlink()
|
||||
self.__path_backup.unlink()
|
||||
|
||||
def _build_file_sync(self, db: dict):
|
||||
with open(self.__path_backup, "w") as file:
|
||||
self.__initial_size = self.db2io(db, file)
|
||||
self.__path_recover.touch()
|
||||
self._finish_recovery_sync()
|
||||
def _copy_sync(self, db: Optional[dict], /) -> None:
|
||||
if db is None:
|
||||
db = {}
|
||||
self._path2db_sync(self.__path_backup, db)
|
||||
self._db2path_sync(db, self.__path)
|
||||
|
||||
def _run_in_thread(self, fn, *args, **kwargs) -> asyncio.Future:
|
||||
future = self.__loop.create_future()
|
||||
def _finish_recovery_sync(self, db: Optional[dict], /) -> None:
|
||||
self._copy_sync(db)
|
||||
self._recovery_unset_sync()
|
||||
|
||||
def wrap():
|
||||
def _build_file_sync(self, db: dict, /) -> None:
|
||||
self._recovery_set_sync(db)
|
||||
self._finish_recovery_sync(db)
|
||||
|
||||
def _run_in_thread(self, fn, /, *args, **kwargs) -> asyncio.Future:
|
||||
"""we are using our own thread to guarantee as much of autonomy and control as possible.
|
||||
intended for heavy tasks."""
|
||||
future = self._create_future()
|
||||
|
||||
def wrap() -> None:
|
||||
try:
|
||||
result = fn(*args, **kwargs)
|
||||
except Exception as exception:
|
||||
self.__loop.call_soon_threadsafe(future.set_exception, exception)
|
||||
self.__loop.call_soon_threadsafe(
|
||||
future.set_exception, exception
|
||||
)
|
||||
else:
|
||||
self.__loop.call_soon_threadsafe(future.set_result, result)
|
||||
self.__loop.call_soon_threadsafe(
|
||||
future.set_result, result
|
||||
)
|
||||
|
||||
threading.Thread(target=wrap).start()
|
||||
|
||||
return future
|
||||
|
||||
async def _build_file(self, db: dict):
|
||||
async def _build_file(self, db: dict, /) -> None:
|
||||
await self._run_in_thread(self._build_file_sync, db)
|
||||
|
||||
def _rebuild_file_sync(self, db: dict):
|
||||
def _rebuild_file_sync(self, db: dict, /) -> None:
|
||||
if self.__path_recover.exists():
|
||||
self._finish_recovery_sync()
|
||||
self.__path.touch()
|
||||
with open(self.__path, 'r') as file:
|
||||
self.io2db(file, db)
|
||||
self._finish_recovery_sync(None)
|
||||
self._path2db_sync(self.__path, db)
|
||||
self._build_file_sync(db)
|
||||
|
||||
async def _rebuild_file(self, db: dict):
|
||||
await self._run_in_thread(self._rebuild_file_sync, db)
|
||||
def _file_open(self, /) -> None:
|
||||
self.__file = self.__path.open('a')
|
||||
|
||||
async def _reload(self):
|
||||
def _reload_sync(self, /) -> None:
|
||||
self.__file.close()
|
||||
del self.__file
|
||||
await self._rebuild_file({})
|
||||
self.__file = open(self.__path, "a")
|
||||
self._rebuild_file_sync({})
|
||||
self._file_open()
|
||||
|
||||
async def _load_from_file(self):
|
||||
await self._rebuild_file(self.__mmdb)
|
||||
async def _reload(self, /) -> None:
|
||||
await self._run_in_thread(self._reload_sync)
|
||||
|
||||
def _assure_truncation(self):
|
||||
if self.__path_truncate_flag.exists():
|
||||
pos = int.from_bytes(self.__path_truncate.read_bytes(), 'little')
|
||||
with open(self.__path, 'r+') as file:
|
||||
file.seek(pos)
|
||||
asyncio.run_coroutine_threadsafe(self._save_error(file.read()), self.__loop).result()
|
||||
file.truncate(pos)
|
||||
self.__path_truncate_flag.unlink(missing_ok=True)
|
||||
self.__path_truncate.unlink(missing_ok=True)
|
||||
def _load_mmdb_sync(self, /) -> dict:
|
||||
db = {}
|
||||
self._rebuild_file_sync(db)
|
||||
return db
|
||||
|
||||
async def _initialize_mmdb(self):
|
||||
self.__mmdb = {}
|
||||
await self.__loop.run_in_executor(None, self._assure_truncation)
|
||||
await self._load_from_file()
|
||||
self.__file = open(self.__path, "a")
|
||||
def _initialize_mmdb_sync(self, /) -> None:
|
||||
self.__mmdb = self._load_mmdb_sync()
|
||||
|
||||
def _initialize_queue(self):
|
||||
def _load_from_file_sync(self, /) -> None:
|
||||
self._assure_truncation_sync()
|
||||
self._initialize_mmdb_sync()
|
||||
self._file_open()
|
||||
|
||||
async def _load_from_file(self, /) -> None:
|
||||
await self._run_in_thread(self._load_from_file_sync)
|
||||
|
||||
def _initialize_queue(self, /) -> None:
|
||||
self.__queue = asyncio.Queue()
|
||||
self._clear_buffer()
|
||||
|
||||
async def _start_task(self):
|
||||
self.__task = self.__loop.create_task(self._background_task())
|
||||
|
||||
async def _initialize(self):
|
||||
assert self.__not_running
|
||||
self.__not_running = False
|
||||
async def _initialize_running(self, /) -> None:
|
||||
self.__loop = asyncio.get_running_loop()
|
||||
self._initialize_queue()
|
||||
await self._initialize_mmdb()
|
||||
await self._start_task()
|
||||
await self._load_from_file()
|
||||
self._start_task()
|
||||
|
||||
async def _initialize(self, /) -> None:
|
||||
assert self.__not_running
|
||||
self.__not_running = False
|
||||
await self._initialize_running()
|
||||
|
||||
@classmethod
|
||||
async def create(cls, factory: 'DbFactory') -> 'DbConnection':
|
||||
async def create(cls, factory: 'DbFactory', /) -> 'DbConnection':
|
||||
dbconnection = DbConnection(factory)
|
||||
await dbconnection._initialize()
|
||||
return dbconnection
|
||||
|
||||
async def aclose(self):
|
||||
async def aclose(self, /) -> None:
|
||||
if not self.__task.done():
|
||||
await self.__queue.join()
|
||||
self.__task.cancel()
|
||||
@ -383,27 +472,36 @@ class DbConnection:
|
||||
del self.__loop
|
||||
self.__not_running = True
|
||||
|
||||
async def complete_transaction(self, delta: dict) -> None:
|
||||
async def complete_transaction(self, delta: dict, /) -> None:
|
||||
if not delta:
|
||||
return
|
||||
buffer = StringIO()
|
||||
self.db2io(delta, buffer)
|
||||
self.__mmdb.update(delta)
|
||||
future = self.__loop.create_future()
|
||||
self.__queue.put_nowait(TransactionRequest(buffer, future))
|
||||
future = self._create_future()
|
||||
self.__queue.put_nowait(TransactionRequest(buffer, future=future))
|
||||
await future
|
||||
|
||||
async def commit(self) -> None:
|
||||
future = self.__loop.create_future()
|
||||
async def commit(self, /) -> None:
|
||||
future = self._create_future()
|
||||
self.__queue.put_nowait(DumpRequest(future))
|
||||
await future
|
||||
|
||||
def transaction(self) -> 'Transaction':
|
||||
def transaction(self, /) -> 'Transaction':
|
||||
return Transaction(self)
|
||||
|
||||
|
||||
class DbFactory:
|
||||
def __init__(self, path: str | pathlib.Path, *, kvrequest_type: Type[KVRequest], buffersize=1048576):
|
||||
self.path = pathlib.Path(path)
|
||||
self.kvrequest_type = kvrequest_type
|
||||
__slots__ = (
|
||||
'path',
|
||||
'kvfactory',
|
||||
'buffersize',
|
||||
'db',
|
||||
)
|
||||
|
||||
def __init__(self, path: pathlib.Path, /, *, kvfactory: KVFactory, buffersize=1048576) -> None:
|
||||
self.path = path
|
||||
self.kvfactory = kvfactory
|
||||
self.buffersize = buffersize
|
||||
|
||||
async def __aenter__(self) -> DbConnection:
|
||||
@ -414,10 +512,16 @@ class DbFactory:
|
||||
await self.db.aclose()
|
||||
|
||||
|
||||
class Db(DbFactory, DbConnection):
|
||||
def __init__(self, path: str | pathlib.Path, *, kvrequest_type: Type[KVRequest], buffersize=1048576):
|
||||
DbFactory.__init__(self, path, kvrequest_type=kvrequest_type, buffersize=buffersize)
|
||||
DbConnection.__init__(self, self)
|
||||
class Db(DbConnection):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, path: str | pathlib.Path, /, *, kvfactory: KVFactory, buffersize=1048576):
|
||||
DbConnection.__init__(
|
||||
self,
|
||||
DbFactory(
|
||||
pathlib.Path(path), kvfactory=kvfactory, buffersize=buffersize
|
||||
)
|
||||
)
|
||||
|
||||
async def __aenter__(self) -> DbConnection:
|
||||
await self._initialize()
|
||||
@ -428,22 +532,28 @@ class Db(DbFactory, DbConnection):
|
||||
|
||||
|
||||
class FallbackMapping:
|
||||
def __init__(self, delta: dict, connection: DbConnection) -> None:
|
||||
__slots__ = (
|
||||
'__delta',
|
||||
'__shadow',
|
||||
'__connection',
|
||||
)
|
||||
|
||||
def __init__(self, delta: dict, connection: DbConnection, /) -> None:
|
||||
self.__delta = delta
|
||||
self.__shadow = {}
|
||||
self.__connection = connection
|
||||
|
||||
def get(self, key: Any, default: Any):
|
||||
def get(self, key: Any, default: Any, /):
|
||||
if key in self.__delta:
|
||||
return self.__delta[key]
|
||||
if key in self.__shadow:
|
||||
return self.__shadow[key]
|
||||
return self.__connection.get(key, default)
|
||||
|
||||
def set_nowait(self, key: Any, value: Any) -> None:
|
||||
def set_nowait(self, key: Any, value: Any, /) -> None:
|
||||
self.__delta[key] = value
|
||||
|
||||
async def commit(self) -> None:
|
||||
async def commit(self, /) -> None:
|
||||
delta = self.__delta.copy()
|
||||
self.__shadow |= delta
|
||||
self.__delta.clear()
|
||||
@ -451,16 +561,21 @@ class FallbackMapping:
|
||||
|
||||
|
||||
class Transaction:
|
||||
delta: dict
|
||||
__slots__ = (
|
||||
'__connection',
|
||||
'__delta',
|
||||
)
|
||||
|
||||
def __init__(self, connection: DbConnection) -> None:
|
||||
self.connection = connection
|
||||
__delta: dict
|
||||
|
||||
def __init__(self, connection: DbConnection, /) -> None:
|
||||
self.__connection = connection
|
||||
|
||||
async def __aenter__(self) -> FallbackMapping:
|
||||
self.delta = {}
|
||||
return FallbackMapping(self.delta, self.connection)
|
||||
self.__delta = {}
|
||||
return FallbackMapping(self.__delta, self.__connection)
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type is None:
|
||||
await self.connection.complete_transaction(self.delta)
|
||||
del self.delta
|
||||
await self.__connection.complete_transaction(self.__delta)
|
||||
del self.__delta
|
||||
|
Loading…
Reference in New Issue
Block a user