From 0e6d58201eed117fdd53e2783e68906e1184e963 Mon Sep 17 00:00:00 2001 From: timotheyca Date: Wed, 22 Dec 2021 19:03:33 +0300 Subject: [PATCH] RoleRequest --- setup.py | 4 +- v6d0auth/app.py | 146 ++++++++++++++++++++++++-------- v6d0auth/cdb.py | 178 +++++++++++++++++++++++++++++++++------ v6d0auth/client.py | 34 +++++++- v6d0auth/remove-role.py | 28 ++++++ v6d0auth/run-server.py | 14 ++- v6d0auth/sign-request.py | 12 ++- v6d0auth/test-request.py | 10 ++- 8 files changed, 349 insertions(+), 77 deletions(-) create mode 100644 v6d0auth/remove-role.py diff --git a/setup.py b/setup.py index 2c2f56a..a905dfc 100644 --- a/setup.py +++ b/setup.py @@ -10,8 +10,8 @@ setup( author_email='', description='', install_requires=[ - 'setuptools~=57.0.0', 'aiohttp', - 'PyNaCl~=1.4.0' + 'PyNaCl~=1.4.0', + 'ptvp35 @ git+https://gitea.ongoteam.net/PTV/ptvp35.git@25727aabd7afd69f66051c806190480302e67260' ] ) diff --git a/v6d0auth/app.py b/v6d0auth/app.py index 8d19575..c73b6ab 100644 --- a/v6d0auth/app.py +++ b/v6d0auth/app.py @@ -1,76 +1,148 @@ import asyncio +import json +from typing import Optional from aiohttp import web, http_websocket -from nacl.exceptions import BadSignatureError from nacl.signing import VerifyKey +from nacl.utils import random -from v6d0auth import certs, cdb +from v6d0auth import certs __all__ = ('V6D0AuthAppFactory',) from v6d0auth.appfactory import AppFactory +from v6d0auth.cdb import CDB, Role, AbstractRequest class V6D0AuthAppFactory(AppFactory): - def __init__(self, loop: asyncio.AbstractEventLoop): - self.loop = loop + def __init__(self, cdb: CDB): + self.cdb = cdb def define_routes(self, routes: web.RouteTableDef): print(certs.vkey.encode().hex()) - mycdb = cdb.CDB(self.loop) - self.loop.create_task(mycdb.job()) + self.cdb.start() @routes.get('/') async def home(_request: web.Request): return web.Response(body='v6d0auth\n') - @routes.post('/approve') + async def ws_approve(ws: web.WebSocketResponse): + nonce = random(16) + await ws.send_bytes(nonce) + hhandle, hnonce = json.loads(certs.verify(await ws.receive_bytes())) + assert hnonce == nonce.hex() + approved = self.cdb.approve(bytes.fromhex(hhandle)) + await ws.send_bytes(approved) + + @routes.get('/approve') async def approve(request: web.Request): - try: - cert = mycdb.approve(await request.read()) - except BadSignatureError: - raise web.HTTPUnauthorized - except KeyError: - raise web.HTTPNotFound + ws = web.WebSocketResponse() + await ws.prepare(request) + await ws_approve(ws) + return ws + + async def requester_for_request(request: web.Request) -> VerifyKey: + return VerifyKey(await request.read()) + + def role_for_request(request: web.Request) -> Optional[str]: + return request.headers.get('v6role') + + def pushed_for_role(requester: VerifyKey, role: Optional[str]) -> AbstractRequest: + if role is None: + return self.cdb.push_requester(requester) else: - return web.Response(body=cert) + return self.cdb.push_role(Role(requester, role)) + + async def pushed_for_request(request: web.Request) -> AbstractRequest: + return pushed_for_role(await requester_for_request(request), role_for_request(request)) @routes.post('/push') async def push(request: web.Request): - try: - timeout = mycdb.push(VerifyKey(await request.read())).timeout - except KeyError: - raise web.HTTPTooManyRequests + pushed = await pushed_for_request(request) + timeout = pushed.timeout + return web.Response(body=str(timeout)) + + def pulled_for_role(requester: VerifyKey, role: Optional[str]) -> Optional[bytes]: + if role is None: + return self.cdb.pull_requester(requester) else: - return web.Response(body=str(timeout)) + return self.cdb.pull_role(Role(requester, role)) + + async def pulled_for_request(request: web.Request) -> Optional[bytes]: + return pulled_for_role(await requester_for_request(request), role_for_request(request)) @routes.post('/pull') async def pull(request: web.Request): try: - cert = mycdb.pull(VerifyKey(await request.read())) + pulled = await pulled_for_request(request) except KeyError: raise web.HTTPNotFound else: - return web.Response(body=cert) + return web.Response(body=pulled) + + @routes.post('/has_role') + async def has_role(request: web.Request): + role = role_for_request(request) + if role is None: + raise web.HTTPBadRequest + return web.Response( + body=(b'1' if self.cdb.has_role(Role(await requester_for_request(request), role)) else b'') + ) + + async def ws_remove(ws: web.WebSocketResponse): + nonce = random(16) + await ws.send_bytes(nonce) + [hrequester, role], hnonce = json.loads(certs.verify(await ws.receive_bytes())) + assert hnonce == nonce.hex() + self.cdb.remove_role(Role(VerifyKey(bytes.fromhex(hrequester)), role)) + await ws.send_bytes(b'0') + + @routes.get('/remove_role') + async def remove_role(request: web.Request): + ws = web.WebSocketResponse() + await ws.prepare(request) + await ws_remove(ws) + return ws + + def srq_for_role(requester: VerifyKey, role: Optional[str]) -> AbstractRequest: + if role is None: + return self.cdb.requester_mapping[requester] + else: + return self.cdb.role_mapping[Role(requester, role)] + + async def srq_for_ws(request: web.Request, ws: web.WebSocketResponse) -> AbstractRequest: + return srq_for_role(VerifyKey(await ws.receive_bytes()), role_for_request(request)) + + async def sqr_fail(ws: web.WebSocketResponse, srq: AbstractRequest) -> None: + srq.force_repair() + await ws.close(code=http_websocket.WSCloseCode.TRY_AGAIN_LATER) + + async def srq_success(ws: web.WebSocketResponse, approved: bytes) -> None: + await ws.send_bytes(approved) + await ws.close() + + async def srq_process(ws: web.WebSocketResponse, srq: AbstractRequest) -> None: + try: + approved = await srq.awaitable() + except asyncio.CancelledError: + await sqr_fail(ws, srq) + else: + await srq_success(ws, approved) + + async def ws_fail(ws: web.WebSocketResponse) -> None: + await ws.close(code=http_websocket.WSCloseCode.POLICY_VIOLATION) + + async def ws_process(request: web.Request, ws: web.WebSocketResponse) -> None: + try: + srq = await srq_for_ws(request, ws) + except TypeError: + await ws_fail(ws) + else: + await srq_process(ws, srq) @routes.get('/pullws') async def pullws(request: web.Request): ws = web.WebSocketResponse() await ws.prepare(request) - try: - srq = mycdb.requester_mapping[VerifyKey(await ws.receive_bytes())] - except TypeError: - await ws.close(code=http_websocket.WSCloseCode.POLICY_VIOLATION) - else: - try: - cert = await srq.future - except asyncio.CancelledError: - if not srq.future.cancelled(): - srq.future.cancel() - if not srq.cancelled: - srq.future = asyncio.get_event_loop().create_future() - await ws.close(code=http_websocket.WSCloseCode.TRY_AGAIN_LATER) - else: - await ws.send_bytes(cert) - await ws.close() + await ws_process(request, ws) return ws diff --git a/v6d0auth/cdb.py b/v6d0auth/cdb.py index 96a0d0f..218649c 100644 --- a/v6d0auth/cdb.py +++ b/v6d0auth/cdb.py @@ -3,27 +3,30 @@ import functools import heapq import time import weakref -from typing import MutableMapping, Optional +from typing import MutableMapping, Optional, Hashable +from nacl.exceptions import BadSignatureError from nacl.signing import VerifyKey from nacl.utils import random +from ptvp35 import Db, KVJson from v6d0auth import certs +from v6d0auth.config import myroot -__all__ = ('CDB',) +__all__ = ('CDB', 'Role', 'AbstractRequest',) TIMEOUT = 300 @functools.total_ordering -class SignatureRequest: - def __init__(self, requester: VerifyKey, loop: asyncio.AbstractEventLoop): - self._requester = requester +class AbstractRequest: + def __init__(self, loop: asyncio.AbstractEventLoop): + self._loop = loop self.timeout = time.time() + TIMEOUT self.handle: bytes = random(12) - self.approved: Optional[bytes] = None + self._approved: Optional[bytes] = None self.cancelled = False - self.future: asyncio.Future[bytes] = loop.create_future() + self.future: asyncio.Future[bytes] = self._loop.create_future() def __le__(self, other): if isinstance(other, SignatureRequest): @@ -34,25 +37,130 @@ class SignatureRequest: def timed_out(self) -> bool: return time.time() > self.timeout + def _approve(self) -> bytes: + raise NotImplementedError + + def _validate(self) -> None: + raise NotImplementedError + + def validate(self) -> None: + assert self._approved is not None + self._validate() + + def valid(self) -> bool: + try: + self.validate() + return True + except (AssertionError, ValueError, BadSignatureError): + return False + def approve(self) -> bytes: - if self.approved is None: - self.approved = certs.sign(bytes(self._requester)) - self.future.set_result(self.approved) - print('approved', self.handle.hex()) - return self.approved + approved = self.approved() + if approved is not None: + return approved + self._approved = self._approve() + self.future.set_result(self._approved) + print('validating', self.handle.hex()) + self.validate() + print('approved', self.handle.hex()) + return self._approved + + def repair(self): + if self.future.done(): + self.future: asyncio.Future[bytes] = self._loop.create_future() + print('repaired', self.handle.hex(), self.display()) + + def force_repair(self): + if not self.future.done(): + self.future.cancel() + self.repair() + + def approved(self) -> Optional[bytes]: + if self.valid(): + return self._approved + else: + self.repair() + return None + + async def awaitable(self) -> bytes: + approved = self.approved() + if approved is None: + approved = await self.future + return approved def cancel(self): if not self.future.done(): self.future.cancel() self.cancelled = True + def display(self) -> str: + raise NotImplementedError + + +class SignatureRequest(AbstractRequest): + def __init__(self, loop: asyncio.AbstractEventLoop, requester: VerifyKey): + super().__init__(loop) + self._requester = requester + + def _approve(self) -> bytes: + return certs.sign(bytes(self._requester)) + + def _validate(self) -> None: + assert certs.verify(self._approved) == bytes(self._requester) + + def display(self) -> str: + return self._requester.encode().hex() + + +class Role(Hashable): + def __init__(self, requester: VerifyKey, role: str): + self._requester = requester + self._role = role + + def key(self): + return self._requester.encode().hex(), self._role + + def __hash__(self): + return hash(self.key()) + + def __eq__(self, other): + if isinstance(other, Role): + return self.key() == other.key() + else: + return NotImplemented + + def display(self): + return f'{self._requester.encode().hex()}@{self._role}' + + +class RoleRequest(AbstractRequest): + def __init__(self, loop: asyncio.AbstractEventLoop, rdb: Db, role: Role): + super().__init__(loop) + self._rdb = rdb + self._role = role + + def _approve(self) -> bytes: + self._rdb.set_nowait(self._role.key(), True) + return b'1' + + def _validate(self) -> None: + assert self._rdb.get(self._role.key(), False) + + def display(self) -> str: + return self._role.display() + + +_rdbfile = myroot / 'roles.db' + class CDB: def __init__(self, loop: asyncio.AbstractEventLoop): - self.handle_mapping: MutableMapping[bytes, SignatureRequest] = weakref.WeakValueDictionary() + self.handle_mapping: MutableMapping[bytes, AbstractRequest] = weakref.WeakValueDictionary() self.requester_mapping: MutableMapping[VerifyKey, SignatureRequest] = weakref.WeakValueDictionary() - self.heap: list[SignatureRequest] = [] + self.role_mapping: MutableMapping[Role, RoleRequest] = weakref.WeakValueDictionary() + self.heap: list[AbstractRequest] = [] self._loop = loop + self.rdb = Db(_rdbfile, kvrequest_type=KVJson) def _cleanup(self): while self.heap and self.heap[0].timed_out(): @@ -66,27 +174,49 @@ class CDB: for request in self._cleanup(): print('cleaned', request.handle.hex()) - def push(self, requester: VerifyKey) -> SignatureRequest: - if requester in self.requester_mapping: - raise KeyError - request = SignatureRequest(requester, self._loop) - self.requester_mapping[requester] = request + def push_abstract(self, request: AbstractRequest): heapq.heappush(self.heap, request) self.handle_mapping[request.handle] = request - print('requested', request.handle.hex(), requester.encode().hex()) + print('requested', request.handle.hex(), request.display()) + + def push_requester(self, requester: VerifyKey) -> SignatureRequest: + if requester in self.requester_mapping: + return self.requester_mapping[requester] + request = SignatureRequest(self._loop, requester) + self.requester_mapping[requester] = request + self.push_abstract(request) + return request + + def push_role(self, role: Role) -> RoleRequest: + if role in self.role_mapping: + return self.role_mapping[role] + request = RoleRequest(self._loop, self.rdb, role) + self.role_mapping[role] = request + self.push_abstract(request) return request def _approve(self, handle: bytes) -> bytes: return self.handle_mapping[handle].approve() - def approve(self, data: bytes) -> bytes: - handle = certs.verify(data) + def approve(self, handle: bytes) -> bytes: return self._approve(handle) - def pull(self, vkey: VerifyKey) -> Optional[bytes]: - return self.requester_mapping[vkey].approved + def pull_requester(self, vkey: VerifyKey) -> Optional[bytes]: + return self.requester_mapping[vkey].approved() + + def pull_role(self, role: Role) -> Optional[bytes]: + return self.role_mapping[role].approved() + + def has_role(self, role: Role) -> bool: + return self.rdb.get(role.key(), False) + + def remove_role(self, role: Role): + return self.rdb.set_nowait(role.key(), False) async def job(self): while True: await asyncio.sleep(TIMEOUT) self.cleanup() + + def start(self): + self._loop.create_task(self.job()) diff --git a/v6d0auth/client.py b/v6d0auth/client.py index 951e163..4d37fc2 100644 --- a/v6d0auth/client.py +++ b/v6d0auth/client.py @@ -1,11 +1,11 @@ import aiohttp from nacl.exceptions import BadSignatureError +from nacl.signing import VerifyKey from v6d0auth import certs -from v6d0auth.certs import averify from v6d0auth.config import myroot, caurl -__all__ = ('request_signature', 'mycert') +__all__ = ('request_signature', 'mycert', 'has_role', 'request_role', 'with_role',) async def request_signature() -> bytes: @@ -18,7 +18,7 @@ async def request_signature() -> bytes: try: return await ws.receive_bytes() except TypeError: - raise TimeoutError + raise RuntimeError("signature request failed") _certfile = myroot / 'cert' @@ -27,8 +27,34 @@ _certfile = myroot / 'cert' async def mycert() -> bytes: try: cert = _certfile.read_bytes() - averify(cert) + certs.averify(cert) except (FileNotFoundError, BadSignatureError): cert = await request_signature() _certfile.write_bytes(cert) return cert + + +async def has_role(vkey: VerifyKey, role: str): + async with aiohttp.ClientSession() as session: + async with session.post(f'{caurl}/has_role', data=vkey.encode(), headers={'v6role': role}) as response: + return (await response.read()) == b'1' + + +async def request_role(role: str) -> bytes: + async with aiohttp.ClientSession() as session: + async with session.post(f'{caurl}/push', data=certs.vkey.encode(), headers={'v6role': role}) as response: + if response.status not in [200, 429]: + raise RuntimeError(response.status) + async with session.ws_connect(f'{caurl}/pullws', headers={'v6role': role}) as ws: + await ws.send_bytes(certs.vkey.encode()) + try: + return await ws.receive_bytes() + except TypeError: + raise RuntimeError("role request failed") + + +async def with_role(role: str): + if not await has_role(certs.vkey, role): + await request_role(role) + if not await has_role(certs.vkey, role): + raise RuntimeError("role request failed") diff --git a/v6d0auth/remove-role.py b/v6d0auth/remove-role.py new file mode 100644 index 0000000..b5b24c5 --- /dev/null +++ b/v6d0auth/remove-role.py @@ -0,0 +1,28 @@ +import argparse +import asyncio +import json + +import aiohttp + +from v6d0auth import certs +from v6d0auth.config import host, port + +parser = argparse.ArgumentParser() +parser.add_argument('requester', type=str) +parser.add_argument('role', type=str) + + +async def main(): + requester = bytes.fromhex(args.requester) + role = args.role + async with aiohttp.ClientSession() as session: + # noinspection HttpUrlsUsage + async with session.ws_connect(f'http://{host}:{port}/remove_role') as ws: + nonce = await ws.receive_bytes() + await ws.send_bytes(certs.sign(json.dumps([[requester.hex(), role], nonce.hex()]).encode())) + print((await ws.receive_bytes()).hex()) + + +if __name__ == '__main__': + args = parser.parse_args() + asyncio.run(main()) diff --git a/v6d0auth/run-server.py b/v6d0auth/run-server.py index 29b293d..15eb216 100644 --- a/v6d0auth/run-server.py +++ b/v6d0auth/run-server.py @@ -1,8 +1,18 @@ import asyncio from v6d0auth.app import V6D0AuthAppFactory +from v6d0auth.cdb import CDB from v6d0auth.run_app import run_app + +async def main(): + cdb = CDB(asyncio.get_running_loop()) + async with cdb.rdb: + await run_app(V6D0AuthAppFactory(cdb).app()) + + if __name__ == '__main__': - loop = asyncio.get_event_loop() - loop.run_until_complete(run_app(V6D0AuthAppFactory(loop).app())) + try: + asyncio.run(main()) + except KeyboardInterrupt: + pass diff --git a/v6d0auth/sign-request.py b/v6d0auth/sign-request.py index 5ce366e..c6a7709 100644 --- a/v6d0auth/sign-request.py +++ b/v6d0auth/sign-request.py @@ -1,5 +1,6 @@ import argparse import asyncio +import json import aiohttp @@ -12,15 +13,12 @@ parser.add_argument('handle', type=str) async def main(): handle = bytes.fromhex(args.handle) - request = certs.sign(handle) async with aiohttp.ClientSession() as session: # noinspection HttpUrlsUsage - async with session.post(f'http://{host}:{port}/approve', data=request) as response: - print(response.status) - if response.status == 200: - print((await response.read()).hex()) - else: - print(await response.text()) + async with session.ws_connect(f'http://{host}:{port}/approve') as ws: + nonce = await ws.receive_bytes() + await ws.send_bytes(certs.sign(json.dumps([handle.hex(), nonce.hex()]).encode())) + print((await ws.receive_bytes()).hex()) if __name__ == '__main__': diff --git a/v6d0auth/test-request.py b/v6d0auth/test-request.py index 01caec5..892f12a 100644 --- a/v6d0auth/test-request.py +++ b/v6d0auth/test-request.py @@ -1,10 +1,18 @@ import asyncio +from subprocess import call +from sys import executable -from v6d0auth.client import request_signature +from v6d0auth import certs +from v6d0auth.client import request_signature, has_role, request_role async def main(): + print(certs.vkey.encode().hex()) print((await request_signature()).hex()) + call([executable, '-m', 'v6d0auth.remove-role', certs.vkey.encode().hex(), 'test']) + print(await has_role(certs.vkey, 'test')) + print(await request_role('test')) + print(await has_role(certs.vkey, 'test')) if __name__ == '__main__':