Merge pull request #52 from zcash-hackworks/merkle-path

Merkle path test vectors for a depth-4 tree
This commit is contained in:
str4d 2021-09-17 01:16:21 +12:00 committed by GitHub
commit 3e0835b140
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 114 additions and 3 deletions

View File

@ -6,6 +6,7 @@ from binascii import unhexlify
from orchard_pallas import Fp
from orchard_sinsemilla import sinsemilla_hash
from tv_output import render_args, render_tv
from utils import i2lebsp, leos2bsp
# https://zips.z.cash/protocol/nu5.pdf#constants
@ -14,11 +15,11 @@ L_MERKLE = 255
UNCOMMITTED_ORCHARD = Fp(2)
# https://zips.z.cash/protocol/nu5.pdf#orchardmerklecrh
def merkle_crh(layer, left, right):
assert layer < MERKLE_DEPTH
def merkle_crh(layer, left, right, depth=MERKLE_DEPTH):
assert layer < depth
assert len(left) == L_MERKLE
assert len(right) == L_MERKLE
l = i2lebsp(10, MERKLE_DEPTH - 1 - layer)
l = i2lebsp(10, depth - 1 - layer)
return sinsemilla_hash(b"z.cash:Orchard-MerkleCRH", l + left + right)
left = unhexlify("87a086ae7d2252d58729b30263fb7b66308bf94ef59a76c9c86e7ea016536505")[::-1]
@ -38,3 +39,96 @@ def empty_roots():
bits = i2lebsp(L_MERKLE, empty_roots[-1].s)
empty_roots.append(merkle_crh(layer, bits, bits))
return empty_roots
def main():
args = render_args()
from random import Random
from tv_rand import Rand
rng = Random(0xabad533d)
def randbytes(l):
ret = []
while len(ret) < l:
ret.append(rng.randrange(0, 256))
return bytes(ret)
rand = Rand(randbytes)
SMALL_DEPTH = 4
# Derive path for each leaf in a tree of depth 4.
def get_paths_and_root(leaves):
assert(len(leaves) == (1 << SMALL_DEPTH))
paths = [[] for _ in range(1 << SMALL_DEPTH)]
# At layer 0, we want:
# - leaf 0: sibling 1
# - leaf 1: sibling 0
# - leaf 2: sibling 3
# - leaf 3: sibling 2 (etc.)
# We repeat this all the way up, just with shorter arrays.
cur_layer = leaves
next_layer = []
for l in range(0, SMALL_DEPTH):
# Iterate over nodes in the current layer.
for i in range(0, len(cur_layer)):
is_left = (i % 2) == 0
sibling = cur_layer[i + 1] if is_left else cur_layer[i - 1]
# As we compute the tree, we start appending siblings to
# multiple paths. Each sibling corresponds to (1 << layer)
# leaves.
leaves_per_sibling = (1 << l)
for j in range(leaves_per_sibling * i, leaves_per_sibling * (i+1)):
paths[j].append(sibling)
# Compute the parent of the current pair of siblings.
if is_left:
layer = SMALL_DEPTH - 1 - l
left = leos2bsp(bytes(cur_layer[i]))[:L_MERKLE]
right = leos2bsp(bytes(sibling))[:L_MERKLE]
next_layer.append(merkle_crh(layer, left, right, depth=SMALL_DEPTH))
cur_layer = next_layer
next_layer = []
# We should have reached the root of the tree.
assert(len(cur_layer) == 1)
return (paths, cur_layer[0])
# Test vectors:
# - Create empty tree of depth 4.
# - Append random leaves
# - After each leaf is appended, derive the Merkle paths for every leaf
# position (using the empty leaf for positions that have not been filled).
test_vectors = []
leaves = [UNCOMMITTED_ORCHARD] * (1 << SMALL_DEPTH)
for i in range(0, (1 << SMALL_DEPTH)):
print("Appending leaf", i + 1, file = sys.stderr)
# Append next leaf
leaves[i] = Fp.random(rand)
# Derive Merkle paths for all leaves
(paths, root) = get_paths_and_root(leaves)
test_vectors.append({
'leaves': [bytes(leaf) for leaf in leaves],
'paths': [[bytes(node) for node in path] for path in paths],
'root': bytes(root),
})
render_tv(
args,
'orchard_merkle_tree',
(
('leaves', '[[u8; 32]; %d]' % (1 << SMALL_DEPTH)),
('paths', '[[[u8; 32]; %d]; %d]' % (SMALL_DEPTH, (1 << SMALL_DEPTH))),
('root', '[u8; 32]'),
),
test_vectors,
)
if __name__ == '__main__':
main()

View File

@ -153,6 +153,23 @@ def tv_part_rust(name, value, config, indent=3):
' ' * (indent + 1),
chunk(hexlify(item)),
))
elif type(item) == list:
print('''%s[''' % (
' ' * (indent + 1)
))
for subitem in item:
if type(subitem) == bytes:
print('''%s[%s],''' % (
' ' * (indent + 2),
chunk(hexlify(subitem)),
))
else:
raise ValueError('Invalid sublist type(%s): %s' % (name, type(subitem)))
print('''%s],''' % (
' ' * (indent + 1)
))
else:
raise ValueError('Invalid list type(%s): %s' % (name, type(item)))
print('''%s],''' % (
pad,
))