Skip to content

Commit

Permalink
add: any size leaves proof
Browse files Browse the repository at this point in the history
  • Loading branch information
olivmath committed Jul 7, 2023
1 parent a837cf8 commit 44fe95c
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 20 deletions.
30 changes: 30 additions & 0 deletions merkly/mtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from merkly.node import Node, Side
from merkly.utils import (
hash_function_type_checking,
is_power_2,
slice_in_pairs,
keccak,
half,
Expand Down Expand Up @@ -109,6 +110,9 @@ def make_proof(self, leafs: List[str], proof: List[Node], leaf: str) -> List[Nod
msg = f"Leaf: {leaf} does not exist in the tree: {leafs}"
raise ValueError(msg) from err

if is_power_2(len(leafs)) is False:
return self.mix_tree(leafs, [], index)

if len(leafs) == 2:
if index == 1:
proof.append(Node(data=leafs[0], side=Side.LEFT))
Expand All @@ -125,3 +129,29 @@ def make_proof(self, leafs: List[str], proof: List[Node], leaf: str) -> List[Nod
else:
proof.append(Node(data=self.make_root(left)[0], side=Side.LEFT))
return self.make_proof(right, proof, leaf)

def mix_tree(
self, leaves: List[str], proof: List[Node], leaf_index: int
) -> List[Node]:
if len(leaves) == 1:
return proof

if leaf_index % 2 == 0:
if leaf_index + 1 < len(leaves):
node = Node(data=leaves[leaf_index + 1], side=Side.RIGHT)
proof.append(node)
else:
node = Node(data=leaves[leaf_index - 1], side=Side.LEFT)
proof.append(node)

return self.mix_tree(self.up_layer(leaves), proof, leaf_index // 2)

def up_layer(self, leaves: List[str]) -> List[str]:
new_layer = []
for pair in slice_in_pairs(leaves):
if len(pair) == 1:
new_layer.append(pair[0])
else:
data = self.hash_function(pair[0], pair[1])
new_layer.append(data)
return new_layer
17 changes: 10 additions & 7 deletions merkly/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,15 @@ def slice_in_pairs(list_item: list):


def hash_function_type_checking(hash_function: Callable[[str], str]) -> bool:
is_valid = (
isinstance(hash_function, types.FunctionType)
and callable(hash_function)
and isinstance(hash_function(str()), str)
)
if hash_function is not None and not is_valid:
a = isinstance(hash_function, types.FunctionType)
b = callable(hash_function)
try:
c = isinstance(hash_function(str(), str()), str)
except TypeError:
c = False

valid = a and b and c
if not valid:
raise InvalidHashFunctionError()


Expand All @@ -113,4 +116,4 @@ def is_power_2(number: int) -> bool:
if left and right:
return True
else:
raise PowerOfTwoError(number)
return False
17 changes: 9 additions & 8 deletions test/test_merkle_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,13 @@ def test_proof_simple_odd_merkle():
Instantiated a simple Merkle Tree
"""
leafs = ["a", "b", "c", "d", "e"]
tree = MerkleTree(leafs)
tree = MerkleTree(leafs, lambda x, y: x + y)
proof = [
Node(right="b5553de315e0edf504d9150af82dafa5c4667fa618ed0a6f19c69b41166c5510"),
Node(right="ed3a2f2a068b98ea5eb600912326df7b62037603d2633eba4ccc9a3845674b90"),
Node(data="abcd", side=Side.LEFT),
]

assert tree.proof("a") == proof
assert tree.verify(proof, "a") == True
assert tree.proof("e") == proof, "Proofs dont's match"
assert tree.verify(proof, "e"), "Proof dont's right"


def test_proof_simple_merkle():
Expand All @@ -86,8 +85,7 @@ def test_proof_simple_merkle():
"""
leafs = ["a", "b", "c", "d"]
tree = MerkleTree(leafs)

assert tree.proof("a") == [
proof = [
Node(
side=Side.RIGHT,
data="b5553de315e0edf504d9150af82dafa5c4667fa618ed0a6f19c69b41166c5510",
Expand Down Expand Up @@ -150,4 +148,7 @@ def invalid_hash_function_that_returns_an_integer_instead_of_a_string(data):
return 123

with raises(InvalidHashFunctionError):
MerkleTree(["a", "b", "c", "d"], invalid_hash_function_that_returns_an_integer_instead_of_a_string)
MerkleTree(
["a", "b", "c", "d"],
invalid_hash_function_that_returns_an_integer_instead_of_a_string,
)
6 changes: 1 addition & 5 deletions test/utils/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,4 @@
],
)
def test_of_is_power_2(number: int, ok: bool):
if ok:
assert ok == is_power_2(number)
else:
with raises(PowerOfTwoError):
is_power_2(number)
assert ok == is_power_2(number)

0 comments on commit 44fe95c

Please sign in to comment.