re-encryption fix + factory moved into origin

This commit is contained in:
AF 2022-05-26 10:20:24 +03:00
parent 6c109a7354
commit d5d4026630
7 changed files with 20 additions and 9 deletions

View File

@ -75,7 +75,7 @@ class Encrypted(RecursiveMentionable, Generic[EncryptedType]):
if isinstance(hashpoint.origin, ResolverOrigin): if isinstance(hashpoint.origin, ResolverOrigin):
resolver: HashResolver = hashpoint.origin.resolver resolver: HashResolver = hashpoint.origin.resolver
assert isinstance(resolver, HashResolver) assert isinstance(resolver, HashResolver)
if isinstance(resolver, EncryptedResolver): if isinstance(resolver, EncryptedResolver) and resolver.encrypted.key == key:
return ShortcutOrigin( return ShortcutOrigin(
hashpoint.factory, hashpoint.factory,
resolver.encrypted.mapping[hashpoint.point], resolver.encrypted.mapping[hashpoint.point],
@ -156,6 +156,7 @@ class ShortcutOrigin(Origin[Encrypted[EncryptedType]], Generic[EncryptedType]):
self.factory: RainbowFactory[Encrypted[EncryptedType]] = EncryptedFactory(factory, key) self.factory: RainbowFactory[Encrypted[EncryptedType]] = EncryptedFactory(factory, key)
assert isinstance(self.factory, RainbowFactory) assert isinstance(self.factory, RainbowFactory)
self.hashpoint = hashpoint self.hashpoint = hashpoint
super().__init__(self.factory)
def resolve(self) -> Encrypted[EncryptedType]: def resolve(self) -> Encrypted[EncryptedType]:
encrypted = self.hashpoint.resolve() encrypted = self.hashpoint.resolve()
@ -165,7 +166,6 @@ class ShortcutOrigin(Origin[Encrypted[EncryptedType]], Generic[EncryptedType]):
def hash_point(self) -> HashPoint[Encrypted[EncryptedType]]: def hash_point(self) -> HashPoint[Encrypted[EncryptedType]]:
return HashPoint( return HashPoint(
self.factory,
self.hashpoint.point, self.hashpoint.point,
self self
) )

View File

@ -20,17 +20,16 @@ def _hash(source: bytes) -> bytes:
class HashPoint(Generic[HashMentioned]): class HashPoint(Generic[HashMentioned]):
def __init__( def __init__(
self, self,
factory: RainbowFactory[HashMentioned],
point: bytes, point: bytes,
origin: Origin[HashMentioned] origin: Origin[HashMentioned]
): ):
assert isinstance(factory, RainbowFactory)
assert isinstance(point, bytes) assert isinstance(point, bytes)
assert isinstance(origin, Origin) assert isinstance(origin, Origin)
assert len(point) == self.HASH_LENGTH assert len(point) == self.HASH_LENGTH
self.factory = factory
self.point = point self.point = point
self.origin = origin self.origin = origin
self.factory = origin.factory
assert isinstance(self.factory, RainbowFactory)
def __bytes__(self): def __bytes__(self):
return self.point return self.point
@ -55,7 +54,7 @@ class HashPoint(Generic[HashMentioned]):
def of(cls, mentioned: HashMentioned) -> 'HashPoint[HashMentioned]': def of(cls, mentioned: HashMentioned) -> 'HashPoint[HashMentioned]':
assert isinstance(mentioned, HashMentionable) assert isinstance(mentioned, HashMentionable)
return cls( return cls(
mentioned.__factory__(), cls.hash(cls.bytes_of_mentioned(mentioned)), LocalOrigin(mentioned) cls.hash(cls.bytes_of_mentioned(mentioned)), LocalOrigin(mentioned)
) )
def resolve(self) -> HashMentioned: def resolve(self) -> HashMentioned:

View File

@ -11,6 +11,7 @@ OriginType = TypeVar('OriginType')
class LocalOrigin(Origin[OriginType], Generic[OriginType]): class LocalOrigin(Origin[OriginType], Generic[OriginType]):
def __init__(self, value: OriginType): def __init__(self, value: OriginType):
assert isinstance(value, HashMentionable) assert isinstance(value, HashMentionable)
super().__init__(value.__factory__())
self.value: OriginType = value self.value: OriginType = value
def resolve(self) -> OriginType: def resolve(self) -> OriginType:

View File

@ -17,4 +17,4 @@ class MetaOrigin(Generic[OriginType]):
assert isinstance(factory, RainbowFactory) assert isinstance(factory, RainbowFactory)
assert isinstance(point, bytes) assert isinstance(point, bytes)
assert len(point) == HashPoint.HASH_LENGTH assert len(point) == HashPoint.HASH_LENGTH
return HashPoint(factory, point, self.origin(factory, point)) return HashPoint(point, self.origin(factory, point))

View File

@ -1,10 +1,15 @@
from typing import Generic, TypeVar from typing import Generic, TypeVar
from rainbowadn.hashing.rainbow_factory import RainbowFactory
__all__ = ('Origin',) __all__ = ('Origin',)
OriginType = TypeVar('OriginType') OriginType = TypeVar('OriginType')
class Origin(Generic[OriginType]): class Origin(Generic[OriginType]):
def __init__(self, factory: RainbowFactory[OriginType]):
self.factory = factory
def resolve(self) -> OriginType: def resolve(self) -> OriginType:
raise NotImplementedError raise NotImplementedError

View File

@ -24,6 +24,7 @@ class ResolverOrigin(Origin[OriginType], Generic[OriginType]):
self.factory = factory self.factory = factory
self.point = point self.point = point
self.resolver = resolver self.resolver = resolver
super().__init__(factory)
def resolve(self) -> OriginType: def resolve(self) -> OriginType:
resolved, resolver = self.resolver.resolve(self.point) resolved, resolver = self.resolver.resolve(self.point)
@ -36,4 +37,4 @@ class ResolverOrigin(Origin[OriginType], Generic[OriginType]):
return mentioned return mentioned
def hash_point(self) -> HashPoint[OriginType]: def hash_point(self) -> HashPoint[OriginType]:
return HashPoint(self.factory, self.point, self) return HashPoint(self.point, self)

View File

@ -156,9 +156,14 @@ class TestAll(unittest.TestCase):
tree = tree.add(HashPoint.of(Plain(b'NEWKEY'))) tree = tree.add(HashPoint.of(Plain(b'NEWKEY')))
tree = tree.remove(HashPoint.of(Plain(b'Q'))) tree = tree.remove(HashPoint.of(Plain(b'Q')))
print(tree.reference.str(0)) print(tree.reference.str(0))
with self.subTest('encrypt'): with self.subTest('encrypt and migrate'):
target = tree.reference target = tree.reference
eeed = Encrypted.encrypt(target, key) eeed = Encrypted.encrypt(target, key)
print(Encrypted.ecc) print(Encrypted.ecc)
dr.save(HashPoint.of(eeed)) dr.save(HashPoint.of(eeed))
print(ResolverMetaOrigin(dr).migrate(HashPoint.of(eeed)).resolve().decrypted.str(0)) print(ResolverMetaOrigin(dr).migrate(HashPoint.of(eeed)).resolve().decrypted.str(0))
with self.subTest('re-encrypt'):
new_key = b'b' * 32
target = eeed.decrypted
Encrypted.encrypt(target, new_key)
print(Encrypted.ecc)