rainbowadn/rainbowadn/encryption/encrypted.py

187 lines
7.1 KiB
Python

from typing import Generic, Iterable, TypeVar
from nacl.bindings import crypto_hash_sha256
from nacl.secret import SecretBox
from rainbowadn.hashing.hashmentionable import HashMentionable
from rainbowadn.hashing.hashpoint import HashPoint
from rainbowadn.hashing.hashresolver import HashResolver
from rainbowadn.hashing.origin import Origin
from rainbowadn.hashing.rainbow_factory import RainbowFactory
from rainbowadn.hashing.recursivementionable import RecursiveMentionable
from rainbowadn.hashing.resolverorigin import ResolverOrigin
__all__ = ('Encrypted', 'EncryptedFactory')
EncryptedType = TypeVar('EncryptedType')
class Encrypted(RecursiveMentionable, Generic[EncryptedType]):
def __init__(
self,
key: bytes,
resolution: tuple[HashPoint['Encrypted'], ...],
decrypted: EncryptedType
):
assert isinstance(key, bytes)
assert isinstance(resolution, tuple)
assert isinstance(decrypted, HashMentionable)
self.factory: RainbowFactory[EncryptedType] = decrypted.__factory__()
self.key = key
self.resolution = resolution
self.decrypted: EncryptedType = decrypted
self.hashpoints = tuple(decrypted.points()) if isinstance(decrypted, RecursiveMentionable) else ()
assert len(self.hashpoints) == len(self.resolution)
self.mapping: dict[bytes, HashPoint[Encrypted]] = {
hashpoint.point: encrypted for hashpoint, encrypted in zip(self.hashpoints, resolution)
}
def points(self) -> Iterable[HashPoint]:
return self.resolution
ecc = 0
@classmethod
def encrypt(cls, decrypted: EncryptedType, key: bytes) -> 'Encrypted[EncryptedType]':
cls.ecc += 1
assert isinstance(key, bytes)
hashpoints = tuple(decrypted.points()) if isinstance(decrypted, RecursiveMentionable) else ()
encrypted: Encrypted[EncryptedType] = object.__new__(cls)
encrypted.__init__(
key,
tuple(
cls.encrypt_hashpoint(hashpoint, key)
for
hashpoint
in
hashpoints
),
decrypted.__factory__().from_bytes(
bytes(decrypted),
EncryptedResolver(encrypted)
)
)
return encrypted
@classmethod
def encrypt_hashpoint(
cls, hashpoint: HashPoint[EncryptedType], key: bytes
) -> HashPoint['Encrypted[EncryptedType]']:
assert isinstance(hashpoint, HashPoint)
assert isinstance(key, bytes)
if isinstance(hashpoint.origin, ResolverOrigin):
resolver: HashResolver = hashpoint.origin.resolver
assert isinstance(resolver, HashResolver)
if isinstance(resolver, EncryptedResolver) and resolver.encrypted.key == key:
return ShortcutOrigin(
hashpoint.factory,
resolver.encrypted.mapping[hashpoint.point],
key
).hash_point()
return HashPoint.of(cls.encrypt(hashpoint.resolve(), key))
def __bytes__(self):
source: bytes = len(self.resolution).to_bytes(8, 'little') + b''.join(
encrypted.point
for
encrypted
in self.resolution
) + bytes(self.decrypted)
nonce: bytes = crypto_hash_sha256(self.key + source)[:24]
return SecretBox(self.key).encrypt(source, nonce=nonce)
def __factory__(self) -> RainbowFactory['Encrypted[EncryptedType]']:
return EncryptedFactory(self.factory, self.key)
class EncryptedFactory(RainbowFactory[Encrypted[EncryptedType]], Generic[EncryptedType]):
def __init__(self, factory: RainbowFactory[EncryptedType], key: bytes):
assert isinstance(factory, RainbowFactory)
assert isinstance(key, bytes)
self.factory = factory
self.key = key
def from_bytes(self, source: bytes, resolver: HashResolver) -> Encrypted[EncryptedType]:
assert isinstance(source, bytes)
assert isinstance(resolver, HashResolver)
plain: bytes = SecretBox(self.key).decrypt(source)
resolution_size: int = int.from_bytes(plain[:8], 'little')
encrypted: Encrypted[EncryptedType] = object.__new__(Encrypted)
decrypted: EncryptedType = self.factory.from_bytes(
plain[8 + resolution_size * HashPoint.HASH_LENGTH:],
EncryptedResolver(encrypted)
)
hashpoints = tuple(decrypted.points()) if isinstance(decrypted, RecursiveMentionable) else ()
assert len(hashpoints) == resolution_size
resolution: tuple[HashPoint[Encrypted], ...] = tuple(
ResolverOrigin(
EncryptedFactory(
hashpoint.factory,
self.key
),
plain[8 + i * HashPoint.HASH_LENGTH: 8 + (i + 1) * HashPoint.HASH_LENGTH],
resolver
).hash_point() for i, hashpoint in enumerate(hashpoints)
)
encrypted.__init__(
self.key,
resolution,
decrypted
)
return encrypted
def loose(self) -> RainbowFactory[Encrypted[EncryptedType]]:
return self
class EncryptedResolver(HashResolver):
def __init__(self, encrypted: Encrypted):
assert isinstance(encrypted, Encrypted)
self.encrypted = encrypted
def resolve(self, point: bytes) -> tuple[bytes, 'HashResolver']:
assert isinstance(point, bytes)
encrypted = self.encrypted.mapping[point].resolve()
return HashPoint.bytes_of_mentioned(encrypted.decrypted), EncryptedResolver(encrypted)
class ShortcutOrigin(Origin[Encrypted[EncryptedType]], Generic[EncryptedType]):
def __init__(self, factory: RainbowFactory[EncryptedType], hashpoint: HashPoint[Encrypted], key: bytes):
assert isinstance(factory, RainbowFactory)
assert isinstance(hashpoint, HashPoint)
assert isinstance(key, bytes)
self.factory: RainbowFactory[Encrypted[EncryptedType]] = EncryptedFactory(factory, key)
assert isinstance(self.factory, RainbowFactory)
self.hashpoint = hashpoint
super().__init__(self.factory)
def resolve(self) -> Encrypted[EncryptedType]:
encrypted = self.hashpoint.resolve()
encrypted = self.factory.from_bytes(bytes(encrypted), ShortcutResolver(encrypted))
assert HashPoint.of(encrypted) == self.hashpoint
return encrypted
def hash_point(self) -> HashPoint[Encrypted[EncryptedType]]:
return HashPoint(
self.hashpoint.point,
self
)
class ShortcutResolver(HashResolver):
def __init__(self, encrypted: Encrypted):
assert isinstance(encrypted, Encrypted)
self.mapping: dict[bytes, HashPoint[Encrypted]] = {
hashpoint.point: hashpoint for hashpoint in encrypted.resolution
}
def resolve(self, point: bytes) -> tuple[bytes, 'HashResolver']:
assert isinstance(point, bytes)
return (
HashPoint.bytes_of_mentioned(self.mapping[point].resolve()),
ShortcutResolver(self.mapping[point].resolve())
)