ptvp35/ptvp35/__init__.py

395 lines
12 KiB
Python

import asyncio
import concurrent.futures
import json
import pathlib
import pickle
import threading
import traceback
from io import StringIO
from typing import Any, Optional, IO, Type, Hashable
__all__ = ('KVRequest', 'KVJson', 'DbConnection', 'DbFactory', 'Db',)
class Request:
def __init__(self, future: Optional[asyncio.Future]):
self.future = future
def set_result(self, result):
if self.future:
self.future.set_result(result)
def set_exception(self, exception):
if self.future:
self.future.set_exception(exception)
class KVRequest(Request):
def __init__(self, key: Any, value: Any, future: Optional[asyncio.Future]):
super().__init__(future)
self.key = key
self.value = value
def free(self):
return type(self)(self.key, self.value, None)
def line(self) -> str:
raise NotImplementedError
@classmethod
def fromline(cls, line: str) -> 'KVRequest':
raise NotImplementedError
class DumpRequest(Request):
pass
class UnkownRequestType(TypeError):
pass
class KVPickle(KVRequest):
def line(self) -> str:
return pickle.dumps(self.free()).hex() + "\n"
@classmethod
def fromline(cls, line: str) -> 'KVPickle':
return pickle.loads(bytes.fromhex(line.strip()))
class KVJson(KVRequest):
def line(self) -> str:
return json.dumps({'key': self.key, 'value': self.value}) + "\n"
@classmethod
def _load_key(cls, key: Any) -> Hashable:
if isinstance(key, Hashable):
return key
elif isinstance(key, list):
return tuple(map(cls._load_key, key))
elif isinstance(key, dict):
return tuple((cls._load_key(k), cls._load_key(v)) for k, v in key)
else:
raise TypeError("unknown KVJson key type, cannot convert to hashable")
@classmethod
def fromline(cls, line: str) -> 'KVJson':
d = json.loads(line)
return KVJson(cls._load_key(d['key']), d['value'], None)
class ErrorRequest(Request):
def __init__(self, line: str, future: Optional[asyncio.Future]):
super().__init__(future)
self.line = line
async def wait(self):
if self.future:
await self.future
class TransactionRequest(Request):
def __init__(self, buffer: StringIO, future: Optional[asyncio.Future]):
super().__init__(future)
self.buffer = buffer
def line(self) -> str:
return self.buffer.getvalue()
class DbConnection:
__mmdb: dict
__loop: asyncio.AbstractEventLoop
__queue: asyncio.Queue[Request]
__file: IO[str]
__buffer: StringIO
__task: asyncio.Future
__initial_size: int
def __init__(
self,
factory: 'DbFactory',
) -> None:
self.factory = factory
path = self.factory.path
self.__path = pathlib.Path(path)
self.__path_backup = pathlib.Path(path + '.backup')
self.__path_recover = pathlib.Path(path + '.recover')
self.__path_error = pathlib.Path(path + '.error')
self.not_running = True
def _queue_error(self, line: str):
request = ErrorRequest(line, self.__loop.create_future())
self.__queue.put_nowait(request)
self.__loop.create_task(request.wait())
def io2db(self, io: IO[str], db: dict) -> int:
size = 0
for line in io:
try:
request = self.factory.kvrequest_type.fromline(line)
db[request.key] = request.value
size += len(line)
except (json.JSONDecodeError, pickle.UnpicklingError, EOFError):
traceback.print_exc()
self._queue_error(line)
return size
def db2io(self, db: dict, io: IO[str]) -> int:
size = 0
for key, value in db.items():
size += io.write(self.factory.kvrequest_type(key, value, None).line())
return size
def get(self, key: Any, default: Any):
return self.__mmdb.get(key, default)
async def set(self, key: Any, value: Any):
self.__mmdb[key] = value
future = self.__loop.create_future()
self.__queue.put_nowait(self.factory.kvrequest_type(key, value, future))
await future
def set_nowait(self, key: Any, value: Any):
self.__mmdb[key] = value
self.__queue.put_nowait(self.factory.kvrequest_type(key, value, None))
async def _dump_buffer_or_request_so(self, request: Request):
if self.__buffer.tell() >= self.factory.buffersize:
await self._dump_buffer()
request.set_result(None)
else:
await self.__queue.put(DumpRequest(request.future))
async def _write(self, line: str, request: Request):
self.__buffer.write(line)
await self._dump_buffer_or_request_so(request)
def _clear_buffer(self):
self.__buffer = StringIO()
def _compress_buffer(self) -> StringIO:
self.__buffer.seek(0)
bufferdb = {}
self.io2db(self.__buffer, bufferdb)
buffer = StringIO()
self.db2io(bufferdb, buffer)
return buffer
async def _dump_compressed_buffer(self):
buffer = self._compress_buffer()
await self.__loop.run_in_executor(None, self.__file.write, buffer.getvalue())
async def _do_dump_buffer(self):
await self._dump_compressed_buffer()
self._clear_buffer()
async def _reload_if_oversized(self):
if self.__file.tell() > 2 * self.__initial_size:
await self._reload()
async def _dump_buffer(self):
if self.__buffer.tell():
await self._do_dump_buffer()
await self._reload_if_oversized()
async def _save_error(self, line: str):
with open(self.__path_error, 'a') as file:
await self.__loop.run_in_executor(None, file.write, line.strip() + '\n')
async def _handle_request(self, request: Request):
if isinstance(request, self.factory.kvrequest_type):
await self._write(request.line(), request)
elif isinstance(request, DumpRequest):
await self._dump_buffer()
request.set_result(None)
elif isinstance(request, ErrorRequest):
await self._save_error(request.line)
elif isinstance(request, TransactionRequest):
await self._write(request.line(), request)
else:
raise UnkownRequestType
async def _background_cycle(self):
request: Request = await self.__queue.get()
try:
await self._handle_request(request)
except Exception as e:
request.set_exception(e)
finally:
self.__queue.task_done()
async def _background_task(self):
while True:
await self._background_cycle()
def _copy_sync(self):
db = {}
with open(self.__path_backup) as file:
self.io2db(file, db)
with open(self.__path, 'w') as file:
self.db2io(db, file)
def _finish_recovery_sync(self):
self._copy_sync()
self.__path_recover.unlink()
self.__path_backup.unlink()
def _build_file_sync(self, db: dict):
with open(self.__path_backup, "w") as file:
self.__initial_size = self.db2io(db, file)
self.__path_recover.touch()
self._finish_recovery_sync()
def _run_in_thread(self, fn, *args, **kwargs) -> asyncio.Future:
future = self.__loop.create_future()
def wrap():
try:
result = fn(*args, **kwargs)
except Exception as exception:
self.__loop.call_soon_threadsafe(future.set_exception, exception)
else:
self.__loop.call_soon_threadsafe(future.set_result, result)
threading.Thread(target=wrap).start()
return future
async def _build_file(self, db: dict):
await self._run_in_thread(self._build_file_sync, db)
def _rebuild_file_sync(self, db: dict):
if self.__path_recover.exists():
self._finish_recovery_sync()
self.__path.touch()
with open(self.__path) as file:
self.io2db(file, db)
self._build_file_sync(db)
async def _rebuild_file(self, db: dict):
await self._run_in_thread(self._rebuild_file_sync, db)
async def _reload(self):
self.__file.close()
await self._rebuild_file({})
self.__file = open(self.__path, "a")
async def _load_from_file(self):
await self._rebuild_file(self.__mmdb)
async def _initialize_mmdb(self):
self.__mmdb = {}
await self._load_from_file()
self.__file = open(self.__path, "a")
async def _initialize_queue(self):
self.__queue = asyncio.Queue()
self.__buffer = StringIO()
async def _start_task(self):
self.__task = self.__loop.create_task(self._background_task())
async def _initialize(self):
assert self.not_running
self.__loop = asyncio.get_event_loop()
await self._initialize_queue()
await self._initialize_mmdb()
await self._start_task()
self.not_running = False
@classmethod
async def create(cls, factory: 'DbFactory') -> 'DbConnection':
dbconnection = DbConnection(factory)
await dbconnection._initialize()
return dbconnection
async def aclose(self):
if not self.__task.done():
await self.__queue.join()
self.__task.cancel()
await self._dump_buffer()
self.__file.close()
await self._build_file(self.__mmdb)
self.not_running = True
del self.__mmdb
del self.__loop
del self.__queue
del self.__file
del self.__buffer
del self.__task
del self.__initial_size
async def complete_transaction(self, delta: dict):
buffer = StringIO()
self.db2io(delta, buffer)
self.__mmdb.update(delta)
future = self.__loop.create_future()
self.__queue.put_nowait(TransactionRequest(buffer, future))
await future
async def commit(self):
future = self.__loop.create_future()
self.__queue.put_nowait(DumpRequest(future))
await future
def transaction(self) -> 'Transaction':
return Transaction(self)
class DbFactory:
def __init__(self, path: str | pathlib.Path, *, kvrequest_type: Type[KVRequest], buffersize=1048576):
self.path = path = str(path)
self.kvrequest_type = kvrequest_type
self.buffersize = buffersize
async def __aenter__(self) -> DbConnection:
self.db = await DbConnection.create(self)
return self.db
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.db.aclose()
class Db(DbFactory, DbConnection):
def __init__(self, path: str | pathlib.Path, *, kvrequest_type: Type[KVRequest], buffersize=1048576):
DbFactory.__init__(self, path, kvrequest_type=kvrequest_type, buffersize=buffersize)
DbConnection.__init__(self, self)
async def __aenter__(self) -> DbConnection:
await self._initialize()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.aclose()
class FallbackMapping:
def __init__(self, delta: dict, connection: DbConnection) -> None:
self.__delta = delta
self.__connection = connection
def get(self, key: Any, default: Any):
return self.__delta.get(key, self.__connection.get(key, default))
def set_nowait(self, key: Any, value: Any):
self.__delta[key] = value
class Transaction:
delta: dict
def __init__(self, connection: DbConnection) -> None:
self.connection = connection
async def __aenter__(self) -> FallbackMapping:
self.delta = {}
return FallbackMapping(self.delta, self.connection)
async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
await self.connection.complete_transaction(self.delta)
del self.delta