diff --git a/ff1.py b/ff1.py new file mode 100644 index 0000000..9524ea6 --- /dev/null +++ b/ff1.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +import sys; assert sys.version_info[0] >= 3, "Python 3 required." + +import os +from binascii import unhexlify, hexlify + +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from cryptography.hazmat.backends import default_backend + +from sapling_utils import bebs2ip, i2bebsp, beos2ip, bebs2osp, cldiv + +# Morris Dworkin +# NIST Special Publication 800-38G +# Recommendation for Block Cipher Modes of Operation: Methods for Format-Preserving Encryption +# +# specialized to the parameters below and a single-block PRF; unoptimized + +radix = 2 +minlen = maxlen = 88 +maxTlen = 255 +assert 2 <= radix and radix < 256 +assert radix**minlen >= 100 +assert 2 <= minlen and minlen <= maxlen and maxlen < 256 + +NUM_2 = bebs2ip +STR_2 = i2bebsp + + +def ff1_aes256_encrypt(key, tweak, x): + n = len(x) + t = len(tweak) + assert minlen <= n and n <= maxlen + assert t <= maxTlen + + u = n//2; v = n-u + assert u == v + A = x[:u]; B = x[u:] + assert radix == 2 + b = cldiv(v, 8) + d = 4*cldiv(b, 4) + 4 + assert d <= 16 + P = bytes([1, 2, 1, 0, 0, radix, 10, u % 256, 0, 0, 0, n, 0, 0, 0, t]) + for i in range(10): + Q = tweak + b'\0'*((-t-b-1) % 16) + bytes([i]) + bebs2osp(B) + y = beos2ip(aes_cbcmac(key, P + Q)[:d]) + c = (NUM_2(A)+y) % (1<. + + key = unhexlify("2B7E151628AED2A6ABF7158809CF4F3CEF4359D8D580AA4F7F036D6F04FC6A94") + + tweak = b'' + x = [0]*88 + ct = ff1_aes256_encrypt(key, tweak, x) + assert ''.join(map(str, ct)) == "0000100100110101011101111111110011000001101100111110011101110101011010100100010011001111", ct + pt = ff1_aes256_decrypt(key, tweak, ct) + assert pt == x, (ct, pt) + + x = list(map(int, "0000100100110101011101111111110011000001101100111110011101110101011010100100010011001111")) + ct = ff1_aes256_encrypt(key, tweak, x) + assert ''.join(map(str, ct)) == "1101101011010001100011110000010011001111110110011101010110100001111001000101011111011000", ct + pt = ff1_aes256_decrypt(key, tweak, ct) + assert pt == x, (ct, pt) + + x = [0, 1]*44 + ct = ff1_aes256_encrypt(key, tweak, x) + assert ''.join(map(str, ct)) == "0000111101000001111011010111011111110001100101000000001101101110100010010111001100100110", ct + pt = ff1_aes256_decrypt(key, tweak, ct) + assert pt == x, (ct, pt) + + tweak = bytes(range(maxTlen)) + ct = ff1_aes256_encrypt(key, tweak, x) + assert ''.join(map(str, ct)) == "0111110110001000000111010110000100010101101000000011100111100100100010101101111010100011", ct + pt = ff1_aes256_decrypt(key, tweak, ct) + assert pt == x, (ct, pt) + + key = os.urandom(32) + tweak = b'' + ct = ff1_aes256_encrypt(key, tweak, x) + pt = ff1_aes256_decrypt(key, tweak, ct) + assert pt == x, (ct, pt) + + tweak = os.urandom(maxTlen) + ct = ff1_aes256_encrypt(key, tweak, x) + pt = ff1_aes256_decrypt(key, tweak, ct) + assert pt == x, (ct, pt) + + +def aes_cbcmac(key, input): + encryptor = Cipher(algorithms.AES(key), modes.CBC(b'\0'*16), backend=default_backend()).encryptor() + return (encryptor.update(input) + encryptor.finalize())[-16:] + +def test_aes(): + # Check we're actually using AES-256. + + # + # + + # Simple test (this wouldn't catch a byte order error in the key): + # ECBVarTxt256.rsp COUNT = 0 + KEY = unhexlify("0000000000000000000000000000000000000000000000000000000000000000") + PLAINTEXT = unhexlify("80000000000000000000000000000000") + CIPHERTEXT = unhexlify("ddc6bf790c15760d8d9aeb6f9a75fd4e") + assert aes_cbcmac(KEY, PLAINTEXT) == CIPHERTEXT + + # Now something more rigorous: + # ECBMCT256.rsp COUNT = 0 + key = unhexlify("f9e8389f5b80712e3886cc1fa2d28a3b8c9cd88a2d4a54c6aa86ce0fef944be0") + acc = unhexlify("b379777f9050e2a818f2940cbbd9aba4") + ct = unhexlify("6893ebaf0a1fccc704326529fdfb60db") + for i in range(1000): + acc = aes_cbcmac(key, acc) + assert acc == ct, hexlify(acc) + + +if __name__ == '__main__': + test_aes() + test_ff1() diff --git a/sapling_generators.py b/sapling_generators.py index 1dc7e50..573383d 100644 --- a/sapling_generators.py +++ b/sapling_generators.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +import sys; assert sys.version_info[0] >= 3, "Python 3 required." + from pyblake2 import blake2s from sapling_jubjub import Point, JUBJUB_COFACTOR diff --git a/sapling_jubjub.py b/sapling_jubjub.py index ede62d2..6bfc31a 100644 --- a/sapling_jubjub.py +++ b/sapling_jubjub.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +import sys; assert sys.version_info[0] >= 3, "Python 3 required." + from sapling_utils import i2lebsp, leos2ip, i2leosp q_j = 52435875175126190479447740508185965837690552500527637822603658699938581184513 diff --git a/sapling_key_components.py b/sapling_key_components.py index 3264029..d1e345a 100644 --- a/sapling_key_components.py +++ b/sapling_key_components.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +import sys; assert sys.version_info[0] >= 3, "Python 3 required." + from pyblake2 import blake2b, blake2s from sapling_generators import PROVING_KEY_BASE, SPENDING_KEY_BASE, group_hash @@ -33,6 +35,8 @@ def crh_ivk(ak, nk): ivk = digest.digest() return leos2ip(ivk) % 2**251 +def diversify_hash(d): + return group_hash(b'Zcash_gd', d) # # Key components @@ -47,22 +51,8 @@ def cached(f): return self._cached[f] return wrapper -class SpendingKey(object): - def __init__(self, data): - self.data = data - - @cached - def ask(self): - return to_scalar(prf_expand(self.data, b'\0')) - - @cached - def nsk(self): - return to_scalar(prf_expand(self.data, b'\1')) - - @cached - def ovk(self): - return prf_expand(self.data, b'\2')[:32] +class DerivedAkNk(object): @cached def ak(self): return SPENDING_KEY_BASE * self.ask() @@ -71,23 +61,42 @@ class SpendingKey(object): def nk(self): return PROVING_KEY_BASE * self.nsk() + +class DerivedIvk(object): @cached def ivk(self): return Fr(crh_ivk(bytes(self.ak()), bytes(self.nk()))) + +class SpendingKey(DerivedAkNk, DerivedIvk): + def __init__(self, data): + self.data = data + + @cached + def ask(self): + return to_scalar(prf_expand(self.data, b'\x00')) + + @cached + def nsk(self): + return to_scalar(prf_expand(self.data, b'\x01')) + + @cached + def ovk(self): + return prf_expand(self.data, b'\x02')[:32] + @cached def default_d(self): i = 0 while True: d = prf_expand(self.data, bytes([3, i]))[:11] - if group_hash(b'Zcash_gd', d): + if diversify_hash(d): return d i += 1 assert i < 256 @cached def default_pkd(self): - return group_hash(b'Zcash_gd', self.default_d()) * self.ivk() + return diversify_hash(self.default_d()) * self.ivk() def main(): @@ -100,7 +109,7 @@ def main(): note_r = Fr(8890123457840276890326754358439057438290574382905).exp(i+1) note_cm = note_commit( note_r, - leos2bsp(bytes(group_hash(b'Zcash_gd', sk.default_d()))), + leos2bsp(bytes(diversify_hash(sk.default_d()))), leos2bsp(bytes(sk.default_pkd())), note_v) note_pos = (980705743285409327583205473820957432*i) % 2**MERKLE_DEPTH diff --git a/sapling_merkle_tree.py b/sapling_merkle_tree.py index 16d91c7..6d7b7f7 100644 --- a/sapling_merkle_tree.py +++ b/sapling_merkle_tree.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +import sys; assert sys.version_info[0] >= 3, "Python 3 required." + from binascii import unhexlify from sapling_pedersen import pedersen_hash diff --git a/sapling_notes.py b/sapling_notes.py index 5fecb57..aa27258 100644 --- a/sapling_notes.py +++ b/sapling_notes.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +import sys; assert sys.version_info[0] >= 3, "Python 3 required." + from pyblake2 import blake2s from sapling_pedersen import ( diff --git a/sapling_pedersen.py b/sapling_pedersen.py index a3a89dc..f782efe 100644 --- a/sapling_pedersen.py +++ b/sapling_pedersen.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +import sys; assert sys.version_info[0] >= 3, "Python 3 required." + from sapling_generators import ( find_group_hash, NOTE_POSITION_BASE, diff --git a/sapling_signatures.py b/sapling_signatures.py index ead3bfd..f774935 100644 --- a/sapling_signatures.py +++ b/sapling_signatures.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +import sys; assert sys.version_info[0] >= 3, "Python 3 required." + import os from pyblake2 import blake2b diff --git a/sapling_utils.py b/sapling_utils.py index 8ea08d4..a334134 100644 --- a/sapling_utils.py +++ b/sapling_utils.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +import sys; assert sys.version_info[0] >= 3, "Python 3 required." def cldiv(n, divisor): return (n + (divisor - 1)) // divisor @@ -9,26 +10,57 @@ def i2lebsp(l, x): def leos2ip(S): return int.from_bytes(S, byteorder='little') +def beos2ip(S): + return int.from_bytes(S, byteorder='big') + # This should be equivalent to LEBS2OSP(I2LEBSP(l, x)) def i2leosp(l, x): return x.to_bytes(cldiv(l, 8), byteorder='little') -def ledna(bits): +# This should be equivalent to BEBS2OSP(I2BEBSP(l, x)) +def i2beosp(l, x): + return x.to_bytes(cldiv(l, 8), byteorder='big') + +def bebs2ip(bits): ret = 0 - for b in bits[::-1]: + for b in bits: ret = ret * 2 if b: ret += 1 return ret +def lebs2ip(bits): + return bebs2ip(bits[::-1]) + +def i2bebsp(m, x): + assert 0 <= x and x < (1 << m) + return [(x >> (m-1-i)) & 1 for i in range(m)] + def lebs2osp(bits): l = len(bits) bits = bits + [0] * (8 * cldiv(l, 8) - l) - return bytes([ledna(bits[i:i + 8]) for i in range(0, len(bits), 8)]) + return bytes([lebs2ip(bits[i:i + 8]) for i in range(0, len(bits), 8)]) def leos2bsp(buf): return sum([[(c >> i) & 1 for i in range(8)] for c in buf], []) +def bebs2osp(bits, m=None): + l = len(bits) + bits = [0] * (8 * cldiv(l, 8) - l) + bits + return bytes([bebs2ip(bits[i:i + 8]) for i in range(0, len(bits), 8)]) assert i2leosp(5, 7) == lebs2osp(i2lebsp(5, 7)) assert i2leosp(32, 1234567890) == lebs2osp(i2lebsp(32, 1234567890)) + +assert i2beosp(5, 7) == bebs2osp(i2bebsp(5, 7)) +assert i2beosp(32, 1234567890) == bebs2osp(i2bebsp(32, 1234567890)) + +assert leos2ip(bytes(range(256))) == lebs2ip(leos2bsp(bytes(range(256)))) + +assert bebs2ip(i2bebsp(5, 7)) == 7 +try: + i2bebsp(3, 12) +except AssertionError: + pass +else: + raise AssertionError("invalid input not caught by i2bebsp") diff --git a/sapling_zip32.py b/sapling_zip32.py new file mode 100644 index 0000000..8fcf4d7 --- /dev/null +++ b/sapling_zip32.py @@ -0,0 +1,224 @@ +#!/usr/bin/env python3 +import sys; assert sys.version_info[0] >= 3, "Python 3 required." + +from pyblake2 import blake2b + +from sapling_key_components import to_scalar, prf_expand, diversify_hash, DerivedAkNk, DerivedIvk +from sapling_generators import SPENDING_KEY_BASE, PROVING_KEY_BASE +from sapling_utils import i2leosp, i2lebsp, lebs2osp +from ff1 import ff1_aes256_encrypt +from tv_output import render_args, render_tv, option, Some + + +def encode_xsk_parts(ask, nsk, ovk, dk): + # bytes = i2leosp_256 for Fr + return bytes(ask) + bytes(nsk) + ovk + dk + +def encode_xfvk_parts(ak, nk, ovk, dk): + return bytes(ak) + bytes(nk) + ovk + dk + + +class ExtendedBase(object): + def ovk(self): + return self._ovk + + def dk(self): + return self._dk + + def c(self): + return self._c + + def depth(self): + return self._depth + + def parent_tag(self): + return self._parent_tag + + def i(self): + return self._i + + def diversifier(self, j): + d = lebs2osp(ff1_aes256_encrypt(self.dk(), b'', i2lebsp(88, j))) + return d if diversify_hash(d) else None + + def fingerprint(self): + FVK = bytes(self.ak()) + bytes(self.nk()) + self.ovk() + return blake2b(person=b'ZcashSaplingFVFP', digest_size=32, data=FVK).digest() + + def tag(self): + return self.fingerprint()[:4] + + +class ExtendedSpendingKey(DerivedAkNk, DerivedIvk, ExtendedBase): + def __init__(self, ask, nsk, ovk, dk, c, depth=0, parent_tag=i2leosp(32, 0), i=0): + self._ask = ask + self._nsk = nsk + self._ovk = ovk + self._dk = dk + self._c = c + self._depth = depth + self._parent_tag = parent_tag + self._i = i + + @classmethod + def master(cls, S): + I = blake2b(person=b'ZcashIP32Sapling', data=S).digest() + I_L = I[:32] + I_R = I[32:] + sk_m = I_L + ask_m = to_scalar(prf_expand(sk_m, b'\x00')) + nsk_m = to_scalar(prf_expand(sk_m, b'\x01')) + ovk_m = prf_expand(sk_m, b'\x02')[:32] + dk_m = prf_expand(sk_m, b'\x10')[:32] + c_m = I_R + return cls(ask_m, nsk_m, ovk_m, dk_m, c_m) + + def ask(self): + return self._ask + + def nsk(self): + return self._nsk + + def is_xsk(self): + return True + + def __bytes__(self): + return (i2leosp(8, self.depth()) + + self.parent_tag() + + i2leosp(32, self.i()) + + self.c() + + encode_xsk_parts(self.ask(), self.nsk(), self.ovk(), self.dk())) + + def to_extended_fvk(self): + return ExtendedFullViewingKey(self.ak(), self.nk(), self.ovk(), self.dk(), self.c(), + self.depth(), self.parent_tag(), self.i()) + + def child(self, i): + if i >= 1<<31: # child is a hardened key + prefix = b'\x11' + encode_xsk_parts(self.ask(), self.nsk(), self.ovk(), self.dk()) + else: + prefix = b'\x12' + encode_xfvk_parts(self.ak(), self.nk(), self.ovk(), self.dk()) + + I = prf_expand(self.c(), prefix + i2leosp(32, i)) + I_L = I[:32] + I_R = I[32:] + I_ask = to_scalar(prf_expand(I_L, b'\x13')) + I_nsk = to_scalar(prf_expand(I_L, b'\x14')) + ask_i = I_ask + self.ask() + nsk_i = I_nsk + self.nsk() + ovk_i = prf_expand(I_L, b'\x15' + self.ovk())[:32] + dk_i = prf_expand(I_L, b'\x16' + self.dk())[:32] + c_i = I_R + return self.__class__(ask_i, nsk_i, ovk_i, dk_i, c_i, self.depth()+1, self.tag(), i) + + +class ExtendedFullViewingKey(DerivedIvk, ExtendedBase): + def __init__(self, ak, nk, ovk, dk, c, depth=0, parent_tag=i2leosp(32, 0), i=0): + self._ak = ak + self._nk = nk + self._ovk = ovk + self._dk = dk + self._c = c + self._depth = depth + self._parent_tag = parent_tag + self._i = i + + @classmethod + def master(cls, S): + return ExtendedSpendingKey.master(S).to_extended_fvk() + + def ak(self): + return self._ak + + def nk(self): + return self._nk + + def is_xsk(self): + return False + + def __bytes__(self): + return (i2leosp(8, self.depth()) + + self.parent_tag() + + i2leosp(32, self.i()) + + self.c() + + encode_xfvk_parts(self.ak(), self.nk(), self.ovk(), self.dk())) + + def to_extended_fvk(self): + return self + + def child(self, i): + if i >= 1<<31: + raise ValueError("can't derive a child hardened key from an extended full viewing key") + else: + prefix = b'\x12' + encode_xfvk_parts(self.ak(), self.nk(), self.ovk(), self.dk()) + + I = prf_expand(self.c(), prefix + i2leosp(32, i)) + I_L = I[:32] + I_R = I[32:] + I_ask = to_scalar(prf_expand(I_L, b'\x13')) + I_nsk = to_scalar(prf_expand(I_L, b'\x14')) + ak_i = SPENDING_KEY_BASE * I_ask + self.ak() + nk_i = PROVING_KEY_BASE * I_nsk + self.nk() + ovk_i = prf_expand(I_L, b'\x15' + self.ovk())[:32] + dk_i = prf_expand(I_L, b'\x16' + self.dk())[:32] + c_i = I_R + return self.__class__(ak_i, nk_i, ovk_i, dk_i, c_i, self.depth()+1, self.tag(), i) + + +def main(): + args = render_args() + + def hardened(i): return i + (1<<31) + + seed = bytes(range(32)) + m = ExtendedSpendingKey.master(seed) + m_1 = m.child(1) + m_1_2h = m_1.child(hardened(2)) + m_1_2hv = m_1_2h.to_extended_fvk() + m_1_2hv_3 = m_1_2hv.child(3) + + test_vectors = [ + {'ask' : Some(bytes(k.ask())) if k.is_xsk() else None, + 'nsk' : Some(bytes(k.nsk())) if k.is_xsk() else None, + 'ovk' : k.ovk(), + 'dk' : k.dk(), + 'c' : k.c(), + 'ak' : bytes(k.ak()), + 'nk' : bytes(k.nk()), + 'ivk' : bytes(k.ivk()), + 'xsk' : Some(bytes(k)) if k.is_xsk() else None, + 'xfvk': bytes(k.to_extended_fvk()), + 'fp' : k.fingerprint(), + 'd0' : option(k.diversifier(0)), + 'd1' : option(k.diversifier(1)), + 'd2' : option(k.diversifier(2)), + 'dmax': option(k.diversifier((1<<88)-1)), + } + for k in (m, m_1, m_1_2h, m_1_2hv, m_1_2hv_3) + ] + + render_tv( + args, + 'sapling_zip32', + ( + ('ask', 'Option<[u8; 32]>'), + ('nsk', 'Option<[u8; 32]>'), + ('ovk', '[u8; 32]'), + ('dk', '[u8; 32]'), + ('c', '[u8; 32]'), + ('ak', '[u8; 32]'), + ('nk', '[u8; 32]'), + ('ivk', '[u8; 32]'), + ('xsk', 'Option<[u8; 169]>'), + ('xfvk','[u8; 169]'), + ('fp', '[u8; 32]'), + ('d0', 'Option<[u8; 11]>'), + ('d1', 'Option<[u8; 11]>'), + ('d2', 'Option<[u8; 11]>'), + ('dmax','Option<[u8; 11]>'), + ), + test_vectors, + ) + +if __name__ == '__main__': + main() diff --git a/tv_output.py b/tv_output.py index 00bf573..38c182f 100644 --- a/tv_output.py +++ b/tv_output.py @@ -1,3 +1,6 @@ +#!/usr/bin/env python3 +import sys; assert sys.version_info[0] >= 3, "Python 3 required." + import argparse from binascii import hexlify import json @@ -7,6 +10,12 @@ def chunk(h): hstr = str(h, 'utf-8') return '0x' + ', 0x'.join([hstr[i:i+2] for i in range(0, len(hstr), 2)]) +class Some(object): + def __init__(self, thing): + self.thing = thing + +def option(x): + return Some(x) if x else None # # JSON (with string comments) @@ -14,6 +23,9 @@ def chunk(h): # def tv_value_json(value, bitcoin_flavoured): + if isinstance(value, Some): + value = value.thing + if type(value) == bytes: if bitcoin_flavoured and len(value) == 32: value = value[::-1] @@ -51,6 +63,20 @@ def tv_bytes_rust(name, value, pad): pad, )) +def tv_option_bytes_rust(name, value, pad): + if value: + print('''%s%s: Some([ + %s%s +%s]),''' % ( + pad, + name, + pad, + chunk(hexlify(value.thing)), + pad, + )) + else: + print('%s%s: None,' % (pad, name)) + def tv_int_rust(name, value, pad): print('%s%s: %d,' % (pad, name, value)) @@ -58,6 +84,8 @@ def tv_part_rust(name, value, indent=3): pad = ' ' * indent if type(value) == bytes: tv_bytes_rust(name, value, pad) + elif isinstance(value, Some) or value is None: + tv_option_bytes_rust(name, value, pad) elif type(value) == int: tv_int_rust(name, value, pad) else: @@ -65,7 +93,7 @@ def tv_part_rust(name, value, indent=3): def tv_rust(filename, parts, vectors): print(' struct TestVector {') - [print(' %s: %s,' % p) for p in parts] + for p in parts: print(' %s: %s,' % p) print(''' }; // From https://github.com/zcash-hackworks/zcash-test-vectors/blob/master/%s.py''' % ( @@ -73,13 +101,13 @@ def tv_rust(filename, parts, vectors): )) if type(vectors) == type({}): print(' let test_vector = TestVector {') - [tv_part_rust(p[0], vectors[p[0]]) for p in parts] + for p in parts: tv_part_rust(p[0], vectors[p[0]]) print(' };') elif type(vectors) == type([]): print(' let test_vectors = vec![') for vector in vectors: print(' TestVector {') - [tv_part_rust(p[0], vector[p[0]], 4) for p in parts] + for p in parts: tv_part_rust(p[0], vector[p[0]], 4) print(' },') print(' ];') else: @@ -92,7 +120,7 @@ def tv_rust(filename, parts, vectors): def render_args(): parser = argparse.ArgumentParser() - parser.add_argument('-t', '--target', choices=['zcash', 'rust'], default='rust') + parser.add_argument('-t', '--target', choices=['zcash', 'json', 'rust'], default='rust') return parser.parse_args() def render_tv(args, filename, parts, vectors): @@ -100,3 +128,5 @@ def render_tv(args, filename, parts, vectors): tv_rust(filename, parts, vectors) elif args.target == 'zcash': tv_json(filename, parts, vectors, True) + elif args.target == 'json': + tv_json(filename, parts, vectors, False)