Merge pull request #7 from daira/daira-zip32

Implement ZIP 32
This commit is contained in:
Daira Hopwood 2018-08-03 14:29:12 +01:00 committed by GitHub
commit 726688e6cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 480 additions and 25 deletions

148
ff1.py Normal file
View File

@ -0,0 +1,148 @@
#!/usr/bin/env python3
import sys; assert sys.version_info[0] >= 3, "Python 3 required."
import os
from binascii import unhexlify, hexlify
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
from sapling_utils import bebs2ip, i2bebsp, beos2ip, bebs2osp, cldiv
# Morris Dworkin
# NIST Special Publication 800-38G
# Recommendation for Block Cipher Modes of Operation: Methods for Format-Preserving Encryption
# <http://dx.doi.org/10.6028/NIST.SP.800-38G>
# specialized to the parameters below and a single-block PRF; unoptimized
radix = 2
minlen = maxlen = 88
maxTlen = 255
assert 2 <= radix and radix < 256
assert radix**minlen >= 100
assert 2 <= minlen and minlen <= maxlen and maxlen < 256
NUM_2 = bebs2ip
STR_2 = i2bebsp
def ff1_aes256_encrypt(key, tweak, x):
n = len(x)
t = len(tweak)
assert minlen <= n and n <= maxlen
assert t <= maxTlen
u = n//2; v = n-u
assert u == v
A = x[:u]; B = x[u:]
assert radix == 2
b = cldiv(v, 8)
d = 4*cldiv(b, 4) + 4
assert d <= 16
P = bytes([1, 2, 1, 0, 0, radix, 10, u % 256, 0, 0, 0, n, 0, 0, 0, t])
for i in range(10):
Q = tweak + b'\0'*((-t-b-1) % 16) + bytes([i]) + bebs2osp(B)
y = beos2ip(aes_cbcmac(key, P + Q)[:d])
c = (NUM_2(A)+y) % (1<<u)
C = STR_2(u, c)
A = B
B = C
return A + B
# This is not used except by tests.
def ff1_aes256_decrypt(key, tweak, x):
n = len(x)
t = len(tweak)
assert minlen <= n and n <= maxlen
assert t <= maxTlen
u = n//2; v = n-u
assert u == v
A = x[:u]; B = x[u:]
assert radix == 2
b = cldiv(v, 8)
d = 4*cldiv(b, 4) + 4
assert d <= 16
P = bytes([1, 2, 1, 0, 0, radix, 10, u % 256, 0, 0, 0, n, 0, 0, 0, t])
for i in range(9, -1, -1):
Q = tweak + b'\0'*((-t-b-1) % 16) + bytes([i]) + bebs2osp(A)
y = beos2ip(aes_cbcmac(key, P + Q)[:d])
c = (NUM_2(B)-y) % (1<<u)
C = STR_2(u, c)
B = A
A = C
return A + B
def test_ff1():
# Test vectors consistent with the Java implementation at
# <https://git.code.sf.net/p/format-preserving-encryption/code>.
key = unhexlify("2B7E151628AED2A6ABF7158809CF4F3CEF4359D8D580AA4F7F036D6F04FC6A94")
tweak = b''
x = [0]*88
ct = ff1_aes256_encrypt(key, tweak, x)
assert ''.join(map(str, ct)) == "0000100100110101011101111111110011000001101100111110011101110101011010100100010011001111", ct
pt = ff1_aes256_decrypt(key, tweak, ct)
assert pt == x, (ct, pt)
x = list(map(int, "0000100100110101011101111111110011000001101100111110011101110101011010100100010011001111"))
ct = ff1_aes256_encrypt(key, tweak, x)
assert ''.join(map(str, ct)) == "1101101011010001100011110000010011001111110110011101010110100001111001000101011111011000", ct
pt = ff1_aes256_decrypt(key, tweak, ct)
assert pt == x, (ct, pt)
x = [0, 1]*44
ct = ff1_aes256_encrypt(key, tweak, x)
assert ''.join(map(str, ct)) == "0000111101000001111011010111011111110001100101000000001101101110100010010111001100100110", ct
pt = ff1_aes256_decrypt(key, tweak, ct)
assert pt == x, (ct, pt)
tweak = bytes(range(maxTlen))
ct = ff1_aes256_encrypt(key, tweak, x)
assert ''.join(map(str, ct)) == "0111110110001000000111010110000100010101101000000011100111100100100010101101111010100011", ct
pt = ff1_aes256_decrypt(key, tweak, ct)
assert pt == x, (ct, pt)
key = os.urandom(32)
tweak = b''
ct = ff1_aes256_encrypt(key, tweak, x)
pt = ff1_aes256_decrypt(key, tweak, ct)
assert pt == x, (ct, pt)
tweak = os.urandom(maxTlen)
ct = ff1_aes256_encrypt(key, tweak, x)
pt = ff1_aes256_decrypt(key, tweak, ct)
assert pt == x, (ct, pt)
def aes_cbcmac(key, input):
encryptor = Cipher(algorithms.AES(key), modes.CBC(b'\0'*16), backend=default_backend()).encryptor()
return (encryptor.update(input) + encryptor.finalize())[-16:]
def test_aes():
# Check we're actually using AES-256.
# <https://csrc.nist.gov/Projects/Cryptographic-Algorithm-Validation-Program/Block-Ciphers>
# <https://csrc.nist.gov/CSRC/media/Projects/Cryptographic-Algorithm-Validation-Program/documents/aes/aesmct.zip>
# Simple test (this wouldn't catch a byte order error in the key):
# ECBVarTxt256.rsp COUNT = 0
KEY = unhexlify("0000000000000000000000000000000000000000000000000000000000000000")
PLAINTEXT = unhexlify("80000000000000000000000000000000")
CIPHERTEXT = unhexlify("ddc6bf790c15760d8d9aeb6f9a75fd4e")
assert aes_cbcmac(KEY, PLAINTEXT) == CIPHERTEXT
# Now something more rigorous:
# ECBMCT256.rsp COUNT = 0
key = unhexlify("f9e8389f5b80712e3886cc1fa2d28a3b8c9cd88a2d4a54c6aa86ce0fef944be0")
acc = unhexlify("b379777f9050e2a818f2940cbbd9aba4")
ct = unhexlify("6893ebaf0a1fccc704326529fdfb60db")
for i in range(1000):
acc = aes_cbcmac(key, acc)
assert acc == ct, hexlify(acc)
if __name__ == '__main__':
test_aes()
test_ff1()

View File

@ -1,4 +1,6 @@
#!/usr/bin/env python3
import sys; assert sys.version_info[0] >= 3, "Python 3 required."
from pyblake2 import blake2s
from sapling_jubjub import Point, JUBJUB_COFACTOR

View File

@ -1,4 +1,6 @@
#!/usr/bin/env python3
import sys; assert sys.version_info[0] >= 3, "Python 3 required."
from sapling_utils import i2lebsp, leos2ip, i2leosp
q_j = 52435875175126190479447740508185965837690552500527637822603658699938581184513

View File

@ -1,4 +1,6 @@
#!/usr/bin/env python3
import sys; assert sys.version_info[0] >= 3, "Python 3 required."
from pyblake2 import blake2b, blake2s
from sapling_generators import PROVING_KEY_BASE, SPENDING_KEY_BASE, group_hash
@ -33,6 +35,8 @@ def crh_ivk(ak, nk):
ivk = digest.digest()
return leos2ip(ivk) % 2**251
def diversify_hash(d):
return group_hash(b'Zcash_gd', d)
#
# Key components
@ -47,22 +51,8 @@ def cached(f):
return self._cached[f]
return wrapper
class SpendingKey(object):
def __init__(self, data):
self.data = data
@cached
def ask(self):
return to_scalar(prf_expand(self.data, b'\0'))
@cached
def nsk(self):
return to_scalar(prf_expand(self.data, b'\1'))
@cached
def ovk(self):
return prf_expand(self.data, b'\2')[:32]
class DerivedAkNk(object):
@cached
def ak(self):
return SPENDING_KEY_BASE * self.ask()
@ -71,23 +61,42 @@ class SpendingKey(object):
def nk(self):
return PROVING_KEY_BASE * self.nsk()
class DerivedIvk(object):
@cached
def ivk(self):
return Fr(crh_ivk(bytes(self.ak()), bytes(self.nk())))
class SpendingKey(DerivedAkNk, DerivedIvk):
def __init__(self, data):
self.data = data
@cached
def ask(self):
return to_scalar(prf_expand(self.data, b'\x00'))
@cached
def nsk(self):
return to_scalar(prf_expand(self.data, b'\x01'))
@cached
def ovk(self):
return prf_expand(self.data, b'\x02')[:32]
@cached
def default_d(self):
i = 0
while True:
d = prf_expand(self.data, bytes([3, i]))[:11]
if group_hash(b'Zcash_gd', d):
if diversify_hash(d):
return d
i += 1
assert i < 256
@cached
def default_pkd(self):
return group_hash(b'Zcash_gd', self.default_d()) * self.ivk()
return diversify_hash(self.default_d()) * self.ivk()
def main():
@ -100,7 +109,7 @@ def main():
note_r = Fr(8890123457840276890326754358439057438290574382905).exp(i+1)
note_cm = note_commit(
note_r,
leos2bsp(bytes(group_hash(b'Zcash_gd', sk.default_d()))),
leos2bsp(bytes(diversify_hash(sk.default_d()))),
leos2bsp(bytes(sk.default_pkd())),
note_v)
note_pos = (980705743285409327583205473820957432*i) % 2**MERKLE_DEPTH

View File

@ -1,4 +1,6 @@
#!/usr/bin/env python3
import sys; assert sys.version_info[0] >= 3, "Python 3 required."
from binascii import unhexlify
from sapling_pedersen import pedersen_hash

View File

@ -1,4 +1,6 @@
#!/usr/bin/env python3
import sys; assert sys.version_info[0] >= 3, "Python 3 required."
from pyblake2 import blake2s
from sapling_pedersen import (

View File

@ -1,4 +1,6 @@
#!/usr/bin/env python3
import sys; assert sys.version_info[0] >= 3, "Python 3 required."
from sapling_generators import (
find_group_hash,
NOTE_POSITION_BASE,

View File

@ -1,4 +1,6 @@
#!/usr/bin/env python3
import sys; assert sys.version_info[0] >= 3, "Python 3 required."
import os
from pyblake2 import blake2b

View File

@ -1,4 +1,5 @@
#!/usr/bin/env python3
import sys; assert sys.version_info[0] >= 3, "Python 3 required."
def cldiv(n, divisor):
return (n + (divisor - 1)) // divisor
@ -9,26 +10,57 @@ def i2lebsp(l, x):
def leos2ip(S):
return int.from_bytes(S, byteorder='little')
def beos2ip(S):
return int.from_bytes(S, byteorder='big')
# This should be equivalent to LEBS2OSP(I2LEBSP(l, x))
def i2leosp(l, x):
return x.to_bytes(cldiv(l, 8), byteorder='little')
def ledna(bits):
# This should be equivalent to BEBS2OSP(I2BEBSP(l, x))
def i2beosp(l, x):
return x.to_bytes(cldiv(l, 8), byteorder='big')
def bebs2ip(bits):
ret = 0
for b in bits[::-1]:
for b in bits:
ret = ret * 2
if b:
ret += 1
return ret
def lebs2ip(bits):
return bebs2ip(bits[::-1])
def i2bebsp(m, x):
assert 0 <= x and x < (1 << m)
return [(x >> (m-1-i)) & 1 for i in range(m)]
def lebs2osp(bits):
l = len(bits)
bits = bits + [0] * (8 * cldiv(l, 8) - l)
return bytes([ledna(bits[i:i + 8]) for i in range(0, len(bits), 8)])
return bytes([lebs2ip(bits[i:i + 8]) for i in range(0, len(bits), 8)])
def leos2bsp(buf):
return sum([[(c >> i) & 1 for i in range(8)] for c in buf], [])
def bebs2osp(bits, m=None):
l = len(bits)
bits = [0] * (8 * cldiv(l, 8) - l) + bits
return bytes([bebs2ip(bits[i:i + 8]) for i in range(0, len(bits), 8)])
assert i2leosp(5, 7) == lebs2osp(i2lebsp(5, 7))
assert i2leosp(32, 1234567890) == lebs2osp(i2lebsp(32, 1234567890))
assert i2beosp(5, 7) == bebs2osp(i2bebsp(5, 7))
assert i2beosp(32, 1234567890) == bebs2osp(i2bebsp(32, 1234567890))
assert leos2ip(bytes(range(256))) == lebs2ip(leos2bsp(bytes(range(256))))
assert bebs2ip(i2bebsp(5, 7)) == 7
try:
i2bebsp(3, 12)
except AssertionError:
pass
else:
raise AssertionError("invalid input not caught by i2bebsp")

224
sapling_zip32.py Normal file
View File

@ -0,0 +1,224 @@
#!/usr/bin/env python3
import sys; assert sys.version_info[0] >= 3, "Python 3 required."
from pyblake2 import blake2b
from sapling_key_components import to_scalar, prf_expand, diversify_hash, DerivedAkNk, DerivedIvk
from sapling_generators import SPENDING_KEY_BASE, PROVING_KEY_BASE
from sapling_utils import i2leosp, i2lebsp, lebs2osp
from ff1 import ff1_aes256_encrypt
from tv_output import render_args, render_tv, option, Some
def encode_xsk_parts(ask, nsk, ovk, dk):
# bytes = i2leosp_256 for Fr
return bytes(ask) + bytes(nsk) + ovk + dk
def encode_xfvk_parts(ak, nk, ovk, dk):
return bytes(ak) + bytes(nk) + ovk + dk
class ExtendedBase(object):
def ovk(self):
return self._ovk
def dk(self):
return self._dk
def c(self):
return self._c
def depth(self):
return self._depth
def parent_tag(self):
return self._parent_tag
def i(self):
return self._i
def diversifier(self, j):
d = lebs2osp(ff1_aes256_encrypt(self.dk(), b'', i2lebsp(88, j)))
return d if diversify_hash(d) else None
def fingerprint(self):
FVK = bytes(self.ak()) + bytes(self.nk()) + self.ovk()
return blake2b(person=b'ZcashSaplingFVFP', digest_size=32, data=FVK).digest()
def tag(self):
return self.fingerprint()[:4]
class ExtendedSpendingKey(DerivedAkNk, DerivedIvk, ExtendedBase):
def __init__(self, ask, nsk, ovk, dk, c, depth=0, parent_tag=i2leosp(32, 0), i=0):
self._ask = ask
self._nsk = nsk
self._ovk = ovk
self._dk = dk
self._c = c
self._depth = depth
self._parent_tag = parent_tag
self._i = i
@classmethod
def master(cls, S):
I = blake2b(person=b'ZcashIP32Sapling', data=S).digest()
I_L = I[:32]
I_R = I[32:]
sk_m = I_L
ask_m = to_scalar(prf_expand(sk_m, b'\x00'))
nsk_m = to_scalar(prf_expand(sk_m, b'\x01'))
ovk_m = prf_expand(sk_m, b'\x02')[:32]
dk_m = prf_expand(sk_m, b'\x10')[:32]
c_m = I_R
return cls(ask_m, nsk_m, ovk_m, dk_m, c_m)
def ask(self):
return self._ask
def nsk(self):
return self._nsk
def is_xsk(self):
return True
def __bytes__(self):
return (i2leosp(8, self.depth()) +
self.parent_tag() +
i2leosp(32, self.i()) +
self.c() +
encode_xsk_parts(self.ask(), self.nsk(), self.ovk(), self.dk()))
def to_extended_fvk(self):
return ExtendedFullViewingKey(self.ak(), self.nk(), self.ovk(), self.dk(), self.c(),
self.depth(), self.parent_tag(), self.i())
def child(self, i):
if i >= 1<<31: # child is a hardened key
prefix = b'\x11' + encode_xsk_parts(self.ask(), self.nsk(), self.ovk(), self.dk())
else:
prefix = b'\x12' + encode_xfvk_parts(self.ak(), self.nk(), self.ovk(), self.dk())
I = prf_expand(self.c(), prefix + i2leosp(32, i))
I_L = I[:32]
I_R = I[32:]
I_ask = to_scalar(prf_expand(I_L, b'\x13'))
I_nsk = to_scalar(prf_expand(I_L, b'\x14'))
ask_i = I_ask + self.ask()
nsk_i = I_nsk + self.nsk()
ovk_i = prf_expand(I_L, b'\x15' + self.ovk())[:32]
dk_i = prf_expand(I_L, b'\x16' + self.dk())[:32]
c_i = I_R
return self.__class__(ask_i, nsk_i, ovk_i, dk_i, c_i, self.depth()+1, self.tag(), i)
class ExtendedFullViewingKey(DerivedIvk, ExtendedBase):
def __init__(self, ak, nk, ovk, dk, c, depth=0, parent_tag=i2leosp(32, 0), i=0):
self._ak = ak
self._nk = nk
self._ovk = ovk
self._dk = dk
self._c = c
self._depth = depth
self._parent_tag = parent_tag
self._i = i
@classmethod
def master(cls, S):
return ExtendedSpendingKey.master(S).to_extended_fvk()
def ak(self):
return self._ak
def nk(self):
return self._nk
def is_xsk(self):
return False
def __bytes__(self):
return (i2leosp(8, self.depth()) +
self.parent_tag() +
i2leosp(32, self.i()) +
self.c() +
encode_xfvk_parts(self.ak(), self.nk(), self.ovk(), self.dk()))
def to_extended_fvk(self):
return self
def child(self, i):
if i >= 1<<31:
raise ValueError("can't derive a child hardened key from an extended full viewing key")
else:
prefix = b'\x12' + encode_xfvk_parts(self.ak(), self.nk(), self.ovk(), self.dk())
I = prf_expand(self.c(), prefix + i2leosp(32, i))
I_L = I[:32]
I_R = I[32:]
I_ask = to_scalar(prf_expand(I_L, b'\x13'))
I_nsk = to_scalar(prf_expand(I_L, b'\x14'))
ak_i = SPENDING_KEY_BASE * I_ask + self.ak()
nk_i = PROVING_KEY_BASE * I_nsk + self.nk()
ovk_i = prf_expand(I_L, b'\x15' + self.ovk())[:32]
dk_i = prf_expand(I_L, b'\x16' + self.dk())[:32]
c_i = I_R
return self.__class__(ak_i, nk_i, ovk_i, dk_i, c_i, self.depth()+1, self.tag(), i)
def main():
args = render_args()
def hardened(i): return i + (1<<31)
seed = bytes(range(32))
m = ExtendedSpendingKey.master(seed)
m_1 = m.child(1)
m_1_2h = m_1.child(hardened(2))
m_1_2hv = m_1_2h.to_extended_fvk()
m_1_2hv_3 = m_1_2hv.child(3)
test_vectors = [
{'ask' : Some(bytes(k.ask())) if k.is_xsk() else None,
'nsk' : Some(bytes(k.nsk())) if k.is_xsk() else None,
'ovk' : k.ovk(),
'dk' : k.dk(),
'c' : k.c(),
'ak' : bytes(k.ak()),
'nk' : bytes(k.nk()),
'ivk' : bytes(k.ivk()),
'xsk' : Some(bytes(k)) if k.is_xsk() else None,
'xfvk': bytes(k.to_extended_fvk()),
'fp' : k.fingerprint(),
'd0' : option(k.diversifier(0)),
'd1' : option(k.diversifier(1)),
'd2' : option(k.diversifier(2)),
'dmax': option(k.diversifier((1<<88)-1)),
}
for k in (m, m_1, m_1_2h, m_1_2hv, m_1_2hv_3)
]
render_tv(
args,
'sapling_zip32',
(
('ask', 'Option<[u8; 32]>'),
('nsk', 'Option<[u8; 32]>'),
('ovk', '[u8; 32]'),
('dk', '[u8; 32]'),
('c', '[u8; 32]'),
('ak', '[u8; 32]'),
('nk', '[u8; 32]'),
('ivk', '[u8; 32]'),
('xsk', 'Option<[u8; 169]>'),
('xfvk','[u8; 169]'),
('fp', '[u8; 32]'),
('d0', 'Option<[u8; 11]>'),
('d1', 'Option<[u8; 11]>'),
('d2', 'Option<[u8; 11]>'),
('dmax','Option<[u8; 11]>'),
),
test_vectors,
)
if __name__ == '__main__':
main()

View File

@ -1,3 +1,6 @@
#!/usr/bin/env python3
import sys; assert sys.version_info[0] >= 3, "Python 3 required."
import argparse
from binascii import hexlify
import json
@ -7,6 +10,12 @@ def chunk(h):
hstr = str(h, 'utf-8')
return '0x' + ', 0x'.join([hstr[i:i+2] for i in range(0, len(hstr), 2)])
class Some(object):
def __init__(self, thing):
self.thing = thing
def option(x):
return Some(x) if x else None
#
# JSON (with string comments)
@ -14,6 +23,9 @@ def chunk(h):
#
def tv_value_json(value, bitcoin_flavoured):
if isinstance(value, Some):
value = value.thing
if type(value) == bytes:
if bitcoin_flavoured and len(value) == 32:
value = value[::-1]
@ -51,6 +63,20 @@ def tv_bytes_rust(name, value, pad):
pad,
))
def tv_option_bytes_rust(name, value, pad):
if value:
print('''%s%s: Some([
%s%s
%s]),''' % (
pad,
name,
pad,
chunk(hexlify(value.thing)),
pad,
))
else:
print('%s%s: None,' % (pad, name))
def tv_int_rust(name, value, pad):
print('%s%s: %d,' % (pad, name, value))
@ -58,6 +84,8 @@ def tv_part_rust(name, value, indent=3):
pad = ' ' * indent
if type(value) == bytes:
tv_bytes_rust(name, value, pad)
elif isinstance(value, Some) or value is None:
tv_option_bytes_rust(name, value, pad)
elif type(value) == int:
tv_int_rust(name, value, pad)
else:
@ -65,7 +93,7 @@ def tv_part_rust(name, value, indent=3):
def tv_rust(filename, parts, vectors):
print(' struct TestVector {')
[print(' %s: %s,' % p) for p in parts]
for p in parts: print(' %s: %s,' % p)
print(''' };
// From https://github.com/zcash-hackworks/zcash-test-vectors/blob/master/%s.py''' % (
@ -73,13 +101,13 @@ def tv_rust(filename, parts, vectors):
))
if type(vectors) == type({}):
print(' let test_vector = TestVector {')
[tv_part_rust(p[0], vectors[p[0]]) for p in parts]
for p in parts: tv_part_rust(p[0], vectors[p[0]])
print(' };')
elif type(vectors) == type([]):
print(' let test_vectors = vec![')
for vector in vectors:
print(' TestVector {')
[tv_part_rust(p[0], vector[p[0]], 4) for p in parts]
for p in parts: tv_part_rust(p[0], vector[p[0]], 4)
print(' },')
print(' ];')
else:
@ -92,7 +120,7 @@ def tv_rust(filename, parts, vectors):
def render_args():
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--target', choices=['zcash', 'rust'], default='rust')
parser.add_argument('-t', '--target', choices=['zcash', 'json', 'rust'], default='rust')
return parser.parse_args()
def render_tv(args, filename, parts, vectors):
@ -100,3 +128,5 @@ def render_tv(args, filename, parts, vectors):
tv_rust(filename, parts, vectors)
elif args.target == 'zcash':
tv_json(filename, parts, vectors, True)
elif args.target == 'json':
tv_json(filename, parts, vectors, False)