Implement the optimization from [WB2019, section 4.2] that removes the remaining inversion.

Signed-off-by: Daira Hopwood <daira@jacaranda.org>
This commit is contained in:
Daira Hopwood 2021-01-01 03:36:05 +00:00
parent 391e67f250
commit 50d3e83467
3 changed files with 156 additions and 96 deletions

View File

@ -124,3 +124,17 @@ ss = sch.sqrmul(sr, 3, s1)
st = sch.sqr(ss, 4)
assert st == s, format(st, 'b')
print(sch)
t = (1<<32) - 1
assert(s == q >> (n+1))
tch = Chain()
t1 = 1
t2 = tch.sqrmul(t1, 1, t1)
t4 = tch.sqrmul(t2, 2, t2)
t8 = tch.sqrmul(t4, 4, t4)
t16 = tch.sqrmul(t8, 8, t8)
t32 = tch.sqrmul(t16, 16, t16)
assert t32 == t, format(t32, 'b')
print(tch)

View File

@ -24,42 +24,7 @@ else:
load('squareroottab.sage')
class Cost:
def __init__(self, sqrs=0, muls=0, invs=0):
self.sqrs = sqrs
self.muls = muls
self.invs = invs
def sqr(self, x):
self.sqrs += 1
return x^2
def mul(self, x, y):
self.muls += 1
return x * y
def div(self, x, y):
self.invs += 1
self.muls += 1
return x / y
def batch_inv0(self, xs):
self.invs += 1
self.muls += 3*(len(xs)-1)
# This should use Montgomery's trick (with constant-time substitutions to handle zeros).
return [0 if x == 0 else x^-1 for x in xs]
def sqrt(self, x):
self.sqrs += 247
self.muls += 35
(res, _, _) = F_p.sarkar_sqrt(x)
return res
def __add__(self, other):
return Cost(self.sqrs + other.sqrs, self.muls + other.muls, self.invs + other.invs)
def __repr__(self):
return "%dS + %dM + %dI" % (self.sqrs, self.muls, self.invs)
DEBUG = True
# E: a short Weierstrass elliptic curve
def find_z_sswu(E):
@ -128,8 +93,9 @@ def select_z_nz(s, ifz, ifnz):
# This should be constant-time in a real implementation.
return ifz if (s == 0) else ifnz
def map_to_curve_simple_swu(E, Z, h, us, c):
def map_to_curve_simple_swu(F, E, Z, us, c):
# would be precomputed
h = F.g
(0, 0, 0, A, B) = E.a_invariants()
mBdivA = -B / A
BdivZA = B / (Z * A)
@ -137,37 +103,15 @@ def map_to_curve_simple_swu(E, Z, h, us, c):
assert (Z/h).is_square()
theta = sqrt(Z/h)
# 1. tv1 = inv0(Z^2 * u^4 + Z * u^2)
# = inv0((Z^2 * u^2 + Z) * u^2)
u2s = [c.sqr(u) for u in us]
tas = [c.mul((Z2*u2 + Z), u2) for u2 in u2s]
tv1s = c.batch_inv0(tas)
Qs = []
for i in range(len(us)):
(u, u2, tv1) = (us[i], u2s[i], tv1s[i])
for u in us:
# 1. tv1 = inv0(Z^2 * u^4 + Z * u^2)
# 2. x1 = (-B / A) * (1 + tv1)
# 3. If tv1 == 0, set x1 = B / (Z * A)
x1 = select_z_nz(tv1, BdivZA, mBdivA * (1 + tv1))
# 4. gx1 = x1^3 + A * x1 + B
# = x1*(x1^2 + A) + B
x1_2 = c.sqr(x1)
gx1 = c.mul(x1, x1_2 + A) + B
# 5. x2 = Z * u^2 * x1
Zu2 = Z * u2 # Z is small
x2 = c.mul(Zu2, x1)
# 6. gx2 = x2^3 + A * x2 + B [optimized out; see below]
# 7. If is_square(gx1), set x = x1 and y = sqrt(gx1)
# 8. Else set x = x2 and y = sqrt(gx2)
y1 = c.sqrt(gx1)
y1_2 = c.sqr(y1)
zero_if_gx1_square = y1_2 - gx1
# This magic comes from a generalization of [WB2019, section 4.2].
#
# We use the "Avoiding inversions" optimization in [WB2019, section 4.2]
# (not to be confused with section 4.3):
#
# here [WB2019]
# ------- ---------------------------------
@ -177,6 +121,36 @@ def map_to_curve_simple_swu(E, Z, h, us, c):
# gx1 g(X_0(t))
# gx2 g(X_1(t))
#
# X0(u) = N/D = [B*(Z^2 * u^4 + Z * u^2 + 1)] / [-A*(Z^2 * u^4 + Z * u^2]
# g(X0(u)) = U/V = [N^3 + A * N * D^2 + B * D^3] / D^3
Zu2 = Z * c.sqr(u) # Z is small
ta = c.sqr(Zu2) + Zu2
N = c.mul(B, ta + 1)
D = c.mul(-A, ta)
N2 = c.sqr(N)
D2 = c.sqr(D)
D3 = c.mul(D2, D)
U = select_z_nz(ta, BdivZA, c.mul(N2 + A*D2, N) + B*D3)
V = select_z_nz(ta, 1, D3)
if DEBUG:
x1 = N/D
gx1 = U/V
tv1 = (0 if ta == 0 else 1/ta)
assert x1 == (BdivZA if tv1 == 0 else mBdivA * (1 + tv1))
assert gx1 == x1^3 + A * x1 + B
# 5. x2 = Z * u^2 * x1
x2 = c.mul(Zu2, x1)
# 6. gx2 = x2^3 + A * x2 + B [optimized out; see below]
# 7. If is_square(gx1), set x = x1 and y = sqrt(gx1)
# 8. Else set x = x2 and y = sqrt(gx2)
(y1, zero_if_gx1_square) = F.sarkar_divsqrt(U, V, c)
# This magic also comes from a generalization of [WB2019, section 4.2].
#
# The Sarkar square root algorithm with input s gives us a square root of
# h * s for free when s is not square, provided we choose h to be a generator
# of the order 2^n multiplicative subgroup (where n = 32 for Pallas and Vesta).
@ -191,8 +165,8 @@ def map_to_curve_simple_swu(E, Z, h, us, c):
# is a square root of gx2. Note that we don't actually need to compute gx2.
y2 = c.mul(theta, c.mul(Zu2, c.mul(u, y1)))
if zero_if_gx1_square != 0:
assert y1_2 == h * gx1, (y1_2, Z, gx1)
if DEBUG and zero_if_gx1_square != 0:
assert y1^2 == h * gx1, (y1_2, Z, gx1)
assert y2^2 == x2^3 + A * x2 + B, (y2, x2, A, B)
x = select_z_nz(zero_if_gx1_square, x1, x2)
@ -326,7 +300,7 @@ def hash_to_curve_affine(msg, DST, uniform=True):
c = Cost()
us = hash_to_field(msg, DST, 2 if uniform else 1)
#print("u = ", u)
Qs = map_to_curve_simple_swu(E_isop, Z_isop, h_p, us, c)
Qs = map_to_curve_simple_swu(F_p, E_isop, Z_isop, us, c)
if uniform:
# Complete addition using affine coordinates: I + 2M + 2S
@ -349,7 +323,7 @@ def hash_to_curve_jacobian(msg, DST):
c = Cost()
us = hash_to_field(msg, DST, 2)
#print("u = ", u)
Qs = map_to_curve_simple_swu(E_isop, Z_isop, h_p, us, c)
Qs = map_to_curve_simple_swu(F_p, E_isop, Z_isop, us, c)
R = Qs[0] + Qs[1]
#print("R = ", R)

View File

@ -4,7 +4,6 @@
# from <https://eprint.iacr.org/2020/1407>, for the Pasta fields.
import sys
from copy import copy
if sys.version_info[0] == 2:
range = xrange
@ -16,16 +15,41 @@ EXPENSIVE = False
SUBGROUP_TEST = True
OP_COUNT = True
class Cost:
def __init__(self, sqrs, muls):
def __init__(self, sqrs=0, muls=0, invs=0):
self.sqrs = sqrs
self.muls = muls
self.invs = invs
def sqr(self, x):
self.sqrs += 1
return x^2
def mul(self, x, y):
self.muls += 1
return x * y
def div(self, x, y):
self.invs += 1
self.muls += 1
return x / y
def batch_inv0(self, xs):
self.invs += 1
self.muls += 3*(len(xs)-1)
# This should use Montgomery's trick (with constant-time substitutions to handle zeros).
return [0 if x == 0 else x^-1 for x in xs]
def __repr__(self):
return repr((self.sqrs, self.muls))
return "%dS + %dM + %dI" % (self.sqrs, self.muls, self.invs)
def __add__(self, other):
return Cost(self.sqrs + other.sqrs, self.muls + other.muls)
return Cost(self.sqrs + other.sqrs, self.muls + other.muls, self.invs + other.invs)
def include(self, other):
self.sqrs += other.sqrs
self.muls += other.muls
def divide(self, divisor):
return Cost((self.sqrs / divisor).numerical_approx(), (self.muls / divisor).numerical_approx())
@ -97,17 +121,63 @@ class SqrtField:
print("best is hash_xor=0x%X, hash_mod=%d" % (hash_xor, hash_mod))
return (hash_xor, hash_mod)
def sarkar_sqrt(self, u):
"""
Return (sqrt(u), True ), if u is square in the field.
(sqrt(g*u), False), otherwise.
"""
def sarkar_sqrt(self, u, c):
if VERBOSE: print("u = %r" % (u,))
# This would actually be done using the addition chain.
v = u^((self.m-1)/2)
cost = copy(self.base_cost)
c.include(self.base_cost)
uv = u * v
uv = c.mul(u, v)
(res, zero_if_square) = self.sarkar_sqrt_common(u, 1, uv, v, c)
return (res, zero_if_square)
"""
Return (sqrt(N/D), True, c), if N/D is square in the field.
(sqrt(g*N/D), False, c), otherwise.
This avoids the full c of computing N/D.
"""
def sarkar_divsqrt(self, N, D, c):
if DEBUG:
u = N/D
if VERBOSE: print("N/D = %r/%r\n = %r" % (N, D, u))
# This would actually be done using addition chains for 2^n - 1 and (m-1)/2
# (see addchain_sqrt.py).
s = D^(2^self.n - 1)
c.sqrs += 31
c.muls += 5
t = c.mul(c.sqr(s), D)
if DEBUG: assert t == D^(2^(self.n+1) - 1)
w = c.mul(c.mul(N, t)^((self.m-1)/2), s)
c.include(self.base_cost)
v = c.mul(w, D)
uv = c.mul(N, w)
if DEBUG:
assert v == u^((self.m-1)/2)
assert uv == u * v
(res, zero_if_square) = self.sarkar_sqrt_common(N, D, uv, v, c)
if DEBUG:
(res_ref, zero_if_square_ref) = self.sarkar_sqrt(u, Cost())
assert res == res_ref
assert (zero_if_square == 0) == (zero_if_square_ref == 0)
return (res, zero_if_square)
def sarkar_sqrt_common(self, N, D, uv, v, c):
x3 = uv * v
cost.muls += 2
if DEBUG: assert x3 == u^self.m
c.muls += 2
if DEBUG:
u = N/D
assert x3 == u^self.m
if EXPENSIVE:
x3_order = x3.multiplicative_order()
if VERBOSE: print("x3_order = %r" % (x3_order,))
@ -122,44 +192,44 @@ class SqrtField:
assert x1 == x3^(1<<(self.n-1-15))
assert x2 == x3^(1<<(self.n-1-23))
cost.sqrs += 8+8+8
c.sqrs += 8+8+8
# i = 0, 1
t_ = self.invtab[self.hash(x0)] # = t >> 16
if DEBUG: assert 1 == x0 * self.g^(t_ << 24), (x0, t_)
assert t_ < 0x100, t_
alpha = x1 * self.gtab[2][t_]
cost.muls += 1
c.muls += 1
# i = 2
t_ += self.invtab[self.hash(alpha)] << 8 # = t >> 8
if DEBUG: assert 1 == x1 * self.g^(t_ << 16), (x1, t_)
assert t_ < 0x10000, t_
alpha = x2 * self.gtab[1][t_ % 256] * self.gtab[2][t_ >> 8]
cost.muls += 2
c.muls += 2
# i = 3
t_ += self.invtab[self.hash(alpha)] << 16 # = t
if DEBUG: assert 1 == x2 * self.g^(t_ << 8), (x2, t_)
assert t_ < 0x1000000, t_
alpha = x3 * self.gtab[0][t_ % 256] * self.gtab[1][(t_ >> 8) % 256] * self.gtab[2][t_ >> 16]
cost.muls += 3
c.muls += 3
t_ += self.invtab[self.hash(alpha)] << 24 # = t << 1
if DEBUG: assert 1 == x3 * self.g^t_, (x3, t_)
t_ = (t_ + 1) >> 1
assert t_ <= 0x80000000, t_
res = uv * self.gtab[0][t_ % 256] * self.gtab[1][(t_ >> 8) % 256] * self.gtab[2][(t_ >> 16) % 256] * self.gtab[3][t_ >> 24]
cost.muls += 4
c.muls += 4
issq = (res^2 == u)
cost.sqrs += 1
zero_if_square = c.mul(c.sqr(res), D) - N
if DEBUG:
assert issq == u.is_square()
if EXPENSIVE: assert issq == (x3_order != 2^self.n), (issq, x3_order)
if not issq:
assert (zero_if_square == 0) == u.is_square()
if EXPENSIVE: assert (zero_if_square == 0) == (x3_order != 2^self.n), (zero_if_square, x3_order)
if zero_if_square != 0:
assert(res^2 == u * self.g)
return (res, issq, cost)
return (res, zero_if_square)
p = 0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001
@ -172,26 +242,28 @@ F_q = SqrtField(q, 5, Cost(223, 24), hash_xor=0x116A9E, hash_mod=1206)
print("p = %r" % (p,))
x = Mod(0x1234567890123456789012345678901234567890123456789012345678901234, p)
print(F_p.sarkar_sqrt(x))
print(F_p.sarkar_sqrt(x, Cost()))
Dx = Mod(0x123456, p)
print(F_p.sarkar_divsqrt(x*Dx, Dx, Cost()))
x = Mod(0x2345678901234567890123456789012345678901234567890123456789012345, p)
print(F_p.sarkar_sqrt(x))
print(F_p.sarkar_sqrt(x, Cost()))
# nonsquare
x = Mod(0x3456789012345678901234567890123456789012345678901234567890123456, p)
print(F_p.sarkar_sqrt(x))
print(F_p.sarkar_sqrt(x, Cost()))
if SUBGROUP_TEST:
for i in range(33):
x = F_p.g^(2^i)
print(F_p.sarkar_sqrt(x))
print(F_p.sarkar_sqrt(x, Cost()))
if OP_COUNT:
total_cost = Cost(0, 0)
cost = Cost()
iters = 50
for i in range(iters):
x = GF(p).random_element()
(_, _, cost) = F_p.sarkar_sqrt(x)
total_cost += cost
y = GF(p).random_element()
(_, _) = F_p.sarkar_divsqrt(x, y, cost)
print total_cost.divide(iters)
print cost.divide(iters)