ptvp35/ptvp35/__init__.py

291 lines
8.8 KiB
Python

import asyncio
import json
import pathlib
import pickle
import shutil
import traceback
from io import StringIO
from typing import Any, Optional, IO, Type, Hashable
class Request:
def __init__(self, future: Optional[asyncio.Future]):
self.future = future
def set_result(self, result):
self.future and self.future.set_result(result)
def set_exception(self, exception):
self.future and self.future.set_exception(exception)
class KVRequest(Request):
def __init__(self, key: Any, value: Any, future: Optional[asyncio.Future]):
super().__init__(future)
self.key = key
self.value = value
def free(self):
return type(self)(self.key, self.value, None)
def line(self) -> str:
raise NotImplementedError
@classmethod
def fromline(cls, line: str) -> 'KVRequest':
raise NotImplementedError
class DumpRequest(Request):
pass
class UnkownRequestType(TypeError):
pass
class KVPickle(KVRequest):
def line(self) -> str:
return pickle.dumps(self.free()).hex() + "\n"
@classmethod
def fromline(cls, line: str) -> 'KVPickle':
return pickle.loads(bytes.fromhex(line.strip()))
class KVJson(KVRequest):
def line(self) -> str:
return json.dumps({'key': self.key, 'value': self.value}) + "\n"
@classmethod
def _load_key(cls, key: Any) -> Hashable:
if isinstance(key, Hashable):
return key
elif isinstance(key, list):
return tuple(map(cls._load_key, key))
elif isinstance(key, dict):
return tuple((cls._load_key(k), cls._load_key(v)) for k, v in key)
else:
raise TypeError("unknown KVJson key type, cannot convert to hashable")
@classmethod
def fromline(cls, line: str) -> 'KVJson':
d = json.loads(line)
return KVJson(cls._load_key(d['key']), d['value'], None)
class ErrorRequest(Request):
def __init__(self, line: str, future: Optional[asyncio.Future]):
super().__init__(future)
self.line = line
async def wait(self):
await self.future
class Db:
__mmdb: Optional[dict]
__loop: Optional[asyncio.AbstractEventLoop]
__queue: Optional[asyncio.Queue[Request]]
__file: Optional[IO[str]]
__buffer: Optional[StringIO]
__task: Optional[asyncio.Future]
def __init__(self, path: str, *, kvrequest_type: Type[KVRequest], buffersize=1048576):
path = str(path)
self.kvrequest_type = kvrequest_type
self.buffersize = buffersize
self.__path = pathlib.Path(path)
self.__path_backup = pathlib.Path(path + '.backup')
self.__path_recover = pathlib.Path(path + '.recover')
self.__path_error = pathlib.Path(path + '.error')
self.__task = None
def _queue_error(self, line: str):
request = ErrorRequest(line, self.__loop.create_future())
self.__queue.put_nowait(request)
self.__loop.create_task(request.wait())
def io2db(self, io: IO[str], db: dict) -> int:
size = 0
for line in io:
try:
request = self.kvrequest_type.fromline(line)
db[request.key] = request.value
size += len(line)
except (json.JSONDecodeError, pickle.UnpicklingError, EOFError):
traceback.print_exc()
self._queue_error(line)
return size
def db2io(self, db: dict, io: IO[str]) -> int:
size = 0
for key, value in db.items():
size += io.write(self.kvrequest_type(key, value, None).line())
return size
def get(self, key: Any, default: Any):
return self.__mmdb.get(key, default)
async def set(self, key: Any, value: Any):
self.__mmdb[key] = value
future = self.__loop.create_future()
self.__queue.put_nowait(self.kvrequest_type(key, value, future))
await future
def set_nowait(self, key: Any, value: Any):
self.__mmdb[key] = value
self.__queue.put_nowait(self.kvrequest_type(key, value, None))
async def _dump_buffer_or_request_so(self, request: KVRequest):
if self.__buffer.tell() >= self.buffersize:
await self._dump_buffer()
request.set_result(None)
else:
await self.__queue.put(DumpRequest(request.future))
async def _write(self, line: str, request: KVRequest):
self.__buffer.write(line)
await self._dump_buffer_or_request_so(request)
def _clear_buffer(self):
self.__buffer = StringIO()
def _compress_buffer(self) -> StringIO:
self.__buffer.seek(0)
bufferdb = {}
self.io2db(self.__buffer, bufferdb)
buffer = StringIO()
self.db2io(bufferdb, buffer)
return buffer
async def _dump_compressed_buffer(self):
buffer = self._compress_buffer()
await self.__loop.run_in_executor(None, self.__file.write, buffer.getvalue())
async def _do_dump_buffer(self):
await self._dump_compressed_buffer()
self._clear_buffer()
async def _reload_if_oversized(self):
if self.__file.tell() > 2 * self.__initial_size:
await self._reload()
async def _dump_buffer(self):
if self.__buffer.tell():
await self._do_dump_buffer()
await self._reload_if_oversized()
async def _save_error(self, line: str):
with open(self.__path_error, 'a') as file:
await self.__loop.run_in_executor(None, file.write, line.strip() + '\n')
async def _handle_request(self, request: Request):
if isinstance(request, self.kvrequest_type):
await self._write(request.line(), request)
elif isinstance(request, DumpRequest):
await self._dump_buffer()
request.set_result(None)
elif isinstance(request, ErrorRequest):
await self._save_error(request.line)
else:
raise UnkownRequestType
async def _background_cycle(self):
request: Request = await self.__queue.get()
try:
await self._handle_request(request)
except Exception as e:
request.set_exception(e)
finally:
self.__queue.task_done()
async def _background_task(self):
while True:
await self._background_cycle()
async def _finish_recovery(self):
await self.__loop.run_in_executor(None, shutil.copy, self.__path_backup, self.__path)
self.__path_recover.unlink()
self.__path_backup.unlink()
async def _rebuild_file(self, db: dict):
if self.__path_recover.exists():
await self._finish_recovery()
self.__path.touch()
with open(self.__path) as file:
await self.__loop.run_in_executor(None, self.io2db, file, db)
with open(self.__path_backup, "w") as file:
self.__initial_size = await self.__loop.run_in_executor(None, self.db2io, db, file)
self.__path_recover.touch()
await self._finish_recovery()
async def _reload(self):
self.__file.close()
await self._rebuild_file({})
self.__file = open(self.__path, "a")
async def _load_from_file(self):
await self._rebuild_file(self.__mmdb)
async def _initialize_mmdb(self):
self.__mmdb = {}
await self._load_from_file()
self.__file = open(self.__path, "a")
async def _initialize_queue(self):
self.__queue = asyncio.Queue()
self.__buffer = StringIO()
async def _start_task(self):
self.__task = self.__loop.create_task(self._background_task())
async def _initialize(self):
assert self.__task is None
self.__loop = asyncio.get_event_loop()
await self._initialize_queue()
await self._initialize_mmdb()
await self._start_task()
async def __aenter__(self):
await self._initialize()
return self
async def _aclose(self):
if not self.__task.done():
await self.__queue.join()
self.__task.cancel()
await self._dump_buffer()
self.__file.close()
with open(self.__path_backup, "w") as file:
self.__initial_size = await self.__loop.run_in_executor(None, self.db2io, self.__mmdb, file)
self.__path_recover.touch()
await self._finish_recovery()
def _uninitialize(self):
self.__mmdb = None
self.__loop = None
self.__queue = None
self.__file = None
self.__buffer = None
self.__task = None
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._aclose()
self._uninitialize()
def cursor(self, **kwargs):
return Cursor(self, **kwargs)
class Cursor:
def __init__(self, db: Db, default=None):
self.default = default
self.db = db
def __getitem__(self, item):
return self.db.get(item, self.default)
def __setitem__(self, key, value):
self.db.set_nowait(key, value)