diff --git a/ptvp35/__init__.py b/ptvp35/__init__.py index 4efb0a1..e9aba02 100644 --- a/ptvp35/__init__.py +++ b/ptvp35/__init__.py @@ -8,10 +8,10 @@ from typing import Any, Optional, IO, Type class Request: def set_result(self, result): - pass + raise NotImplementedError def set_exception(self, exception): - pass + raise NotImplementedError class KVRequest(Request): @@ -38,7 +38,14 @@ class KVRequest(Request): class DumpRequest(Request): - pass + def __init__(self, future: Optional[asyncio.Future]): + self.future = future + + def set_result(self, result): + self.future and self.future.set_result(result) + + def set_exception(self, exception): + self.future and self.future.set_exception(exception) class UnkownRequestType(TypeError): @@ -105,15 +112,16 @@ class Db: self.__mmdb[key] = value self.__queue.put_nowait(self.kvrequest_type(key, value, None)) - async def _dump_buffer_or_request_so(self): + async def _dump_buffer_or_request_so(self, request: KVRequest): if self.__buffer.tell() >= self.buffersize: await self._dump_buffer() + request.set_result(None) else: - await self.__queue.put(DumpRequest()) + await self.__queue.put(DumpRequest(request.future)) - async def _write(self, line: str): + async def _write(self, line: str, request: KVRequest): self.__buffer.write(line) - await self._dump_buffer_or_request_so() + await self._dump_buffer_or_request_so(request) def _clear_buffer(self): self.__buffer = StringIO() @@ -145,12 +153,12 @@ class Db: async def _handle_request(self, request: Request): if isinstance(request, self.kvrequest_type): - await self._write(request.line()) + await self._write(request.line(), request) elif isinstance(request, DumpRequest): await self._dump_buffer() + request.set_result(None) else: raise UnkownRequestType - request.set_result(None) async def _background_cycle(self): request: Request = await self.__queue.get()