ptvp35/ptvp35/__init__.py

582 lines
18 KiB
Python

import asyncio
import concurrent.futures
import json
import os
import pathlib
import threading
import traceback
from io import StringIO, UnsupportedOperation
from typing import Any, Optional, IO, Hashable
__all__ = ('KVRequest', 'KVJson', 'DbConnection', 'DbFactory', 'Db',)
class Request:
__slots__ = (
'__future',
)
def __init__(self, future: Optional[asyncio.Future], /) -> None:
self.__future = future
def waiting(self, /) -> bool:
return self.__future is not None
def set_result(self, result, /) -> None:
if self.__future is not None:
self.__future.set_result(result)
def set_exception(self, exception, /):
if self.__future is not None:
self.__future.set_exception(exception)
async def wait(self, /):
if self.__future is not None:
await self.__future
class KVFactory:
__slots__ = ()
def line(self, key: Any, value: Any, /) -> str:
raise NotImplementedError
def fromline(self, line: str, /) -> 'KVRequest':
raise NotImplementedError
def request(self, key: Any, value: Any, /, *, future: Optional[asyncio.Future]) -> 'KVRequest':
return KVRequest(key, value, future=future, factory=self)
def free(self, key: Any, value: Any, /) -> 'KVRequest':
return self.request(key, value, future=None)
class KVRequest(Request):
__slots__ = (
'__factory',
'key',
'value',
)
def __init__(self, key: Any, value: Any, /, *, future: Optional[asyncio.Future], factory: KVFactory):
super().__init__(future)
self.__factory = factory
self.key = key
self.value = value
def free(self, /):
return self.__factory.free(self.key, self.value)
def line(self, /) -> str:
return self.__factory.line(self.key, self.value)
class DumpRequest(Request):
__slots__ = ()
class UnknownRequestType(TypeError):
pass
class KVJson(KVFactory):
__slots__ = ()
def line(self, key: Any, value: Any, /) -> str:
return json.dumps({'key': key, 'value': value}) + '\n'
@classmethod
def _load_key(cls, key: Any, /) -> Hashable:
if isinstance(key, Hashable):
return key
elif isinstance(key, list):
return tuple(map(cls._load_key, key))
elif isinstance(key, dict):
return tuple((cls._load_key(k), cls._load_key(v)) for k, v in key.items())
else:
raise TypeError(
'unknown KVJson key type, cannot convert to hashable'
)
def fromline(self, line: str, /) -> 'KVRequest':
d = json.loads(line)
return self.free(self._load_key(d['key']), d['value'])
class TransactionRequest(Request):
__slots__ = (
'buffer',
)
def __init__(self, buffer: StringIO, /, *, future: Optional[asyncio.Future]):
super().__init__(future)
self.buffer = buffer
def line(self, /) -> str:
return self.buffer.getvalue()
class DbConnection:
__slots__ = (
'__factory',
'__path',
'__path_backup',
'__path_recover',
'__path_error',
'__path_truncate',
'__path_truncate_flag',
'__not_running',
'__mmdb',
'__loop',
'__queue',
'__file',
'__buffer',
'__buffer_future',
'__buffer_requested',
'__task',
'__initial_size',
)
__mmdb: dict
__loop: asyncio.AbstractEventLoop
__queue: asyncio.Queue[Request]
__file: IO[str]
__buffer: StringIO
__buffer_future: asyncio.Future
__buffer_requested: bool
__task: asyncio.Future
__initial_size: int
def __init__(
self,
factory: 'DbFactory',
/
) -> None:
self.__factory = factory
self.__path = path = self.__factory.path
name = self.__path.name
self.__path_backup = path.with_name(name + '.backup')
self.__path_recover = path.with_name(name + '.recover')
self.__path_error = path.with_name(name + '.error')
self.__path_truncate = path.with_name(name + '.truncate')
self.__path_truncate_flag = path.with_name(name + '.truncate_flag')
self.__not_running = True
def _create_future(self, /) -> asyncio.Future:
return self.__loop.create_future()
def _save_error_sync(self, line: str, /) -> None:
with self.__path_error.open('a') as file:
file.write(line.strip() + '\n')
async def _save_error(self, line: str, /) -> None:
await self.__loop.run_in_executor(None, self._save_error_sync, line)
def _queue_error(self, line: str, /) -> concurrent.futures.Future:
return asyncio.run_coroutine_threadsafe(
self._save_error(line), self.__loop
)
def _save_error_from_thread(self, line: str, /) -> None:
self._queue_error(line).result()
def io2db(self, io: IO[str], db: dict, /) -> int:
"""there are no guarantees about .error file if error occurs here"""
size = 0
for line in io:
try:
request = self.__factory.kvfactory.fromline(line)
db[request.key] = request.value
size += len(line)
except (json.JSONDecodeError, EOFError):
traceback.print_exc()
# this condition should never occur, but we should be able to handle this UB as best as we can
self._queue_error(line)
return size
def db2io(self, db: dict, io: IO[str], /) -> int:
size = 0
for key, value in db.items():
size += io.write(self.__factory.kvfactory.free(key, value).line())
return size
def _path2db_sync(self, path: pathlib.Path, db: dict, /) -> int:
path.touch()
with path.open('r') as file:
return self.io2db(file, db)
def _db2path_sync(self, db: dict, path: pathlib.Path, /) -> int:
with path.open('w') as file:
return self.db2io(db, file)
def get(self, key: Any, default: Any, /):
return self.__mmdb.get(key, default)
async def set(self, key: Any, value: Any, /) -> None:
self.__mmdb[key] = value
future = self._create_future()
self.__queue.put_nowait(
self.__factory.kvfactory.request(key, value, future=future))
await future
def set_nowait(self, key: Any, value: Any, /) -> None:
self.__mmdb[key] = value
self.__queue.put_nowait(self.__factory.kvfactory.free(key, value))
def _clear_buffer(self, /) -> None:
self.__buffer = StringIO()
self.__buffer_future = self._create_future()
self.__buffer_requested = False
def _compress_buffer(self, /) -> StringIO:
self.__buffer.seek(0)
bufferdb = {}
self.io2db(self.__buffer, bufferdb)
buffer = StringIO()
self.db2io(bufferdb, buffer)
return buffer
def _dump_compressed_buffer_sync(self, /) -> None:
self._file_write_sync(self._compress_buffer().getvalue())
async def _dump_compressed_buffer(self, /) -> None:
await self.__loop.run_in_executor(None, self._dump_compressed_buffer_sync)
def _satisfy_buffer_future(self, /) -> None:
self.__buffer_future.set_result(None)
self._clear_buffer()
async def _do_dump_buffer(self, /) -> None:
await self._dump_compressed_buffer()
self._satisfy_buffer_future()
async def _dump_buffer(self, /) -> None:
if self.__buffer.tell():
await self._do_dump_buffer()
await self._reload_if_oversized()
elif self.__buffer_requested:
self._satisfy_buffer_future()
def _request_buffer(self, request: Request, /) -> None:
if request.waiting():
self.__buffer_future.exception
self.__buffer_future.add_done_callback(
lambda bf: request.set_exception(e) if (
e := bf.exception()) is not None else request.set_result(None)
)
if not self.__buffer_requested:
self.__buffer_requested = True
self.__queue.put_nowait(DumpRequest(None))
async def _dump_buffer_or_request_so(self, request: Request, /) -> None:
if self.__buffer.tell() >= self.__factory.buffersize:
await self._dump_buffer()
request.set_result(None)
else:
self._request_buffer(request)
async def _write(self, line: str, request: Request, /) -> None:
self.__buffer.write(line)
await self._dump_buffer_or_request_so(request)
def _truncation_set_sync(self, /) -> None:
self.__path_truncate.write_bytes(
self.__file.tell().to_bytes(16, 'little')
)
self.__path_truncate_flag.touch()
def _truncation_unset_sync(self, /) -> None:
self.__path_truncate_flag.unlink(missing_ok=True)
self.__path_truncate.unlink(missing_ok=True)
def _file_truncate_sync(self, file: IO[str], pos: int, /) -> None:
file.seek(pos)
self._save_error_from_thread(file.read())
file.truncate(pos)
def _truncation_target_sync(self, /) -> int:
return int.from_bytes(self.__path_truncate.read_bytes(), 'little')
def _truncate_sync(self, /) -> None:
with self.__path.open('r+') as file:
self._file_truncate_sync(file, self._truncation_target_sync())
def _assure_truncation_sync(self, /) -> None:
if self.__path_truncate_flag.exists():
self._truncate_sync()
self._truncation_unset_sync()
def _write_to_disk_sync(self, line: str, /) -> None:
self.__file.write(line)
self.__file.flush()
try:
os.fsync(self.__file.fileno())
except UnsupportedOperation:
pass
def _file_write_sync(self, line: str, /) -> None:
self._truncation_set_sync()
self._write_to_disk_sync(line)
self._truncation_unset_sync()
async def _reload_if_oversized(self, /) -> None:
if self.__file.tell() > 2 * self.__initial_size:
await self._reload()
async def _handle_request(self, request: Request, /) -> None:
if isinstance(request, KVRequest):
await self._write(request.line(), request)
elif isinstance(request, DumpRequest):
await self._dump_buffer()
request.set_result(None)
elif isinstance(request, TransactionRequest):
await self._write(request.line(), request)
else:
raise UnknownRequestType
async def _background_cycle(self, /) -> None:
request: Request = await self.__queue.get()
try:
await self._handle_request(request)
except Exception as e:
request.set_exception(e)
finally:
self.__queue.task_done()
async def _background_task(self, /) -> None:
while True:
await self._background_cycle()
def _start_task(self, /) -> None:
self.__task = self.__loop.create_task(self._background_task())
def _recovery_set_sync(self, db: dict, /) -> None:
self.__initial_size = self._db2path_sync(db, self.__path_backup)
self.__path_recover.touch()
def _recovery_unset_sync(self, /) -> None:
self.__path_recover.unlink()
self.__path_backup.unlink()
def _copy_sync(self, db: Optional[dict], /) -> None:
if db is None:
db = {}
self._path2db_sync(self.__path_backup, db)
self._db2path_sync(db, self.__path)
def _finish_recovery_sync(self, db: Optional[dict], /) -> None:
self._copy_sync(db)
self._recovery_unset_sync()
def _build_file_sync(self, db: dict, /) -> None:
self._recovery_set_sync(db)
self._finish_recovery_sync(db)
def _run_in_thread(self, fn, /, *args, **kwargs) -> asyncio.Future:
"""we are using our own thread to guarantee as much of autonomy and control as possible.
intended for heavy tasks."""
future = self._create_future()
def wrap() -> None:
try:
result = fn(*args, **kwargs)
except Exception as exception:
self.__loop.call_soon_threadsafe(
future.set_exception, exception
)
else:
self.__loop.call_soon_threadsafe(
future.set_result, result
)
threading.Thread(target=wrap).start()
return future
async def _build_file(self, db: dict, /) -> None:
await self._run_in_thread(self._build_file_sync, db)
def _rebuild_file_sync(self, db: dict, /) -> None:
if self.__path_recover.exists():
self._finish_recovery_sync(None)
self._path2db_sync(self.__path, db)
self._build_file_sync(db)
def _file_open(self, /) -> None:
self.__file = self.__path.open('a')
def _reload_sync(self, /) -> None:
self.__file.close()
del self.__file
self._rebuild_file_sync({})
self._file_open()
async def _reload(self, /) -> None:
await self._run_in_thread(self._reload_sync)
def _load_mmdb_sync(self, /) -> dict:
db = {}
self._rebuild_file_sync(db)
return db
def _initialize_mmdb_sync(self, /) -> None:
self.__mmdb = self._load_mmdb_sync()
def _load_from_file_sync(self, /) -> None:
self._assure_truncation_sync()
self._initialize_mmdb_sync()
self._file_open()
async def _load_from_file(self, /) -> None:
await self._run_in_thread(self._load_from_file_sync)
def _initialize_queue(self, /) -> None:
self.__queue = asyncio.Queue()
self._clear_buffer()
async def _initialize_running(self, /) -> None:
self.__loop = asyncio.get_running_loop()
self._initialize_queue()
await self._load_from_file()
self._start_task()
async def _initialize(self, /) -> None:
assert self.__not_running
self.__not_running = False
await self._initialize_running()
@classmethod
async def create(cls, factory: 'DbFactory', /) -> 'DbConnection':
dbconnection = DbConnection(factory)
await dbconnection._initialize()
return dbconnection
async def aclose(self, /) -> None:
if not self.__task.done():
await self.__queue.join()
self.__task.cancel()
await self._dump_buffer()
if not self.__buffer_future.done():
self.__buffer_future.cancel()
self.__file.close()
await self._build_file(self.__mmdb)
del self.__task
del self.__file
del self.__initial_size
del self.__mmdb
del self.__buffer_requested
del self.__buffer_future
del self.__buffer
del self.__queue
del self.__loop
self.__not_running = True
async def complete_transaction(self, delta: dict, /) -> None:
if not delta:
return
buffer = StringIO()
self.db2io(delta, buffer)
self.__mmdb.update(delta)
future = self._create_future()
self.__queue.put_nowait(TransactionRequest(buffer, future=future))
await future
async def commit(self, /) -> None:
future = self._create_future()
self.__queue.put_nowait(DumpRequest(future))
await future
def transaction(self, /) -> 'Transaction':
return Transaction(self)
class DbFactory:
__slots__ = (
'path',
'kvfactory',
'buffersize',
'db',
)
def __init__(self, path: pathlib.Path, /, *, kvfactory: KVFactory, buffersize=1048576) -> None:
self.path = path
self.kvfactory = kvfactory
self.buffersize = buffersize
async def __aenter__(self) -> DbConnection:
self.db = await DbConnection.create(self)
return self.db
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.db.aclose()
class Db(DbConnection):
__slots__ = ()
def __init__(self, path: str | pathlib.Path, /, *, kvfactory: KVFactory, buffersize=1048576):
DbConnection.__init__(
self,
DbFactory(
pathlib.Path(path), kvfactory=kvfactory, buffersize=buffersize
)
)
async def __aenter__(self) -> DbConnection:
await self._initialize()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.aclose()
class FallbackMapping:
__slots__ = (
'__delta',
'__shadow',
'__connection',
)
def __init__(self, delta: dict, connection: DbConnection, /) -> None:
self.__delta = delta
self.__shadow = {}
self.__connection = connection
def get(self, key: Any, default: Any, /):
if key in self.__delta:
return self.__delta[key]
if key in self.__shadow:
return self.__shadow[key]
return self.__connection.get(key, default)
def set_nowait(self, key: Any, value: Any, /) -> None:
self.__delta[key] = value
async def commit(self, /) -> None:
delta = self.__delta.copy()
self.__shadow |= delta
self.__delta.clear()
await self.__connection.complete_transaction(delta)
class Transaction:
__slots__ = (
'__connection',
'__delta',
)
__delta: dict
def __init__(self, connection: DbConnection, /) -> None:
self.__connection = connection
async def __aenter__(self) -> FallbackMapping:
self.__delta = {}
return FallbackMapping(self.__delta, self.__connection)
async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
await self.__connection.complete_transaction(self.__delta)
del self.__delta