From d5d40266301b6281e8eda225f2c441086850aaea Mon Sep 17 00:00:00 2001 From: timotheyca Date: Thu, 26 May 2022 10:20:24 +0300 Subject: [PATCH] re-encryption fix + factory moved into origin --- rainbowadn/encryption/encrypted.py | 4 ++-- rainbowadn/hashing/hashpoint.py | 7 +++---- rainbowadn/hashing/localorigin.py | 1 + rainbowadn/hashing/metaorigin.py | 2 +- rainbowadn/hashing/origin.py | 5 +++++ rainbowadn/hashing/resolverorigin.py | 3 ++- rainbowadn/testing/test_all.py | 7 ++++++- 7 files changed, 20 insertions(+), 9 deletions(-) diff --git a/rainbowadn/encryption/encrypted.py b/rainbowadn/encryption/encrypted.py index 8f3f33a..ddd0d2a 100644 --- a/rainbowadn/encryption/encrypted.py +++ b/rainbowadn/encryption/encrypted.py @@ -75,7 +75,7 @@ class Encrypted(RecursiveMentionable, Generic[EncryptedType]): if isinstance(hashpoint.origin, ResolverOrigin): resolver: HashResolver = hashpoint.origin.resolver assert isinstance(resolver, HashResolver) - if isinstance(resolver, EncryptedResolver): + if isinstance(resolver, EncryptedResolver) and resolver.encrypted.key == key: return ShortcutOrigin( hashpoint.factory, resolver.encrypted.mapping[hashpoint.point], @@ -156,6 +156,7 @@ class ShortcutOrigin(Origin[Encrypted[EncryptedType]], Generic[EncryptedType]): self.factory: RainbowFactory[Encrypted[EncryptedType]] = EncryptedFactory(factory, key) assert isinstance(self.factory, RainbowFactory) self.hashpoint = hashpoint + super().__init__(self.factory) def resolve(self) -> Encrypted[EncryptedType]: encrypted = self.hashpoint.resolve() @@ -165,7 +166,6 @@ class ShortcutOrigin(Origin[Encrypted[EncryptedType]], Generic[EncryptedType]): def hash_point(self) -> HashPoint[Encrypted[EncryptedType]]: return HashPoint( - self.factory, self.hashpoint.point, self ) diff --git a/rainbowadn/hashing/hashpoint.py b/rainbowadn/hashing/hashpoint.py index 524e433..56d034f 100644 --- a/rainbowadn/hashing/hashpoint.py +++ b/rainbowadn/hashing/hashpoint.py @@ -20,17 +20,16 @@ def _hash(source: bytes) -> bytes: class HashPoint(Generic[HashMentioned]): def __init__( self, - factory: RainbowFactory[HashMentioned], point: bytes, origin: Origin[HashMentioned] ): - assert isinstance(factory, RainbowFactory) assert isinstance(point, bytes) assert isinstance(origin, Origin) assert len(point) == self.HASH_LENGTH - self.factory = factory self.point = point self.origin = origin + self.factory = origin.factory + assert isinstance(self.factory, RainbowFactory) def __bytes__(self): return self.point @@ -55,7 +54,7 @@ class HashPoint(Generic[HashMentioned]): def of(cls, mentioned: HashMentioned) -> 'HashPoint[HashMentioned]': assert isinstance(mentioned, HashMentionable) return cls( - mentioned.__factory__(), cls.hash(cls.bytes_of_mentioned(mentioned)), LocalOrigin(mentioned) + cls.hash(cls.bytes_of_mentioned(mentioned)), LocalOrigin(mentioned) ) def resolve(self) -> HashMentioned: diff --git a/rainbowadn/hashing/localorigin.py b/rainbowadn/hashing/localorigin.py index b6ad665..1fb9495 100644 --- a/rainbowadn/hashing/localorigin.py +++ b/rainbowadn/hashing/localorigin.py @@ -11,6 +11,7 @@ OriginType = TypeVar('OriginType') class LocalOrigin(Origin[OriginType], Generic[OriginType]): def __init__(self, value: OriginType): assert isinstance(value, HashMentionable) + super().__init__(value.__factory__()) self.value: OriginType = value def resolve(self) -> OriginType: diff --git a/rainbowadn/hashing/metaorigin.py b/rainbowadn/hashing/metaorigin.py index 0674f51..0c93805 100644 --- a/rainbowadn/hashing/metaorigin.py +++ b/rainbowadn/hashing/metaorigin.py @@ -17,4 +17,4 @@ class MetaOrigin(Generic[OriginType]): assert isinstance(factory, RainbowFactory) assert isinstance(point, bytes) assert len(point) == HashPoint.HASH_LENGTH - return HashPoint(factory, point, self.origin(factory, point)) + return HashPoint(point, self.origin(factory, point)) diff --git a/rainbowadn/hashing/origin.py b/rainbowadn/hashing/origin.py index 9c6fd3a..fe013ec 100644 --- a/rainbowadn/hashing/origin.py +++ b/rainbowadn/hashing/origin.py @@ -1,10 +1,15 @@ from typing import Generic, TypeVar +from rainbowadn.hashing.rainbow_factory import RainbowFactory + __all__ = ('Origin',) OriginType = TypeVar('OriginType') class Origin(Generic[OriginType]): + def __init__(self, factory: RainbowFactory[OriginType]): + self.factory = factory + def resolve(self) -> OriginType: raise NotImplementedError diff --git a/rainbowadn/hashing/resolverorigin.py b/rainbowadn/hashing/resolverorigin.py index 77809ac..8645c98 100644 --- a/rainbowadn/hashing/resolverorigin.py +++ b/rainbowadn/hashing/resolverorigin.py @@ -24,6 +24,7 @@ class ResolverOrigin(Origin[OriginType], Generic[OriginType]): self.factory = factory self.point = point self.resolver = resolver + super().__init__(factory) def resolve(self) -> OriginType: resolved, resolver = self.resolver.resolve(self.point) @@ -36,4 +37,4 @@ class ResolverOrigin(Origin[OriginType], Generic[OriginType]): return mentioned def hash_point(self) -> HashPoint[OriginType]: - return HashPoint(self.factory, self.point, self) + return HashPoint(self.point, self) diff --git a/rainbowadn/testing/test_all.py b/rainbowadn/testing/test_all.py index cfce391..cf4cdd1 100644 --- a/rainbowadn/testing/test_all.py +++ b/rainbowadn/testing/test_all.py @@ -156,9 +156,14 @@ class TestAll(unittest.TestCase): tree = tree.add(HashPoint.of(Plain(b'NEWKEY'))) tree = tree.remove(HashPoint.of(Plain(b'Q'))) print(tree.reference.str(0)) - with self.subTest('encrypt'): + with self.subTest('encrypt and migrate'): target = tree.reference eeed = Encrypted.encrypt(target, key) print(Encrypted.ecc) dr.save(HashPoint.of(eeed)) print(ResolverMetaOrigin(dr).migrate(HashPoint.of(eeed)).resolve().decrypted.str(0)) + with self.subTest('re-encrypt'): + new_key = b'b' * 32 + target = eeed.decrypted + Encrypted.encrypt(target, new_key) + print(Encrypted.ecc)