hashtocurve.sage: more realistic use of Montgomery's trick.

Signed-off-by: Daira Hopwood <daira@jacaranda.org>
This commit is contained in:
Daira Hopwood 2020-12-29 17:52:35 +00:00
parent 96fd2c794e
commit 7df33f4ce4
1 changed files with 48 additions and 50 deletions

View File

@ -42,9 +42,11 @@ class Cost:
self.muls += 1 self.muls += 1
return x / y return x / y
def inv0(self, x): def batch_inv0(self, xs):
self.invs += 1 self.invs += 1
return 0 if x == 0 else x^-1 self.muls += 3*(len(xs)-1)
# This should use Montgomery's trick (with constant-time substitutions to handle zeros).
return [0 if x == 0 else x^-1 for x in xs]
def sqrt(self, x): def sqrt(self, x):
self.sqrs += 247 self.sqrs += 247
@ -98,57 +100,57 @@ def select_z_nz(s, ifz, ifnz):
# This should be constant-time in a real implementation. # This should be constant-time in a real implementation.
return ifz if (s == 0) else ifnz return ifz if (s == 0) else ifnz
def map_to_curve_simple_swu(E, Z, u, c): def map_to_curve_simple_swu(E, Z, us, c):
# would be precomputed # would be precomputed
(0, 0, 0, A, B) = E.a_invariants() (0, 0, 0, A, B) = E.a_invariants()
mBdivA = -B / A mBdivA = -B / A
BdivZA = B / (Z * A) BdivZA = B / (Z * A)
#print("A = ", A) Z2 = Z^2
#print("B = ", B)
#print("Z = ", Z)
#print("-B/A = ", mBdivA)
#print("B/ZA = ", BdivZA)
# 1. tv1 = inv0(Z^2 * u^4 + Z * u^2) # 1. tv1 = inv0(Z^2 * u^4 + Z * u^2)
Z2 = c.sqr(Z) # = inv0((Z^2 * u^2 + Z) * u^2)
u2 = c.sqr(u) u2s = [c.sqr(u) for u in us]
u4 = c.sqr(u2) tas = [c.mul((Z2*u2 + Z), u2) for u2 in u2s]
ta = c.mul(Z2, u4) + c.mul(Z, u2) tv1s = c.batch_inv0(tas)
tv1 = c.inv0(ta)
# 2. x1 = (-B / A) * (1 + tv1) Qs = []
# 3. If tv1 == 0, set x1 = B / (Z * A) for i in range(len(us)):
x1 = select_z_nz(tv1, BdivZA, mBdivA * (1 + tv1)) (u, u2, tv1) = (us[i], u2s[i], tv1s[i])
# 4. gx1 = x1^3 + A * x1 + B # 2. x1 = (-B / A) * (1 + tv1)
# = x1*(x1^2 + A) + B # 3. If tv1 == 0, set x1 = B / (Z * A)
x1_2 = c.sqr(x1) x1 = select_z_nz(tv1, BdivZA, mBdivA * (1 + tv1))
gx1 = c.mul(x1, x1_2 + A) + B
# 5. x2 = Z * u^2 * x1 # 4. gx1 = x1^3 + A * x1 + B
tb = c.mul(Z, u2) # = x1*(x1^2 + A) + B
x2 = c.mul(tb, x1) x1_2 = c.sqr(x1)
gx1 = c.mul(x1, x1_2 + A) + B
# 6. gx2 = x2^3 + A * x2 + B # 5. x2 = Z * u^2 * x1
# = x2*(x2^2 + A) + B tb = c.mul(Z, u2)
x2_2 = c.sqr(x2) x2 = c.mul(tb, x1)
gx2 = c.mul(x2, x2_2 + A) + B
# 7. If is_square(gx1), set x = x1 and y = sqrt(gx1) # 6. gx2 = x2^3 + A * x2 + B
# 8. Else set x = x2 and y = sqrt(gx2) # = x2*(x2^2 + A) + B
y1 = c.sqrt(gx1) x2_2 = c.sqr(x2)
y1_2 = c.sqr(y1) gx2 = c.mul(x2, x2_2 + A) + B
if CONSTANT_TIME or y1_2 != gx1:
y2 = c.sqrt(gx2)
x = select_z_nz(y1_2 - gx1, x1, x2)
y = select_z_nz(y1_2 - gx1, y1, y2)
else:
(x, y) = (x1, y1)
# 9. If sgn0(u) != sgn0(y), set y = -y # 7. If is_square(gx1), set x = x1 and y = sqrt(gx1)
y = select_z_nz((int(u) % 2) - (int(y) % 2), y, -y) # 8. Else set x = x2 and y = sqrt(gx2)
y1 = c.sqrt(gx1)
y1_2 = c.sqr(y1)
if CONSTANT_TIME or y1_2 != gx1:
y2 = c.sqrt(gx2)
x = select_z_nz(y1_2 - gx1, x1, x2)
y = select_z_nz(y1_2 - gx1, y1, y2)
else:
(x, y) = (x1, y1)
return E((x, y)) # 9. If sgn0(u) != sgn0(y), set y = -y
y = select_z_nz((int(u) % 2) - (int(y) % 2), y, -y)
Qs.append(E((x, y)))
return Qs
# iso_Ep = Isogeny of degree 3 from Elliptic Curve defined by y^2 = x^3 + 10949663248450308183708987909873589833737836120165333298109615750520499732811*x + 1265 over Fp # iso_Ep = Isogeny of degree 3 from Elliptic Curve defined by y^2 = x^3 + 10949663248450308183708987909873589833737836120165333298109615750520499732811*x + 1265 over Fp
@ -198,25 +200,21 @@ def OS2IP(bs):
def hash_to_curve(msg, DST, uniform=True): def hash_to_curve(msg, DST, uniform=True):
c = Cost() c = Cost()
u = hash_to_field(msg, DST, 2) us = hash_to_field(msg, DST, 2 if uniform else 1)
#print("u = ", u) #print("u = ", u)
R = map_to_curve_simple_swu(E_isop, Z_isop, u[0], c) Qs = map_to_curve_simple_swu(E_isop, Z_isop, us, c)
if uniform: if uniform:
Q1 = map_to_curve_simple_swu(E_isop, Z_isop, u[1], c)
# We could batch the two inv0 inversions (not done above for simplicity).
c.invs -= 1
c.muls += 3
# Complete addition using affine coordinates: I + 2M + 2S # Complete addition using affine coordinates: I + 2M + 2S
# (S for x1^2; compute numerator and denominator of the division for the correct case; # (S for x1^2; compute numerator and denominator of the division for the correct case;
# I + M to divide; S + M to compute x and y of the result.) # I + M to divide; S + M to compute x and y of the result.)
R = R + Q1 R = Qs[0] + Qs[1]
#print("R = ", R) #print("R = ", R)
c.invs += 1 c.invs += 1
c.sqrs += 2 c.sqrs += 2
c.muls += 2 c.muls += 2
else:
R = Qs[0]
# no cofactor clearing needed since Pallas and Vesta are prime-order # no cofactor clearing needed since Pallas and Vesta are prime-order
(x, y) = R.xy() (x, y) = R.xy()
@ -228,5 +226,5 @@ def hash_to_curve(msg, DST, uniform=True):
iters = 100 iters = 100
for i in range(iters): for i in range(iters):
(res, cost) = hash_to_curve(pack(">I", i), "blah", uniform=False) (res, cost) = hash_to_curve(pack(">I", i), "blah", uniform=True)
print(res, cost) print(res, cost)