ptvp35/ptvp35/__init__.py
2022-11-05 03:35:50 +00:00

394 lines
12 KiB
Python

import asyncio
import json
import pathlib
import pickle
import threading
import traceback
from io import StringIO
from typing import Any, Optional, IO, Type, Hashable
__all__ = ('KVRequest', 'KVJson', 'DbConnection', 'DbFactory', 'Db',)
class Request:
def __init__(self, future: Optional[asyncio.Future]):
self.future = future
def set_result(self, result):
if self.future:
self.future.set_result(result)
def set_exception(self, exception):
if self.future:
self.future.set_exception(exception)
class KVRequest(Request):
def __init__(self, key: Any, value: Any, future: Optional[asyncio.Future]):
super().__init__(future)
self.key = key
self.value = value
def free(self):
return type(self)(self.key, self.value, None)
def line(self) -> str:
raise NotImplementedError
@classmethod
def fromline(cls, line: str) -> 'KVRequest':
raise NotImplementedError
class DumpRequest(Request):
pass
class UnkownRequestType(TypeError):
pass
class KVPickle(KVRequest):
def line(self) -> str:
return pickle.dumps(self.free()).hex() + "\n"
@classmethod
def fromline(cls, line: str) -> 'KVPickle':
return pickle.loads(bytes.fromhex(line.strip()))
class KVJson(KVRequest):
def line(self) -> str:
return json.dumps({'key': self.key, 'value': self.value}) + "\n"
@classmethod
def _load_key(cls, key: Any) -> Hashable:
if isinstance(key, Hashable):
return key
elif isinstance(key, list):
return tuple(map(cls._load_key, key))
elif isinstance(key, dict):
return tuple((cls._load_key(k), cls._load_key(v)) for k, v in key)
else:
raise TypeError("unknown KVJson key type, cannot convert to hashable")
@classmethod
def fromline(cls, line: str) -> 'KVJson':
d = json.loads(line)
return KVJson(cls._load_key(d['key']), d['value'], None)
class ErrorRequest(Request):
def __init__(self, line: str, future: Optional[asyncio.Future]):
super().__init__(future)
self.line = line
async def wait(self):
if self.future:
await self.future
class TransactionRequest(Request):
def __init__(self, buffer: StringIO, future: Optional[asyncio.Future]):
super().__init__(future)
self.buffer = buffer
def line(self) -> str:
return self.buffer.getvalue()
class DbConnection:
__mmdb: dict
__loop: asyncio.AbstractEventLoop
__queue: asyncio.Queue[Request]
__file: IO[str]
__buffer: StringIO
__task: asyncio.Future
__initial_size: int
def __init__(
self,
factory: 'DbFactory',
) -> None:
self.factory = factory
path = self.factory.path
self.__path = pathlib.Path(path)
self.__path_backup = pathlib.Path(path + '.backup')
self.__path_recover = pathlib.Path(path + '.recover')
self.__path_error = pathlib.Path(path + '.error')
self.not_running = True
def _queue_error(self, line: str):
request = ErrorRequest(line, self.__loop.create_future())
self.__queue.put_nowait(request)
self.__loop.create_task(request.wait())
def io2db(self, io: IO[str], db: dict) -> int:
size = 0
for line in io:
try:
request = self.factory.kvrequest_type.fromline(line)
db[request.key] = request.value
size += len(line)
except (json.JSONDecodeError, pickle.UnpicklingError, EOFError):
traceback.print_exc()
self._queue_error(line)
return size
def db2io(self, db: dict, io: IO[str]) -> int:
size = 0
for key, value in db.items():
size += io.write(self.factory.kvrequest_type(key, value, None).line())
return size
def get(self, key: Any, default: Any):
return self.__mmdb.get(key, default)
async def set(self, key: Any, value: Any):
self.__mmdb[key] = value
future = self.__loop.create_future()
self.__queue.put_nowait(self.factory.kvrequest_type(key, value, future))
await future
def set_nowait(self, key: Any, value: Any):
self.__mmdb[key] = value
self.__queue.put_nowait(self.factory.kvrequest_type(key, value, None))
async def _dump_buffer_or_request_so(self, request: Request):
if self.__buffer.tell() >= self.factory.buffersize:
await self._dump_buffer()
request.set_result(None)
else:
await self.__queue.put(DumpRequest(request.future))
async def _write(self, line: str, request: Request):
self.__buffer.write(line)
await self._dump_buffer_or_request_so(request)
def _clear_buffer(self):
self.__buffer = StringIO()
def _compress_buffer(self) -> StringIO:
self.__buffer.seek(0)
bufferdb = {}
self.io2db(self.__buffer, bufferdb)
buffer = StringIO()
self.db2io(bufferdb, buffer)
return buffer
async def _dump_compressed_buffer(self):
buffer = self._compress_buffer()
await self.__loop.run_in_executor(None, self.__file.write, buffer.getvalue())
async def _do_dump_buffer(self):
await self._dump_compressed_buffer()
self._clear_buffer()
async def _reload_if_oversized(self):
if self.__file.tell() > 2 * self.__initial_size:
await self._reload()
async def _dump_buffer(self):
if self.__buffer.tell():
await self._do_dump_buffer()
await self._reload_if_oversized()
async def _save_error(self, line: str):
with open(self.__path_error, 'a') as file:
await self.__loop.run_in_executor(None, file.write, line.strip() + '\n')
async def _handle_request(self, request: Request):
if isinstance(request, self.factory.kvrequest_type):
await self._write(request.line(), request)
elif isinstance(request, DumpRequest):
await self._dump_buffer()
request.set_result(None)
elif isinstance(request, ErrorRequest):
await self._save_error(request.line)
elif isinstance(request, TransactionRequest):
await self._write(request.line(), request)
else:
raise UnkownRequestType
async def _background_cycle(self):
request: Request = await self.__queue.get()
try:
await self._handle_request(request)
except Exception as e:
request.set_exception(e)
finally:
self.__queue.task_done()
async def _background_task(self):
while True:
await self._background_cycle()
def _copy_sync(self):
db = {}
with open(self.__path_backup) as file:
self.io2db(file, db)
with open(self.__path, 'w') as file:
self.db2io(db, file)
def _finish_recovery_sync(self):
self._copy_sync()
self.__path_recover.unlink()
self.__path_backup.unlink()
def _build_file_sync(self, db: dict):
with open(self.__path_backup, "w") as file:
self.__initial_size = self.db2io(db, file)
self.__path_recover.touch()
self._finish_recovery_sync()
def _run_in_thread(self, fn, *args, **kwargs) -> asyncio.Future:
future = self.__loop.create_future()
def wrap():
try:
result = fn(*args, **kwargs)
except Exception as exception:
self.__loop.call_soon_threadsafe(future.set_exception, exception)
else:
self.__loop.call_soon_threadsafe(future.set_result, result)
threading.Thread(target=wrap).start()
return future
async def _build_file(self, db: dict):
await self._run_in_thread(self._build_file_sync, db)
def _rebuild_file_sync(self, db: dict):
if self.__path_recover.exists():
self._finish_recovery_sync()
self.__path.touch()
with open(self.__path) as file:
self.io2db(file, db)
self._build_file_sync(db)
async def _rebuild_file(self, db: dict):
await self._run_in_thread(self._rebuild_file_sync, db)
async def _reload(self):
self.__file.close()
await self._rebuild_file({})
self.__file = open(self.__path, "a")
async def _load_from_file(self):
await self._rebuild_file(self.__mmdb)
async def _initialize_mmdb(self):
self.__mmdb = {}
await self._load_from_file()
self.__file = open(self.__path, "a")
async def _initialize_queue(self):
self.__queue = asyncio.Queue()
self.__buffer = StringIO()
async def _start_task(self):
self.__task = self.__loop.create_task(self._background_task())
async def _initialize(self):
assert self.not_running
self.__loop = asyncio.get_event_loop()
await self._initialize_queue()
await self._initialize_mmdb()
await self._start_task()
self.not_running = False
@classmethod
async def create(cls, factory: 'DbFactory') -> 'DbConnection':
dbconnection = DbConnection(factory)
await dbconnection._initialize()
return dbconnection
async def aclose(self):
if not self.__task.done():
await self.__queue.join()
self.__task.cancel()
await self._dump_buffer()
self.__file.close()
await self._build_file(self.__mmdb)
self.not_running = True
del self.__mmdb
del self.__loop
del self.__queue
del self.__file
del self.__buffer
del self.__task
del self.__initial_size
async def complete_transaction(self, delta: dict):
buffer = StringIO()
self.db2io(delta, buffer)
self.__mmdb.update(delta)
future = self.__loop.create_future()
self.__queue.put_nowait(TransactionRequest(buffer, future))
await future
async def commit(self):
future = self.__loop.create_future()
self.__queue.put_nowait(DumpRequest(future))
await future
def transaction(self) -> 'Transaction':
return Transaction(self)
class DbFactory:
def __init__(self, path: str | pathlib.Path, *, kvrequest_type: Type[KVRequest], buffersize=1048576):
self.path = path = str(path)
self.kvrequest_type = kvrequest_type
self.buffersize = buffersize
async def __aenter__(self) -> DbConnection:
self.db = await DbConnection.create(self)
return self.db
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.db.aclose()
class Db(DbFactory, DbConnection):
def __init__(self, path: str | pathlib.Path, *, kvrequest_type: Type[KVRequest], buffersize=1048576):
DbFactory.__init__(self, path, kvrequest_type=kvrequest_type, buffersize=buffersize)
DbConnection.__init__(self, self)
async def __aenter__(self) -> DbConnection:
await self._initialize()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.aclose()
class FallbackMapping:
def __init__(self, delta: dict, connection: DbConnection) -> None:
self.__delta = delta
self.__connection = connection
def get(self, key: Any, default: Any):
return self.__delta.get(key, self.__connection.get(key, default))
def set_nowait(self, key: Any, value: Any):
self.__delta[key] = value
class Transaction:
delta: dict
def __init__(self, connection: DbConnection) -> None:
self.connection = connection
async def __aenter__(self) -> FallbackMapping:
self.delta = {}
return FallbackMapping(self.delta, self.connection)
async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
await self.connection.complete_transaction(self.delta)
del self.delta