diff --git a/sapling_signatures.py b/sapling_signatures.py index a8ce8ab..3088417 100644 --- a/sapling_signatures.py +++ b/sapling_signatures.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -from binascii import hexlify import os from pyblake2 import blake2b @@ -7,7 +6,7 @@ from sapling_generators import SPENDING_KEY_BASE from sapling_jubjub import Fr, Point, r_j from sapling_key_components import to_scalar from sapling_utils import cldiv, leos2ip -from tv_output import chunk +from tv_output import tv_rust def H(x): @@ -75,20 +74,7 @@ def main(): return bytes(ret) rj = RedJubjub(SPENDING_KEY_BASE, randbytes) - print(''' - struct TestVector { - sk: [u8; 32], - vk: [u8; 32], - alpha: [u8; 32], - rsk: [u8; 32], - rvk: [u8; 32], - m: [u8; 32], - sig: [u8; 64], - rsig: [u8; 64], - }; - - // From https://github.com/zcash-hackworks/zcash-test-vectors/blob/master/sapling_signatures.py - let test_vectors = vec![''') + test_vectors = [] for i in range(0, 10): sk = rj.gen_private() vk = rj.derive_public(sk) @@ -104,42 +90,31 @@ def main(): assert not rj.verify(vk, M, rsig) assert not rj.verify(rvk, M, sig) - print(''' TestVector { - sk: [ - %s - ], - vk: [ - %s - ], - alpha: [ - %s - ], - rsk: [ - %s - ], - rvk: [ - %s - ], - m: [ - %s - ], - sig: [ - %s - ], - rsig: [ - %s - ], - },''' % ( - chunk(hexlify(bytes(sk))), - chunk(hexlify(bytes(vk))), - chunk(hexlify(bytes(alpha))), - chunk(hexlify(bytes(rsk))), - chunk(hexlify(bytes(rvk))), - chunk(hexlify(M)), - chunk(hexlify(sig)), - chunk(hexlify(rsig)), - )) - print(' ];') + test_vectors.append({ + 'sk': bytes(sk), + 'vk': bytes(vk), + 'alpha': bytes(alpha), + 'rsk': bytes(rsk), + 'rvk': bytes(rvk), + 'm': M, + 'sig': sig, + 'rsig': rsig, + }) + + tv_rust( + 'sapling_signatures', + ( + ('sk', '[u8; 32]'), + ('vk', '[u8; 32]'), + ('alpha', '[u8; 32]'), + ('rsk', '[u8; 32]'), + ('rvk', '[u8; 32]'), + ('m', '[u8; 32]'), + ('sig', '[u8; 64]'), + ('rsig', '[u8; 64]'), + ), + test_vectors, + ) if __name__ == '__main__': diff --git a/tv_output.py b/tv_output.py index b0a848d..fae030b 100644 --- a/tv_output.py +++ b/tv_output.py @@ -5,13 +5,17 @@ def chunk(h): h = str(h, 'utf-8') return '0x' + ', 0x'.join([h[i:i+2] for i in range(0, len(h), 2)]) -def tv_part_rust(name, value): - print(''' %s: [ - %s - ],''' % ( - name, - chunk(hexlify(value)) - )) +def tv_part_rust(name, value, indent=3): + pad = ' ' * indent + print('''%s%s: [ + %s%s +%s],''' % ( + pad, + name, + pad, + chunk(hexlify(value)), + pad, + )) def tv_rust(filename, parts, vectors): print(' struct TestVector {') @@ -25,5 +29,12 @@ def tv_rust(filename, parts, vectors): print(' let test_vector = TestVector {') [tv_part_rust(p[0], vectors[p[0]]) for p in parts] print(' };') + elif type(vectors) == type([]): + print(' let test_vectors = vec![') + for vector in vectors: + print(' TestVector {') + [tv_part_rust(p[0], vector[p[0]], 4) for p in parts] + print(' },') + print(' ];') else: raise ValueError('Invalid type(vectors)')