diff --git a/hashtocurve.sage b/hashtocurve.sage index 5671a89..b2322bd 100755 --- a/hashtocurve.sage +++ b/hashtocurve.sage @@ -6,15 +6,8 @@ import sys from math import ceil, log from struct import pack - -import hashlib -if sys.version_info < (3, 6): - try: - import sha3 - except ImportError: - print('Please run:\n`sage -c "import sys; print(sys.executable)"` -m pip install pysha3\n') - raise -from hashlib import shake_128 +from pyblake2 import blake2b +from hashlib import sha256 if sys.version_info[0] == 2: range = xrange @@ -22,6 +15,10 @@ if sys.version_info[0] == 2: else: as_byte = lambda x: x +def as_bytes(x): + # + return bytes(bytearray(x)) + load('squareroottab.sage') DEBUG = True @@ -64,7 +61,12 @@ def is_good_Z(F, g, A, B, Z): # Point in Chudnovsky coordinates (Jacobian with Z^2 and Z^3 cached). class ChudnovskyPoint: - def __init__(self, E, x, y, z, z2, z3): + def __init__(self, E, x, y, z, z2=None, z3=None): + if z2 is None: + z2 = z^2 + if z3 is None: + z3 = z^3 + if DEBUG: (0, 0, 0, A, B) = E.a_invariants() assert z2 == z^2 @@ -159,7 +161,7 @@ class ChudnovskyPoint: return (self.x, self.y, self.z) def __repr__(self): - return "%r : %r : %r : %r : %r" % (hex(int(self.x)), hex(int(self.y)), hex(int(self.z)), hex(int(self.z2)), hex(int(self.z3))) + return "ChudnovskyPoint {\n 0x%064x\n: 0x%064x\n: 0x%064x\n: 0x%064x\n: 0x%064x\n}" % (int(self.x), int(self.y), int(self.z), int(self.z2), int(self.z3)) assert p == 0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001 @@ -167,25 +169,25 @@ assert q == 0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001 Fp = GF(p) Fq = GF(q) -E_isop_A = 10949663248450308183708987909873589833737836120165333298109615750520499732811 -E_isoq_A = 17413348858408915339762682399132325137863850198379221683097628341577494210225 -E_isop_B = 1265 -E_isoq_B = 1265 -E_isop = EllipticCurve(Fp, [E_isop_A, E_isop_B]) -E_isoq = EllipticCurve(Fq, [E_isoq_A, E_isoq_B]) -E_p = EllipticCurve(Fp, [0, 5]) -E_q = EllipticCurve(Fq, [0, 5]) +IsoEp_A = 10949663248450308183708987909873589833737836120165333298109615750520499732811 +IsoEq_A = 17413348858408915339762682399132325137863850198379221683097628341577494210225 +IsoEp_B = 1265 +IsoEq_B = 1265 +IsoEp = EllipticCurve(Fp, [IsoEp_A, IsoEp_B]) +IsoEq = EllipticCurve(Fq, [IsoEq_A, IsoEq_B]) +Ep = EllipticCurve(Fp, [0, 5]) +Eq = EllipticCurve(Fq, [0, 5]) -k = 128 +k = 256 Lp = (len(format(p, 'b')) + k + 7) // 8 Lq = (len(format(q, 'b')) + k + 7) // 8 -assert Lp == 48 and Lq == 48 -L = Lp +assert Lp == 64 and Lq == 64 +CHUNKLEN = Lp -Z_isop = find_z_sswu(E_isop) -Z_isoq = find_z_sswu(E_isoq) -assert Z_isop == Mod(-13, p) -assert Z_isoq == Mod(-13, q) +IsoEpZ = find_z_sswu(IsoEp) +IsoEqZ = find_z_sswu(IsoEq) +assert IsoEpZ == Mod(-13, p) +assert IsoEqZ == Mod(-13, q) def select_z_nz(s, ifz, ifnz): @@ -193,6 +195,8 @@ def select_z_nz(s, ifz, ifnz): return ifz if (s == 0) else ifnz def map_to_curve_simple_swu(F, E, Z, u, c): + if VERBOSE: print("map_to_curve(0x%064x)" % (u,)) + # would be precomputed h = F.g (0, 0, 0, A, B) = E.a_invariants() @@ -248,6 +252,7 @@ def map_to_curve_simple_swu(F, E, Z, u, c): # 7. If is_square(gx1), set x = x1 and y = sqrt(gx1) # 8. Else set x = x2 and y = sqrt(gx2) (y1, zero_if_gx1_square) = F.sarkar_divsqrt(U, D3, c) + if VERBOSE: print("zero_if_gx1_square = %064x" % (zero_if_gx1_square,)) # This magic also comes from a generalization of [WB2019, section 4.2]. # @@ -314,6 +319,7 @@ def isop_map_jacobian(P, c): z4 = c.sqr(z2) z6 = c.sqr(z3) + if VERBOSE: print("IsoEp { x: 0x%064x, y: 0x%064x, z: 0x%064x }" % (x, y, z)) Nx = ((( 6432893846517566412420610278260439325191790329320346825767705947633326140075 *x + 23989696149150192365340222745168215001509815558210986772351135915822265203574*z2)*x + 10492611921771203378452795982353351666191589197598957448093274638589204800759*z4)*x + @@ -335,6 +341,8 @@ def isop_map_jacobian(P, c): 28948022309329048855892746252171976963363056481941560715954676764349967629797*z6) * z3 c.muls += 6 + if VERBOSE: print("num_x = 0x%064x\ndiv_x = 0x%064x\nnum_y = 0x%064x\ndiv_y = 0x%064x" % (Nx, Dx, Ny, Dy)) + zo = c.mul(Dx, Dy) xo = c.mul(c.mul(Nx, Dy), zo) yo = c.mul(c.mul(Ny, Dx), c.sqr(zo)) @@ -405,22 +413,27 @@ def isoq_map_jacobian(P, c): assert isoq_map_affine(x / z2, y / z3, Cost()) == (xo / zo^2, yo / zo^3) return (xo, yo, zo) +def hex_bytes(bs): + return "[%s]" % (", ".join(["%02x" % (as_byte(b),) for b in bs]),) -def expand_message_xof(msg, DST, len_in_bytes): - assert len(DST) < 256 - len_in_bytes = int(len_in_bytes) +def hash(hasher, msg): + if VERBOSE: print(hex_bytes(msg)) + h = hasher() + h.update(msg) + return h.digest() - # This is horrible but matches the reference code. - xof = shake_128() - xof.update(msg) - xof.update(pack(">H", len_in_bytes)) - xof.update(pack("B", len(DST))) - xof.update(DST) - return xof.digest(len_in_bytes) +SHA256 = (sha256, 32, 64) +BLAKE2b = (blake2b, 64, 64) -def hash_to_field(modulus, msg, DST, count): - uniform_bytes = expand_message_xof(msg, DST, L*count) - return [Mod(OS2IP(uniform_bytes[L*i : L*(i+1)]), modulus) for i in range(count)] +def hash_to_field(modulus, message, DST, count): + outlen = int(count * CHUNKLEN) + uniform_bytes = expand_message_xmd(BLAKE2b, message, DST, outlen) + if VERBOSE: + print("uniform_bytes:") + print(hex_bytes(uniform_bytes[: 64])) + print(hex_bytes(uniform_bytes[64 :])) + + return [Mod(OS2IP(uniform_bytes[CHUNKLEN*i : CHUNKLEN*(i+1)]), modulus) for i in range(count)] def OS2IP(bs): acc = 0 @@ -428,50 +441,91 @@ def OS2IP(bs): acc = (acc<<8) + as_byte(b) return acc +def expand_message_xmd(H, msg, DST, len_in_bytes): + (hasher, b_in_bytes, r_in_bytes) = H + assert len(DST) <= 255 + ell = (len_in_bytes + b_in_bytes - 1)//b_in_bytes + assert ell <= 255 + + DST_prime = DST + as_bytes([len(DST)]) + msg_prime = b"\x00"*r_in_bytes + bytes(msg) + as_bytes([len_in_bytes >> 8, len_in_bytes & 0xFF, 0]) + DST_prime + + if VERBOSE: print("b_0:") + b_0 = hash(hasher, msg_prime) + if VERBOSE: print("b_1:") + b = hash(hasher, b_0 + b"\x01" + DST_prime) + for i in range(2, ell+1): + if VERBOSE: print("b_%d:" % (i,)) + b += hash(hasher, as_bytes(as_byte(x) ^^ as_byte(y) for x, y in zip(b_0, b[-64 :])) + as_bytes([i]) + DST_prime) + + return b[: len_in_bytes] + def hash_to_pallas_jacobian(msg, DST): c = Cost() us = hash_to_field(p, msg, DST, 2) - #print("u = ", u) - Q0 = map_to_curve_simple_swu(F_p, E_isop, Z_isop, us[0], c) - Q1 = map_to_curve_simple_swu(F_p, E_isop, Z_isop, us[1], c) + if VERBOSE: print("us = [0x%064x, 0x064%x]" % (us[0], us[1])) + Q0 = map_to_curve_simple_swu(F_p, IsoEp, IsoEpZ, us[0], c) + Q1 = map_to_curve_simple_swu(F_p, IsoEp, IsoEpZ, us[1], c) - R = Q0.add(Q1, E_isop, c) - # Q0.add(Q0, E_isop, Cost()) # check that unified addition works + R = Q0.add(Q1, IsoEp, c) + # Q0.add(Q0, IsoEp, Cost()) # check that unified addition works # no cofactor clearing needed since Pallas is prime-order (Px, Py, Pz) = isop_map_jacobian(R, c) - P = E_p((Px / Pz^2, Py / Pz^3)) - return (P, c) + P = Ep((Px / Pz^2, Py / Pz^3)) + return ((Px, Py, Pz), c) def hash_to_vesta_jacobian(msg, DST): c = Cost() us = hash_to_field(q, msg, DST, 2) - #print("u = ", u) - Q0 = map_to_curve_simple_swu(F_q, E_isoq, Z_isoq, us[0], c) - Q1 = map_to_curve_simple_swu(F_q, E_isoq, Z_isoq, us[1], c) + if VERBOSE: print("us = [0x%064x, 0x064%x]" % (us[0], us[1])) + Q0 = map_to_curve_simple_swu(F_q, IsoEq, IsoEqZ, us[0], c) + Q1 = map_to_curve_simple_swu(F_q, IsoEq, IsoEqZ, us[1], c) - R = Q0.add(Q1, E_isoq, c) - # Q0.add(Q0, E_isoq, Cost()) # check that unified addition works + R = Q0.add(Q1, IsoEq, c) + # Q0.add(Q0, IsoEq, Cost()) # check that unified addition works # no cofactor clearing needed since Vesta is prime-order (Px, Py, Pz) = isoq_map_jacobian(R, c) - P = E_q((Px / Pz^2, Py / Pz^3)) - return (P, c) + P = Eq((Px / Pz^2, Py / Pz^3)) + return ((Px, Py, Pz), c) -print(map_to_curve_simple_swu(F_p, E_isop, Z_isop, Mod(1, p), Cost())) print("") -print(map_to_curve_simple_swu(F_q, E_isoq, Z_isoq, Mod(1, q), Cost())) +print(map_to_curve_simple_swu(F_p, IsoEp, IsoEpZ, Mod(0, p), Cost())) +print("") +print(map_to_curve_simple_swu(F_p, IsoEp, IsoEpZ, Mod(1, p), Cost())) print("") -print(hash_to_pallas_jacobian("hello", "blah")) +print(map_to_curve_simple_swu(F_q, IsoEq, IsoEqZ, Mod(0, q), Cost())) print("") -print(hash_to_vesta_jacobian("hello", "blah")) +print(map_to_curve_simple_swu(F_q, IsoEq, IsoEqZ, Mod(1, q), Cost())) +print("") + +(x, y, z) = isop_map_jacobian( + ChudnovskyPoint(IsoEp, + Mod(0x0a881e4d556945aa9c6cfc47bce1aba6593c053e5e2337adc37f111df5c4419e, p), + Mod(0x035e5c8a06d5cfb4a62eec46f662cb4e6979f7f2b0acf188f234e04434502b47, p), + Mod(0x3af37975b09331256ac4e343558dcbf3575baa717958ef1f11ab791d4fb6f6b4, p)), + Cost()) +print("Ep { x: 0x%064x, y: 0x%064x, z: 0x%064x }" % (x, y, z)) +print("") + +# This test vector is chosen so that the first map_to_curve_simple_swu takes the gx1 square +# "branch" and the second takes the gx1 non-square "branch" (opposite to the Vesta test vector). +((x, y, z), c) = hash_to_pallas_jacobian(b"world", "z.cash:test-pallas_XMD:BLAKE2b_SSWU_RO_") +print("Ep { x: 0x%064x, y: 0x%064x, z: 0x%064x }" % (x, y, z)) +print("") + +# This test vector is chosen so that the first map_to_curve_simple_swu takes the gx1 non-square +# "branch" and the second takes the gx1 square "branch" (opposite to the Pallas test vector). +((x, y, z), c) = hash_to_vesta_jacobian(b"hello", "z.cash:test-vesta_XMD:BLAKE2b_SSWU_RO_") +print("Eq { x: 0x%064x, y: 0x%064x, z: 0x%064x }" % (x, y, z)) print("") if OP_COUNT: iters = 100 for i in range(iters): - (R, cost) = hash_to_pallas_jacobian(pack(">I", i), "blah") + (R, cost) = hash_to_pallas_jacobian(pack(">I", i), "z.cash:test-pallas_XMD:BLAKE2b_SSWU_RO_") print(R, cost) diff --git a/squareroottab.sage b/squareroottab.sage index 90ef70d..e1b6342 100755 --- a/squareroottab.sage +++ b/squareroottab.sage @@ -12,8 +12,8 @@ DEBUG = True VERBOSE = False EXPENSIVE = False -SUBGROUP_TEST = True -OP_COUNT = True +SUBGROUP_TEST = False +OP_COUNT = False class Cost: @@ -137,8 +137,8 @@ class SqrtField: return (res, zero_if_square) """ - Return (sqrt(N/D), True, c), if N/D is square in the field. - (sqrt(g*N/D), False, c), otherwise. + Return (sqrt(N/D), True ), if N/D is square in the field. + (sqrt(g*N/D), False), otherwise. This avoids the full cost of computing N/D. """