diff --git a/orchard_group_hash.py b/orchard_group_hash.py new file mode 100644 index 0000000..5ef0c9e --- /dev/null +++ b/orchard_group_hash.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 +import sys; assert sys.version_info[0] >= 3, "Python 3 required." + +import math + +import orchard_iso_pallas + +from pyblake2 import blake2b +from orchard_pallas import Fp, p, q, PALLAS_B, Point +from orchard_iso_pallas import PALLAS_ISO_B, PALLAS_ISO_A +from sapling_utils import i2beosp, cldiv, beos2ip, i2leosp, lebs2ip +from tv_output import render_args, render_tv +from tv_rand import Rand + +# https://stackoverflow.com/questions/2612720/how-to-do-bitwise-exclusive-or-of-two-strings-in-python +def sxor(s1,s2): + return bytes([a ^ b for a,b in zip(s1,s2)]) + +def expand_message_xmd(msg, dst, len_in_bytes): + assert len(dst) <= 255 + + b_in_bytes = 64 # hash function output size + r_in_bytes = 128 + + ell = cldiv(len_in_bytes, b_in_bytes) + + assert ell <= 255 + + dst_prime = dst + i2beosp(8, len(dst)) + z_pad = b"\x00" * r_in_bytes + l_i_b_str = i2beosp(16, len_in_bytes) + msg_prime = z_pad + msg + l_i_b_str + i2beosp(8, 0) + dst_prime + + b = [] + + b0_ctx = blake2b(digest_size=b_in_bytes, person=i2beosp(128,0)) + b0_ctx.update(msg_prime) + b.append(b0_ctx.digest()) + assert len(b[0]) == b_in_bytes + + b1_ctx = blake2b(digest_size=b_in_bytes, person=i2beosp(128,0)) + b1_ctx.update(b[0] + i2beosp(8, 1) + dst_prime) + b.append(b1_ctx.digest()) + assert len(b[1]) == b_in_bytes + + for i in range(2, ell + 1): + bi_input = sxor(b[0], b[i-1]) + + assert len(bi_input) == b_in_bytes + + bi_input += i2beosp(8, i) + dst_prime + + bi_ctx = blake2b(digest_size=b_in_bytes, person=i2beosp(128,0)) + bi_ctx.update(bi_input) + + b.append(bi_ctx.digest()) + assert len(b[i]) == b_in_bytes + + return b''.join(b[1:])[0:len_in_bytes] + +def hash_to_field(msg, dst): + k = 256 + count = 2 + m = 1 + + L = cldiv(math.ceil(math.log2(p)) + k, 8) + assert L == 512/8 + + len_in_bytes = count * m * L + uniform_bytes = expand_message_xmd(msg, dst, len_in_bytes) + + elements = [] + for i in range(0, count): + for j in range(0, m): + elm_offset = L * (j + i * m) + tv = uniform_bytes[elm_offset:elm_offset+L] + elements.append(Fp(beos2ip(tv), False)) + + assert len(elements) == count + + return elements + +def map_to_curve_simple_swu(u): + # The notation below follows Appendix F.2 of the Internet Draft + zero = Fp(0) + assert zero.inv() == Fp(0) + + A = PALLAS_ISO_A + B = PALLAS_ISO_B + Z = Fp(-13, False) + c1 = -B / A + c2 = Fp(-1) / Z + + tv1 = Z * u.exp(2) + tv2 = tv1.exp(2) + x1 = tv1 + tv2 + + x1 = x1.inv() + e1 = x1 == Fp(0) + x1 = x1 + Fp(1) + + x1 = c2 if e1 else x1 # If (tv1 + tv2) == 0, set x1 = -1 / Z + + x1 = x1 * c1 # x1 = (-B / A) * (1 + (1 / (Z^2 * u^4 + Z * u^2))) + gx1 = x1.exp(2) + gx1 = gx1 + A + gx1 = gx1 * x1 + gx1 = gx1 + B # gx1 = g(x1) = x1^3 + A * x1 + B + x2 = tv1 * x1 # x2 = Z * u^2 * x1 + tv2 = tv1 * tv2 + gx2 = gx1 * tv2 # gx2 = (Z * u^2)^3 * gx1 + + e2 = (gx1.sqrt() is not None) + + x = x1 if e2 else x2 # If is_square(gx1), x = x1, else x = x2 + yy = gx1 if e2 else gx2 # If is_square(gx1), yy = gx1, else yy = gx2 + y = yy.sqrt() + + e3 = u.sgn0() == y.sgn0() + + y = y if e3 else -y #y = CMOV(-y, y, e3) + + return orchard_iso_pallas.Point(x, y) + +def group_hash(d, m): + dst = d + b"-" + b"pallas" + b"_XMD:BLAKE2b_SSWU_RO_" + + elems = hash_to_field(m, dst) + assert len(elems) == 2 + + q = [map_to_curve_simple_swu(elems[0]).iso_map(), map_to_curve_simple_swu(elems[1]).iso_map()] + + return q[0] + q[1] + + +def main(): + fixed_test_vectors = [ + # This is the Pallas test vector from the Sage and Rust code (in affine coordinates). + (b"z.cash:test", b"Trans rights now!", Point(Fp(10899331951394555178876036573383466686793225972744812919361819919497009261523), + Fp(851679174277466283220362715537906858808436854303373129825287392516025427980))), + ] + + for (domain, msg, point) in fixed_test_vectors: + gh = group_hash(domain, msg) + assert gh == point + + test_vectors = [(domain, msg) for (domain, msg, _) in fixed_test_vectors] + + from random import Random + rng = Random(0xabad533d) + def randbytes(l): + ret = [] + while len(ret) < l: + ret.append(rng.randrange(0, 256)) + return bytes(ret) + rand = Rand(randbytes) + + # Generate test vectors with the following properties: + # - One of two domains. + # - Random message lengths between 0 and 255 bytes. + # - Random message contents. + for _ in range(10): + domain = b"z.cash:test-longer" if rand.bool() else b"z.cash:test" + msg_len = rand.u8() + msg = bytes([rand.u8() for _ in range(msg_len)]) + test_vectors.append((domain, msg)) + + render_tv( + render_args(), + 'orchard_group_hash', + ( + ('domain', 'Vec'), + ('msg', 'Vec'), + ('point', '[u8; 32]'), + ), + [{ + 'domain': domain, + 'msg': msg, + 'point': bytes(group_hash(domain, msg)), + } for (domain, msg) in test_vectors], + ) + + +if __name__ == "__main__": + main() diff --git a/orchard_iso_pallas.py b/orchard_iso_pallas.py new file mode 100755 index 0000000..fbdfe93 --- /dev/null +++ b/orchard_iso_pallas.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +# -*- coding: utf8 -*- +import sys; assert sys.version_info[0] >= 3, "Python 3 required." + +import orchard_pallas +from orchard_pallas import Fp, p, q, Scalar + +# +# Point arithmetic +# + +PALLAS_ISO_B = Fp(1265) +PALLAS_ISO_A = Fp(0x18354a2eb0ea8c9c49be2d7258370742b74134581a27a59f92bb4b0b657a014b) + +class Point(object): + + @staticmethod + def from_bytes(buf): + assert len(buf) == 32 + if buf == bytes([0]*32): + return Point.identity() + + y_sign = buf[31] >> 7 + buf = buf[:31] + bytes([buf[31] & 0b01111111]) + try: + x = Fp.from_bytes(buf) + except ValueError: + return None + + x3 = x * x * x + y2 = x3 + PALLAS_ISO_A * x + PALLAS_ISO_B + + y = y2.sqrt() + if y is None: + return None + + if y.s % 2 != y_sign: + y = -y + + return Point(x, y) + + # Maps a point on iso-Pallas to a point on Pallas + def iso_map(self): + + c = [ + None, # make the indices 1-based + Fp(0x0e38e38e38e38e38e38e38e38e38e38e4081775473d8375b775f6034aaaaaaab), + Fp(0x3509afd51872d88e267c7ffa51cf412a0f93b82ee4b994958cf863b02814fb76), + Fp(0x17329b9ec525375398c7d7ac3d98fd13380af066cfeb6d690eb64faef37ea4f7), + Fp(0x1c71c71c71c71c71c71c71c71c71c71c8102eea8e7b06eb6eebec06955555580), + Fp(0x1d572e7ddc099cff5a607fcce0494a799c434ac1c96b6980c47f2ab668bcd71f), + Fp(0x325669becaecd5d11d13bf2a7f22b105b4abf9fb9a1fc81c2aa3af1eae5b6604), + Fp(0x1a12f684bda12f684bda12f684bda12f7642b01ad461bad25ad985b5e38e38e4), + Fp(0x1a84d7ea8c396c47133e3ffd28e7a09507c9dc17725cca4ac67c31d8140a7dbb), + Fp(0x3fb98ff0d2ddcadd303216cce1db9ff11765e924f745937802e2be87d225b234), + Fp(0x025ed097b425ed097b425ed097b425ed0ac03e8e134eb3e493e53ab371c71c4f), + Fp(0x0c02c5bcca0e6b7f0790bfb3506defb65941a3a4a97aa1b35a28279b1d1b42ae), + Fp(0x17033d3c60c68173573b3d7f7d681310d976bbfabbc5661d4d90ab820b12320a), + Fp(0x40000000000000000000000000000000224698fc094cf91b992d30ecfffffde5) + ] + + if self == Point.identity(): + return orchard_pallas.identity() + else: + numerator_a = c[1] * self.x * self.x * self.x + c[2] * self.x * self.x + c[3] * self.x + c[4] + denominator_a = self.x * self.x + c[5] * self.x + c[6] + + numerator_b = (c[7] * self.x * self.x * self.x + c[8] * self.x * self.x + c[9] * self.x + c[10]) * self.y + denominator_b = self.x * self.x * self.x + c[11] * self.x * self.x + c[12] * self.x + c[13] + + return orchard_pallas.Point(numerator_a / denominator_a, numerator_b / denominator_b) + + def __init__(self, x, y, is_identity=False): + self.x = x + self.y = y + self.is_identity = is_identity + + if is_identity: + assert self.x == Fp.ZERO + assert self.y == Fp.ZERO + else: + assert self.y * self.y == self.x * self.x * self.x + PALLAS_ISO_A * self.x + PALLAS_ISO_B + + def identity(): + p = Point(Fp.ZERO, Fp.ZERO, True) + return p + + def __neg__(self): + if self.is_identity: + return self + else: + return Point(Fp(self.x.s), -Fp(self.y.s)) + + def __add__(self, a): + if self.is_identity: + return a + elif a.is_identity: + return self + else: + # Hüseyin Hışıl. “Elliptic Curves, Group Law, and Efficient Computation”. PhD thesis. + # section 4.1 + (x1, y1) = (self.x, self.y) + (x2, y2) = (a.x, a.y) + + if x1 == x2: + if (y1 != y2) or (y1 == Fp(0)): + return Point.identity() + else: + return self.double() + else: + λ = (y1 - y2) / (x1 - x2) + x3 = λ*λ - x1 - x2 + y3 = λ*(x1 - x3) - y1 + return Point(x3, y3) + + def __sub__(self, a): + return (-a) + self + + def double(self): + if self.is_identity: + return self + + # Hüseyin Hışıl. “Elliptic Curves, Group Law, and Efficient Computation”. PhD thesis. + # section 4.1 + λ = (Fp(3) * self.x * self.x + PALLAS_ISO_A) / (self.y + self.y) + x3 = λ*λ - self.x - self.x + y3 = λ*(self.x - x3) - self.y + return Point(x3, y3) + + def __mul__(self, s): + s = format(s.s, '0256b') + ret = self.ZERO + for c in s: + ret = ret.double() + if int(c): + ret = ret + self + return ret + + def __bytes__(self): + if self.is_identity: + return bytes([0] * 32) + + buf = bytes(self.x) + if self.y.s % 2 == 1: + buf = buf[:31] + bytes([buf[31] | (1 << 7)]) + return buf + + def __eq__(self, a): + if a is None: + return False + if not (self.is_identity or a.is_identity): + return self.x == a.x and self.y == a.y + else: + return self.is_identity == a.is_identity + + def __str__(self): + if self.is_identity: + return 'Point(identity)' + else: + return 'Point(%s, %s)' % (self.x, self.y) + + +Point.ZERO = Point.identity() + +# This is an arbitrarily-chosen generator for testing purposes only, NOT a +# formally-selected common generator for iso-Pallas. +x = Fp(2) +y2 = x * x * x + PALLAS_ISO_A * x + PALLAS_ISO_B +y = y2.sqrt() +assert y is not None + +Point.GENERATOR = Point(x, y) + +assert Point.ZERO + Point.ZERO == Point.ZERO +assert Point.GENERATOR - Point.GENERATOR == Point.ZERO +assert Point.GENERATOR + Point.GENERATOR + Point.GENERATOR == Point.GENERATOR * Scalar(3) +assert Point.GENERATOR + Point.GENERATOR - Point.GENERATOR == Point.GENERATOR + +assert Point.from_bytes(bytes([0]*32)) == Point.ZERO +assert Point.from_bytes(bytes(Point.GENERATOR)) == Point.GENERATOR diff --git a/orchard_map_to_curve.py b/orchard_map_to_curve.py new file mode 100755 index 0000000..4d43dd1 --- /dev/null +++ b/orchard_map_to_curve.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 + +from orchard_group_hash import map_to_curve_simple_swu +from orchard_iso_pallas import Point as IsoPoint +from orchard_pallas import Fp +from sapling_utils import leos2ip +from tv_output import render_args, render_tv +from tv_rand import Rand + + +def main(): + fixed_test_vectors = [ + (Fp(0), IsoPoint(Fp(19938918781445865934736160264407396416050199005817793816893455093350997047296), + Fp(1448774895934493446148762800986014913165975534940595774801697325542407056356))), + (Fp(1), IsoPoint(Fp(5290181550357368025040301950220623271393946308300025648720253222947454165280), + Fp(24520995241805476578231005891941079870703368870355132644748659103632565232759))), + (Fp(0x123456789abcdef123456789abcdef123456789abcdef123456789abcdef0123), + IsoPoint(Fp(16711718778908753690082328243251803703269853000652055785581237369882690082595), + Fp(1764705856161931038824461929646873031992914829456409784642560948827969833589))), + ] + + for (u, point) in fixed_test_vectors: + P = map_to_curve_simple_swu(u) + assert P == point + + test_vectors = [u for (u, _) in fixed_test_vectors] + + from random import Random + rng = Random(0xabad533d) + def randbytes(l): + ret = [] + while len(ret) < l: + ret.append(rng.randrange(0, 256)) + return bytes(ret) + rand = Rand(randbytes) + + # Generate random test vectors + for _ in range(10): + test_vectors.append(Fp(leos2ip(rand.b(32)))) + + render_tv( + render_args(), + 'orchard_map_to_curve', + ( + ('u', '[u8; 32]'), + ('point', '[u8; 32]'), + ), + [{ + 'u': bytes(u), + 'point': bytes(map_to_curve_simple_swu(u)), + } for u in test_vectors], + ) + + +if __name__ == "__main__": + main() diff --git a/orchard_pallas.py b/orchard_pallas.py index 4451cec..29d5ad3 100644 --- a/orchard_pallas.py +++ b/orchard_pallas.py @@ -37,6 +37,10 @@ class Fp(FieldElement): def __str__(self): return 'Fp(%s)' % self.s + def sgn0(self): + # https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-10#section-4.1 + return (self.s % 2) == 1 + def sqrt(self): # Tonelli-Shank's algorithm for p mod 16 = 1 # https://eprint.iacr.org/2012/685.pdf (page 12, algorithm 5) @@ -138,14 +142,19 @@ class Point(object): return Point(x, y) - def __init__(self, x, y): + def __init__(self, x, y, is_identity=False): self.x = x self.y = y - self.is_identity = False + self.is_identity = is_identity + + if is_identity: + assert self.x == Fp.ZERO + assert self.y == Fp.ZERO + else: + assert self.y * self.y == self.x * self.x * self.x + PALLAS_B def identity(): - p = Point(Fp.ZERO, Fp.ZERO) - p.is_identity = True + p = Point(Fp.ZERO, Fp.ZERO, True) return p def __neg__(self): @@ -174,6 +183,13 @@ class Point(object): else: return self.double() + def checked_incomplete_add(self, a): + assert self != a + assert self != -a + assert self != Point.identity() + assert a != Point.identity() + return self + a + def __sub__(self, a): return (-a) + self @@ -186,6 +202,11 @@ class Point(object): x = λ*λ - self.x - self.x y = λ*(self.x - x) - self.y return Point(x, y) + + def extract(self): + if self.is_identity: + return Fp.ZERO + return self.x def __mul__(self, s): s = format(s.s, '0256b') diff --git a/orchard_sinsemilla.py b/orchard_sinsemilla.py new file mode 100755 index 0000000..2662272 --- /dev/null +++ b/orchard_sinsemilla.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +import sys; assert sys.version_info[0] >= 3, "Python 3 required." + +import math + +import orchard_iso_pallas + +from orchard_pallas import Fp, Point +from sapling_utils import cldiv, lebs2ip, i2leosp +from orchard_group_hash import group_hash +from tv_output import render_args, render_tv +from tv_rand import Rand + +SINSEMILLA_K = 10 + +# Interprets a string or a list as a sequence of bits. +def str_to_bits(s): + for c in s: + assert c in ['0', '1', 0, 1, False, True] + # Regular Python truthiness is fine here except for bool('0') == True. + return [c != '0' and bool(c) for c in s] + +def pad(n, m): + padding_needed = n * SINSEMILLA_K - len(m) + zeros = [0] * padding_needed + m = list(m) + zeros + + return [lebs2ip(str_to_bits(m[i*SINSEMILLA_K : (i+1)*SINSEMILLA_K])) for i in range(n)] + +def sinsemilla_hash_to_point(d, m): + n = cldiv(len(m), SINSEMILLA_K) + m = pad(n, m) + acc = group_hash(b"z.cash:SinsemillaQ", d) + + for m_i in m: + acc = acc.checked_incomplete_add( + group_hash(b"z.cash:SinsemillaS", i2leosp(32, m_i)) + ).checked_incomplete_add(acc) + + return acc + +def sinsemilla_hash(d, m): + return sinsemilla_hash_to_point(d, m).extract() + + +def main(): + test_vectors = [ + # 40 bits, so no padding + (b"z.cash:test-Sinsemilla", [0,0,0,1,0,1,1,0,1,0,1,0,0,1,1,0,0,0,1,1,0,1,1,0,0,0,1,1,0,1,1,0,1,1,1,1,0,1,1,0]), + ] + + sh = sinsemilla_hash_to_point(test_vectors[0][0], test_vectors[0][1]) + assert sh == Point(Fp(19681977528872088480295086998934490146368213853811658798708435106473481753752), + Fp(14670850419772526047574141291705097968771694788047376346841674072293161339903)) + + from random import Random + rng = Random(0xabad533d) + def randbytes(l): + ret = [] + while len(ret) < l: + ret.append(rng.randrange(0, 256)) + return bytes(ret) + rand = Rand(randbytes) + + # Generate test vectors with the following properties: + # - One of two domains. + # - Random message lengths between 0 and 255 bytes. + # - Random message bits. + for _ in range(10): + domain = b"z.cash:test-Sinsemilla-longer" if rand.bool() else b"z.cash:test-Sinsemilla" + msg_len = rand.u8() + msg = bytes([rand.bool() for _ in range(msg_len)]) + test_vectors.append((domain, msg)) + + test_vectors = [{ + 'domain': domain, + 'msg': msg, + 'point': bytes(sinsemilla_hash_to_point(domain, msg)), + 'hash': bytes(sinsemilla_hash(domain, msg)), + } for (domain, msg) in test_vectors] + + render_tv( + render_args(), + 'orchard_sinsemilla', + ( + ('domain', 'Vec'), + ('msg', { + 'rust_type': 'Vec', + 'rust_fmt': lambda x: str_to_bits(x), + }), + ('point', '[u8; 32]'), + ('hash', '[u8; 32]'), + ), + test_vectors, + ) + + +if __name__ == "__main__": + main() diff --git a/transaction.py b/transaction.py index 7a2582b..aa3ca64 100644 --- a/transaction.py +++ b/transaction.py @@ -155,7 +155,7 @@ RAND_OPCODES = [ class Script(object): def __init__(self, rand): self._script = bytes([ - rand.a(RAND_OPCODES) for i in range(rand.u8() % 10) + rand.a(RAND_OPCODES) for i in range(rand.i8() % 10) ]) def raw(self): @@ -212,11 +212,11 @@ class Transaction(object): self.nVersion = rand.u32() & ((1 << 31) - 1) self.vin = [] - for i in range(rand.u8() % 3): + for i in range(rand.i8() % 3): self.vin.append(TxIn(rand)) self.vout = [] - for i in range(rand.u8() % 3): + for i in range(rand.i8() % 3): self.vout.append(TxOut(rand)) self.nLockTime = rand.u32() @@ -227,14 +227,14 @@ class Transaction(object): self.vShieldedSpends = [] self.vShieldedOutputs = [] if self.nVersion >= SAPLING_TX_VERSION: - for _ in range(rand.u8() % 5): + for _ in range(rand.i8() % 5): self.vShieldedSpends.append(SpendDescription(rand)) - for _ in range(rand.u8() % 5): + for _ in range(rand.i8() % 5): self.vShieldedOutputs.append(OutputDescription(rand)) self.vJoinSplit = [] if self.nVersion >= 2: - for i in range(rand.u8() % 3): + for i in range(rand.i8() % 3): self.vJoinSplit.append(JoinSplit(rand, self.fOverwintered and self.nVersion >= SAPLING_TX_VERSION)) if len(self.vJoinSplit) > 0: self.joinSplitPubKey = rand.b(32) # Potentially invalid diff --git a/tv_output.py b/tv_output.py index fcb225d..90180a2 100644 --- a/tv_output.py +++ b/tv_output.py @@ -75,6 +75,17 @@ def tv_vec_bytes_rust(name, value, pad): pad, )) +def tv_vec_bool_rust(name, value, pad): + print('''%s%s: vec![ + %s%s +%s],''' % ( + pad, + name, + pad, + ', '.join(['true' if x else 'false' for x in value]), + pad, + )) + def tv_option_bytes_rust(name, value, pad): if value: print('''%s%s: Some([ @@ -121,6 +132,8 @@ def tv_part_rust(name, value, config, indent=3): tv_option_vec_bytes_rust(name, value, pad) elif config['rust_type'] == 'Vec': tv_vec_bytes_rust(name, value, pad) + elif config['rust_type'] == 'Vec': + tv_vec_bool_rust(name, value, pad) elif config['rust_type'].startswith('Option<['): tv_option_bytes_rust(name, value, pad) elif type(value) == bytes: diff --git a/tv_rand.py b/tv_rand.py index 3035286..e064332 100644 --- a/tv_rand.py +++ b/tv_rand.py @@ -12,9 +12,12 @@ class Rand(object): def v(self, l, f): return struct.unpack(f, self.b(l))[0] - def u8(self): + def i8(self): return self.v(1, 'b') + def u8(self): + return self.v(1, 'B') + def u32(self): return self.v(4, '