291 lines
8.8 KiB
Python
291 lines
8.8 KiB
Python
import asyncio
|
|
import json
|
|
import pathlib
|
|
import pickle
|
|
import shutil
|
|
import traceback
|
|
from io import StringIO
|
|
from typing import Any, Optional, IO, Type, Hashable
|
|
|
|
|
|
class Request:
|
|
def __init__(self, future: Optional[asyncio.Future]):
|
|
self.future = future
|
|
|
|
def set_result(self, result):
|
|
self.future and self.future.set_result(result)
|
|
|
|
def set_exception(self, exception):
|
|
self.future and self.future.set_exception(exception)
|
|
|
|
|
|
class KVRequest(Request):
|
|
def __init__(self, key: Any, value: Any, future: Optional[asyncio.Future]):
|
|
super().__init__(future)
|
|
self.key = key
|
|
self.value = value
|
|
|
|
def free(self):
|
|
return type(self)(self.key, self.value, None)
|
|
|
|
def line(self) -> str:
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def fromline(cls, line: str) -> 'KVRequest':
|
|
raise NotImplementedError
|
|
|
|
|
|
class DumpRequest(Request):
|
|
pass
|
|
|
|
|
|
class UnkownRequestType(TypeError):
|
|
pass
|
|
|
|
|
|
class KVPickle(KVRequest):
|
|
def line(self) -> str:
|
|
return pickle.dumps(self.free()).hex() + "\n"
|
|
|
|
@classmethod
|
|
def fromline(cls, line: str) -> 'KVPickle':
|
|
return pickle.loads(bytes.fromhex(line.strip()))
|
|
|
|
|
|
class KVJson(KVRequest):
|
|
def line(self) -> str:
|
|
return json.dumps({'key': self.key, 'value': self.value}) + "\n"
|
|
|
|
@classmethod
|
|
def _load_key(cls, key: Any) -> Hashable:
|
|
if isinstance(key, Hashable):
|
|
return key
|
|
elif isinstance(key, list):
|
|
return tuple(map(cls._load_key, key))
|
|
elif isinstance(key, dict):
|
|
return tuple((cls._load_key(k), cls._load_key(v)) for k, v in key)
|
|
else:
|
|
raise TypeError("unknown KVJson key type, cannot convert to hashable")
|
|
|
|
@classmethod
|
|
def fromline(cls, line: str) -> 'KVJson':
|
|
d = json.loads(line)
|
|
return KVJson(cls._load_key(d['key']), d['value'], None)
|
|
|
|
|
|
class ErrorRequest(Request):
|
|
def __init__(self, line: str, future: Optional[asyncio.Future]):
|
|
super().__init__(future)
|
|
self.line = line
|
|
|
|
async def wait(self):
|
|
await self.future
|
|
|
|
|
|
class Db:
|
|
__mmdb: Optional[dict]
|
|
__loop: Optional[asyncio.AbstractEventLoop]
|
|
__queue: Optional[asyncio.Queue[Request]]
|
|
__file: Optional[IO[str]]
|
|
__buffer: Optional[StringIO]
|
|
__task: Optional[asyncio.Future]
|
|
|
|
def __init__(self, path: str, *, kvrequest_type: Type[KVRequest], buffersize=1048576):
|
|
path = str(path)
|
|
self.kvrequest_type = kvrequest_type
|
|
self.buffersize = buffersize
|
|
self.__path = pathlib.Path(path)
|
|
self.__path_backup = pathlib.Path(path + '.backup')
|
|
self.__path_recover = pathlib.Path(path + '.recover')
|
|
self.__path_error = pathlib.Path(path + '.error')
|
|
self.__task = None
|
|
|
|
def _queue_error(self, line: str):
|
|
request = ErrorRequest(line, self.__loop.create_future())
|
|
self.__queue.put_nowait(request)
|
|
self.__loop.create_task(request.wait())
|
|
|
|
def io2db(self, io: IO[str], db: dict) -> int:
|
|
size = 0
|
|
for line in io:
|
|
try:
|
|
request = self.kvrequest_type.fromline(line)
|
|
db[request.key] = request.value
|
|
size += len(line)
|
|
except (json.JSONDecodeError, pickle.UnpicklingError, EOFError):
|
|
traceback.print_exc()
|
|
self._queue_error(line)
|
|
return size
|
|
|
|
def db2io(self, db: dict, io: IO[str]) -> int:
|
|
size = 0
|
|
for key, value in db.items():
|
|
size += io.write(self.kvrequest_type(key, value, None).line())
|
|
return size
|
|
|
|
def get(self, key: Any, default: Any):
|
|
return self.__mmdb.get(key, default)
|
|
|
|
async def set(self, key: Any, value: Any):
|
|
self.__mmdb[key] = value
|
|
future = self.__loop.create_future()
|
|
self.__queue.put_nowait(self.kvrequest_type(key, value, future))
|
|
await future
|
|
|
|
def set_nowait(self, key: Any, value: Any):
|
|
self.__mmdb[key] = value
|
|
self.__queue.put_nowait(self.kvrequest_type(key, value, None))
|
|
|
|
async def _dump_buffer_or_request_so(self, request: KVRequest):
|
|
if self.__buffer.tell() >= self.buffersize:
|
|
await self._dump_buffer()
|
|
request.set_result(None)
|
|
else:
|
|
await self.__queue.put(DumpRequest(request.future))
|
|
|
|
async def _write(self, line: str, request: KVRequest):
|
|
self.__buffer.write(line)
|
|
await self._dump_buffer_or_request_so(request)
|
|
|
|
def _clear_buffer(self):
|
|
self.__buffer = StringIO()
|
|
|
|
def _compress_buffer(self) -> StringIO:
|
|
self.__buffer.seek(0)
|
|
bufferdb = {}
|
|
self.io2db(self.__buffer, bufferdb)
|
|
buffer = StringIO()
|
|
self.db2io(bufferdb, buffer)
|
|
return buffer
|
|
|
|
async def _dump_compressed_buffer(self):
|
|
buffer = self._compress_buffer()
|
|
await self.__loop.run_in_executor(None, self.__file.write, buffer.getvalue())
|
|
|
|
async def _do_dump_buffer(self):
|
|
await self._dump_compressed_buffer()
|
|
self._clear_buffer()
|
|
|
|
async def _reload_if_oversized(self):
|
|
if self.__file.tell() > 2 * self.__initial_size:
|
|
await self._reload()
|
|
|
|
async def _dump_buffer(self):
|
|
if self.__buffer.tell():
|
|
await self._do_dump_buffer()
|
|
await self._reload_if_oversized()
|
|
|
|
async def _save_error(self, line: str):
|
|
with open(self.__path_error, 'a') as file:
|
|
await self.__loop.run_in_executor(None, file.write, line.strip() + '\n')
|
|
|
|
async def _handle_request(self, request: Request):
|
|
if isinstance(request, self.kvrequest_type):
|
|
await self._write(request.line(), request)
|
|
elif isinstance(request, DumpRequest):
|
|
await self._dump_buffer()
|
|
request.set_result(None)
|
|
elif isinstance(request, ErrorRequest):
|
|
await self._save_error(request.line)
|
|
else:
|
|
raise UnkownRequestType
|
|
|
|
async def _background_cycle(self):
|
|
request: Request = await self.__queue.get()
|
|
try:
|
|
await self._handle_request(request)
|
|
except Exception as e:
|
|
request.set_exception(e)
|
|
finally:
|
|
self.__queue.task_done()
|
|
|
|
async def _background_task(self):
|
|
while True:
|
|
await self._background_cycle()
|
|
|
|
async def _finish_recovery(self):
|
|
await self.__loop.run_in_executor(None, shutil.copy, self.__path_backup, self.__path)
|
|
self.__path_recover.unlink()
|
|
self.__path_backup.unlink()
|
|
|
|
async def _rebuild_file(self, db: dict):
|
|
if self.__path_recover.exists():
|
|
await self._finish_recovery()
|
|
self.__path.touch()
|
|
with open(self.__path) as file:
|
|
await self.__loop.run_in_executor(None, self.io2db, file, db)
|
|
with open(self.__path_backup, "w") as file:
|
|
self.__initial_size = await self.__loop.run_in_executor(None, self.db2io, db, file)
|
|
self.__path_recover.touch()
|
|
await self._finish_recovery()
|
|
|
|
async def _reload(self):
|
|
self.__file.close()
|
|
await self._rebuild_file({})
|
|
self.__file = open(self.__path, "a")
|
|
|
|
async def _load_from_file(self):
|
|
await self._rebuild_file(self.__mmdb)
|
|
|
|
async def _initialize_mmdb(self):
|
|
self.__mmdb = {}
|
|
await self._load_from_file()
|
|
self.__file = open(self.__path, "a")
|
|
|
|
async def _initialize_queue(self):
|
|
self.__queue = asyncio.Queue()
|
|
self.__buffer = StringIO()
|
|
|
|
async def _start_task(self):
|
|
self.__task = self.__loop.create_task(self._background_task())
|
|
|
|
async def _initialize(self):
|
|
assert self.__task is None
|
|
self.__loop = asyncio.get_event_loop()
|
|
await self._initialize_queue()
|
|
await self._initialize_mmdb()
|
|
await self._start_task()
|
|
|
|
async def __aenter__(self):
|
|
await self._initialize()
|
|
return self
|
|
|
|
async def _aclose(self):
|
|
if not self.__task.done():
|
|
await self.__queue.join()
|
|
self.__task.cancel()
|
|
await self._dump_buffer()
|
|
self.__file.close()
|
|
with open(self.__path_backup, "w") as file:
|
|
self.__initial_size = await self.__loop.run_in_executor(None, self.db2io, self.__mmdb, file)
|
|
self.__path_recover.touch()
|
|
await self._finish_recovery()
|
|
|
|
def _uninitialize(self):
|
|
self.__mmdb = None
|
|
self.__loop = None
|
|
self.__queue = None
|
|
self.__file = None
|
|
self.__buffer = None
|
|
self.__task = None
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
await self._aclose()
|
|
self._uninitialize()
|
|
|
|
def cursor(self, **kwargs):
|
|
return Cursor(self, **kwargs)
|
|
|
|
|
|
class Cursor:
|
|
def __init__(self, db: Db, default=None):
|
|
self.default = default
|
|
self.db = db
|
|
|
|
def __getitem__(self, item):
|
|
return self.db.get(item, self.default)
|
|
|
|
def __setitem__(self, key, value):
|
|
self.db.set_nowait(key, value)
|