diff --git a/hashtocurve.sage b/hashtocurve.sage index 1113853..8c2ed0e 100755 --- a/hashtocurve.sage +++ b/hashtocurve.sage @@ -42,9 +42,11 @@ class Cost: self.muls += 1 return x / y - def inv0(self, x): + def batch_inv0(self, xs): self.invs += 1 - return 0 if x == 0 else x^-1 + self.muls += 3*(len(xs)-1) + # This should use Montgomery's trick (with constant-time substitutions to handle zeros). + return [0 if x == 0 else x^-1 for x in xs] def sqrt(self, x): self.sqrs += 247 @@ -98,57 +100,57 @@ def select_z_nz(s, ifz, ifnz): # This should be constant-time in a real implementation. return ifz if (s == 0) else ifnz -def map_to_curve_simple_swu(E, Z, u, c): +def map_to_curve_simple_swu(E, Z, us, c): # would be precomputed (0, 0, 0, A, B) = E.a_invariants() mBdivA = -B / A BdivZA = B / (Z * A) - #print("A = ", A) - #print("B = ", B) - #print("Z = ", Z) - #print("-B/A = ", mBdivA) - #print("B/ZA = ", BdivZA) + Z2 = Z^2 # 1. tv1 = inv0(Z^2 * u^4 + Z * u^2) - Z2 = c.sqr(Z) - u2 = c.sqr(u) - u4 = c.sqr(u2) - ta = c.mul(Z2, u4) + c.mul(Z, u2) - tv1 = c.inv0(ta) + # = inv0((Z^2 * u^2 + Z) * u^2) + u2s = [c.sqr(u) for u in us] + tas = [c.mul((Z2*u2 + Z), u2) for u2 in u2s] + tv1s = c.batch_inv0(tas) - # 2. x1 = (-B / A) * (1 + tv1) - # 3. If tv1 == 0, set x1 = B / (Z * A) - x1 = select_z_nz(tv1, BdivZA, mBdivA * (1 + tv1)) + Qs = [] + for i in range(len(us)): + (u, u2, tv1) = (us[i], u2s[i], tv1s[i]) - # 4. gx1 = x1^3 + A * x1 + B - # = x1*(x1^2 + A) + B - x1_2 = c.sqr(x1) - gx1 = c.mul(x1, x1_2 + A) + B + # 2. x1 = (-B / A) * (1 + tv1) + # 3. If tv1 == 0, set x1 = B / (Z * A) + x1 = select_z_nz(tv1, BdivZA, mBdivA * (1 + tv1)) - # 5. x2 = Z * u^2 * x1 - tb = c.mul(Z, u2) - x2 = c.mul(tb, x1) + # 4. gx1 = x1^3 + A * x1 + B + # = x1*(x1^2 + A) + B + x1_2 = c.sqr(x1) + gx1 = c.mul(x1, x1_2 + A) + B - # 6. gx2 = x2^3 + A * x2 + B - # = x2*(x2^2 + A) + B - x2_2 = c.sqr(x2) - gx2 = c.mul(x2, x2_2 + A) + B + # 5. x2 = Z * u^2 * x1 + tb = c.mul(Z, u2) + x2 = c.mul(tb, x1) - # 7. If is_square(gx1), set x = x1 and y = sqrt(gx1) - # 8. Else set x = x2 and y = sqrt(gx2) - y1 = c.sqrt(gx1) - y1_2 = c.sqr(y1) - if CONSTANT_TIME or y1_2 != gx1: - y2 = c.sqrt(gx2) - x = select_z_nz(y1_2 - gx1, x1, x2) - y = select_z_nz(y1_2 - gx1, y1, y2) - else: - (x, y) = (x1, y1) + # 6. gx2 = x2^3 + A * x2 + B + # = x2*(x2^2 + A) + B + x2_2 = c.sqr(x2) + gx2 = c.mul(x2, x2_2 + A) + B - # 9. If sgn0(u) != sgn0(y), set y = -y - y = select_z_nz((int(u) % 2) - (int(y) % 2), y, -y) + # 7. If is_square(gx1), set x = x1 and y = sqrt(gx1) + # 8. Else set x = x2 and y = sqrt(gx2) + y1 = c.sqrt(gx1) + y1_2 = c.sqr(y1) + if CONSTANT_TIME or y1_2 != gx1: + y2 = c.sqrt(gx2) + x = select_z_nz(y1_2 - gx1, x1, x2) + y = select_z_nz(y1_2 - gx1, y1, y2) + else: + (x, y) = (x1, y1) - return E((x, y)) + # 9. If sgn0(u) != sgn0(y), set y = -y + y = select_z_nz((int(u) % 2) - (int(y) % 2), y, -y) + Qs.append(E((x, y))) + + return Qs # iso_Ep = Isogeny of degree 3 from Elliptic Curve defined by y^2 = x^3 + 10949663248450308183708987909873589833737836120165333298109615750520499732811*x + 1265 over Fp @@ -198,25 +200,21 @@ def OS2IP(bs): def hash_to_curve(msg, DST, uniform=True): c = Cost() - u = hash_to_field(msg, DST, 2) + us = hash_to_field(msg, DST, 2 if uniform else 1) #print("u = ", u) - R = map_to_curve_simple_swu(E_isop, Z_isop, u[0], c) + Qs = map_to_curve_simple_swu(E_isop, Z_isop, us, c) if uniform: - Q1 = map_to_curve_simple_swu(E_isop, Z_isop, u[1], c) - - # We could batch the two inv0 inversions (not done above for simplicity). - c.invs -= 1 - c.muls += 3 - # Complete addition using affine coordinates: I + 2M + 2S # (S for x1^2; compute numerator and denominator of the division for the correct case; # I + M to divide; S + M to compute x and y of the result.) - R = R + Q1 + R = Qs[0] + Qs[1] #print("R = ", R) c.invs += 1 c.sqrs += 2 c.muls += 2 + else: + R = Qs[0] # no cofactor clearing needed since Pallas and Vesta are prime-order (x, y) = R.xy() @@ -228,5 +226,5 @@ def hash_to_curve(msg, DST, uniform=True): iters = 100 for i in range(iters): - (res, cost) = hash_to_curve(pack(">I", i), "blah", uniform=False) + (res, cost) = hash_to_curve(pack(">I", i), "blah", uniform=True) print(res, cost)