mirror of https://github.com/zcash/pasta.git
Implement the optimization from [WB2019, section 4.2] that removes the remaining inversion.
Signed-off-by: Daira Hopwood <daira@jacaranda.org>
This commit is contained in:
parent
391e67f250
commit
50d3e83467
|
@ -124,3 +124,17 @@ ss = sch.sqrmul(sr, 3, s1)
|
|||
st = sch.sqr(ss, 4)
|
||||
assert st == s, format(st, 'b')
|
||||
print(sch)
|
||||
|
||||
|
||||
t = (1<<32) - 1
|
||||
|
||||
assert(s == q >> (n+1))
|
||||
tch = Chain()
|
||||
t1 = 1
|
||||
t2 = tch.sqrmul(t1, 1, t1)
|
||||
t4 = tch.sqrmul(t2, 2, t2)
|
||||
t8 = tch.sqrmul(t4, 4, t4)
|
||||
t16 = tch.sqrmul(t8, 8, t8)
|
||||
t32 = tch.sqrmul(t16, 16, t16)
|
||||
assert t32 == t, format(t32, 'b')
|
||||
print(tch)
|
||||
|
|
110
hashtocurve.sage
110
hashtocurve.sage
|
@ -24,42 +24,7 @@ else:
|
|||
|
||||
load('squareroottab.sage')
|
||||
|
||||
class Cost:
|
||||
def __init__(self, sqrs=0, muls=0, invs=0):
|
||||
self.sqrs = sqrs
|
||||
self.muls = muls
|
||||
self.invs = invs
|
||||
|
||||
def sqr(self, x):
|
||||
self.sqrs += 1
|
||||
return x^2
|
||||
|
||||
def mul(self, x, y):
|
||||
self.muls += 1
|
||||
return x * y
|
||||
|
||||
def div(self, x, y):
|
||||
self.invs += 1
|
||||
self.muls += 1
|
||||
return x / y
|
||||
|
||||
def batch_inv0(self, xs):
|
||||
self.invs += 1
|
||||
self.muls += 3*(len(xs)-1)
|
||||
# This should use Montgomery's trick (with constant-time substitutions to handle zeros).
|
||||
return [0 if x == 0 else x^-1 for x in xs]
|
||||
|
||||
def sqrt(self, x):
|
||||
self.sqrs += 247
|
||||
self.muls += 35
|
||||
(res, _, _) = F_p.sarkar_sqrt(x)
|
||||
return res
|
||||
|
||||
def __add__(self, other):
|
||||
return Cost(self.sqrs + other.sqrs, self.muls + other.muls, self.invs + other.invs)
|
||||
|
||||
def __repr__(self):
|
||||
return "%dS + %dM + %dI" % (self.sqrs, self.muls, self.invs)
|
||||
DEBUG = True
|
||||
|
||||
# E: a short Weierstrass elliptic curve
|
||||
def find_z_sswu(E):
|
||||
|
@ -128,8 +93,9 @@ def select_z_nz(s, ifz, ifnz):
|
|||
# This should be constant-time in a real implementation.
|
||||
return ifz if (s == 0) else ifnz
|
||||
|
||||
def map_to_curve_simple_swu(E, Z, h, us, c):
|
||||
def map_to_curve_simple_swu(F, E, Z, us, c):
|
||||
# would be precomputed
|
||||
h = F.g
|
||||
(0, 0, 0, A, B) = E.a_invariants()
|
||||
mBdivA = -B / A
|
||||
BdivZA = B / (Z * A)
|
||||
|
@ -137,37 +103,15 @@ def map_to_curve_simple_swu(E, Z, h, us, c):
|
|||
assert (Z/h).is_square()
|
||||
theta = sqrt(Z/h)
|
||||
|
||||
# 1. tv1 = inv0(Z^2 * u^4 + Z * u^2)
|
||||
# = inv0((Z^2 * u^2 + Z) * u^2)
|
||||
u2s = [c.sqr(u) for u in us]
|
||||
tas = [c.mul((Z2*u2 + Z), u2) for u2 in u2s]
|
||||
tv1s = c.batch_inv0(tas)
|
||||
|
||||
Qs = []
|
||||
for i in range(len(us)):
|
||||
(u, u2, tv1) = (us[i], u2s[i], tv1s[i])
|
||||
|
||||
for u in us:
|
||||
# 1. tv1 = inv0(Z^2 * u^4 + Z * u^2)
|
||||
# 2. x1 = (-B / A) * (1 + tv1)
|
||||
# 3. If tv1 == 0, set x1 = B / (Z * A)
|
||||
x1 = select_z_nz(tv1, BdivZA, mBdivA * (1 + tv1))
|
||||
|
||||
# 4. gx1 = x1^3 + A * x1 + B
|
||||
# = x1*(x1^2 + A) + B
|
||||
x1_2 = c.sqr(x1)
|
||||
gx1 = c.mul(x1, x1_2 + A) + B
|
||||
|
||||
# 5. x2 = Z * u^2 * x1
|
||||
Zu2 = Z * u2 # Z is small
|
||||
x2 = c.mul(Zu2, x1)
|
||||
|
||||
# 6. gx2 = x2^3 + A * x2 + B [optimized out; see below]
|
||||
# 7. If is_square(gx1), set x = x1 and y = sqrt(gx1)
|
||||
# 8. Else set x = x2 and y = sqrt(gx2)
|
||||
y1 = c.sqrt(gx1)
|
||||
y1_2 = c.sqr(y1)
|
||||
zero_if_gx1_square = y1_2 - gx1
|
||||
|
||||
# This magic comes from a generalization of [WB2019, section 4.2].
|
||||
#
|
||||
# We use the "Avoiding inversions" optimization in [WB2019, section 4.2]
|
||||
# (not to be confused with section 4.3):
|
||||
#
|
||||
# here [WB2019]
|
||||
# ------- ---------------------------------
|
||||
|
@ -177,6 +121,36 @@ def map_to_curve_simple_swu(E, Z, h, us, c):
|
|||
# gx1 g(X_0(t))
|
||||
# gx2 g(X_1(t))
|
||||
#
|
||||
# X0(u) = N/D = [B*(Z^2 * u^4 + Z * u^2 + 1)] / [-A*(Z^2 * u^4 + Z * u^2]
|
||||
# g(X0(u)) = U/V = [N^3 + A * N * D^2 + B * D^3] / D^3
|
||||
|
||||
Zu2 = Z * c.sqr(u) # Z is small
|
||||
ta = c.sqr(Zu2) + Zu2
|
||||
N = c.mul(B, ta + 1)
|
||||
D = c.mul(-A, ta)
|
||||
N2 = c.sqr(N)
|
||||
D2 = c.sqr(D)
|
||||
D3 = c.mul(D2, D)
|
||||
U = select_z_nz(ta, BdivZA, c.mul(N2 + A*D2, N) + B*D3)
|
||||
V = select_z_nz(ta, 1, D3)
|
||||
|
||||
if DEBUG:
|
||||
x1 = N/D
|
||||
gx1 = U/V
|
||||
tv1 = (0 if ta == 0 else 1/ta)
|
||||
assert x1 == (BdivZA if tv1 == 0 else mBdivA * (1 + tv1))
|
||||
assert gx1 == x1^3 + A * x1 + B
|
||||
|
||||
# 5. x2 = Z * u^2 * x1
|
||||
x2 = c.mul(Zu2, x1)
|
||||
|
||||
# 6. gx2 = x2^3 + A * x2 + B [optimized out; see below]
|
||||
# 7. If is_square(gx1), set x = x1 and y = sqrt(gx1)
|
||||
# 8. Else set x = x2 and y = sqrt(gx2)
|
||||
(y1, zero_if_gx1_square) = F.sarkar_divsqrt(U, V, c)
|
||||
|
||||
# This magic also comes from a generalization of [WB2019, section 4.2].
|
||||
#
|
||||
# The Sarkar square root algorithm with input s gives us a square root of
|
||||
# h * s for free when s is not square, provided we choose h to be a generator
|
||||
# of the order 2^n multiplicative subgroup (where n = 32 for Pallas and Vesta).
|
||||
|
@ -191,8 +165,8 @@ def map_to_curve_simple_swu(E, Z, h, us, c):
|
|||
# is a square root of gx2. Note that we don't actually need to compute gx2.
|
||||
|
||||
y2 = c.mul(theta, c.mul(Zu2, c.mul(u, y1)))
|
||||
if zero_if_gx1_square != 0:
|
||||
assert y1_2 == h * gx1, (y1_2, Z, gx1)
|
||||
if DEBUG and zero_if_gx1_square != 0:
|
||||
assert y1^2 == h * gx1, (y1_2, Z, gx1)
|
||||
assert y2^2 == x2^3 + A * x2 + B, (y2, x2, A, B)
|
||||
|
||||
x = select_z_nz(zero_if_gx1_square, x1, x2)
|
||||
|
@ -326,7 +300,7 @@ def hash_to_curve_affine(msg, DST, uniform=True):
|
|||
c = Cost()
|
||||
us = hash_to_field(msg, DST, 2 if uniform else 1)
|
||||
#print("u = ", u)
|
||||
Qs = map_to_curve_simple_swu(E_isop, Z_isop, h_p, us, c)
|
||||
Qs = map_to_curve_simple_swu(F_p, E_isop, Z_isop, us, c)
|
||||
|
||||
if uniform:
|
||||
# Complete addition using affine coordinates: I + 2M + 2S
|
||||
|
@ -349,7 +323,7 @@ def hash_to_curve_jacobian(msg, DST):
|
|||
c = Cost()
|
||||
us = hash_to_field(msg, DST, 2)
|
||||
#print("u = ", u)
|
||||
Qs = map_to_curve_simple_swu(E_isop, Z_isop, h_p, us, c)
|
||||
Qs = map_to_curve_simple_swu(F_p, E_isop, Z_isop, us, c)
|
||||
|
||||
R = Qs[0] + Qs[1]
|
||||
#print("R = ", R)
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# from <https://eprint.iacr.org/2020/1407>, for the Pasta fields.
|
||||
|
||||
import sys
|
||||
from copy import copy
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
range = xrange
|
||||
|
@ -16,16 +15,41 @@ EXPENSIVE = False
|
|||
SUBGROUP_TEST = True
|
||||
OP_COUNT = True
|
||||
|
||||
|
||||
class Cost:
|
||||
def __init__(self, sqrs, muls):
|
||||
def __init__(self, sqrs=0, muls=0, invs=0):
|
||||
self.sqrs = sqrs
|
||||
self.muls = muls
|
||||
self.invs = invs
|
||||
|
||||
def sqr(self, x):
|
||||
self.sqrs += 1
|
||||
return x^2
|
||||
|
||||
def mul(self, x, y):
|
||||
self.muls += 1
|
||||
return x * y
|
||||
|
||||
def div(self, x, y):
|
||||
self.invs += 1
|
||||
self.muls += 1
|
||||
return x / y
|
||||
|
||||
def batch_inv0(self, xs):
|
||||
self.invs += 1
|
||||
self.muls += 3*(len(xs)-1)
|
||||
# This should use Montgomery's trick (with constant-time substitutions to handle zeros).
|
||||
return [0 if x == 0 else x^-1 for x in xs]
|
||||
|
||||
def __repr__(self):
|
||||
return repr((self.sqrs, self.muls))
|
||||
return "%dS + %dM + %dI" % (self.sqrs, self.muls, self.invs)
|
||||
|
||||
def __add__(self, other):
|
||||
return Cost(self.sqrs + other.sqrs, self.muls + other.muls)
|
||||
return Cost(self.sqrs + other.sqrs, self.muls + other.muls, self.invs + other.invs)
|
||||
|
||||
def include(self, other):
|
||||
self.sqrs += other.sqrs
|
||||
self.muls += other.muls
|
||||
|
||||
def divide(self, divisor):
|
||||
return Cost((self.sqrs / divisor).numerical_approx(), (self.muls / divisor).numerical_approx())
|
||||
|
@ -97,17 +121,63 @@ class SqrtField:
|
|||
print("best is hash_xor=0x%X, hash_mod=%d" % (hash_xor, hash_mod))
|
||||
return (hash_xor, hash_mod)
|
||||
|
||||
def sarkar_sqrt(self, u):
|
||||
"""
|
||||
Return (sqrt(u), True ), if u is square in the field.
|
||||
(sqrt(g*u), False), otherwise.
|
||||
"""
|
||||
def sarkar_sqrt(self, u, c):
|
||||
if VERBOSE: print("u = %r" % (u,))
|
||||
|
||||
# This would actually be done using the addition chain.
|
||||
v = u^((self.m-1)/2)
|
||||
cost = copy(self.base_cost)
|
||||
c.include(self.base_cost)
|
||||
|
||||
uv = u * v
|
||||
uv = c.mul(u, v)
|
||||
(res, zero_if_square) = self.sarkar_sqrt_common(u, 1, uv, v, c)
|
||||
return (res, zero_if_square)
|
||||
|
||||
"""
|
||||
Return (sqrt(N/D), True, c), if N/D is square in the field.
|
||||
(sqrt(g*N/D), False, c), otherwise.
|
||||
|
||||
This avoids the full c of computing N/D.
|
||||
"""
|
||||
def sarkar_divsqrt(self, N, D, c):
|
||||
if DEBUG:
|
||||
u = N/D
|
||||
if VERBOSE: print("N/D = %r/%r\n = %r" % (N, D, u))
|
||||
|
||||
# This would actually be done using addition chains for 2^n - 1 and (m-1)/2
|
||||
# (see addchain_sqrt.py).
|
||||
s = D^(2^self.n - 1)
|
||||
c.sqrs += 31
|
||||
c.muls += 5
|
||||
t = c.mul(c.sqr(s), D)
|
||||
if DEBUG: assert t == D^(2^(self.n+1) - 1)
|
||||
w = c.mul(c.mul(N, t)^((self.m-1)/2), s)
|
||||
c.include(self.base_cost)
|
||||
v = c.mul(w, D)
|
||||
uv = c.mul(N, w)
|
||||
|
||||
if DEBUG:
|
||||
assert v == u^((self.m-1)/2)
|
||||
assert uv == u * v
|
||||
|
||||
(res, zero_if_square) = self.sarkar_sqrt_common(N, D, uv, v, c)
|
||||
|
||||
if DEBUG:
|
||||
(res_ref, zero_if_square_ref) = self.sarkar_sqrt(u, Cost())
|
||||
assert res == res_ref
|
||||
assert (zero_if_square == 0) == (zero_if_square_ref == 0)
|
||||
|
||||
return (res, zero_if_square)
|
||||
|
||||
def sarkar_sqrt_common(self, N, D, uv, v, c):
|
||||
x3 = uv * v
|
||||
cost.muls += 2
|
||||
if DEBUG: assert x3 == u^self.m
|
||||
c.muls += 2
|
||||
if DEBUG:
|
||||
u = N/D
|
||||
assert x3 == u^self.m
|
||||
if EXPENSIVE:
|
||||
x3_order = x3.multiplicative_order()
|
||||
if VERBOSE: print("x3_order = %r" % (x3_order,))
|
||||
|
@ -122,44 +192,44 @@ class SqrtField:
|
|||
assert x1 == x3^(1<<(self.n-1-15))
|
||||
assert x2 == x3^(1<<(self.n-1-23))
|
||||
|
||||
cost.sqrs += 8+8+8
|
||||
c.sqrs += 8+8+8
|
||||
|
||||
# i = 0, 1
|
||||
t_ = self.invtab[self.hash(x0)] # = t >> 16
|
||||
if DEBUG: assert 1 == x0 * self.g^(t_ << 24), (x0, t_)
|
||||
assert t_ < 0x100, t_
|
||||
alpha = x1 * self.gtab[2][t_]
|
||||
cost.muls += 1
|
||||
c.muls += 1
|
||||
|
||||
# i = 2
|
||||
t_ += self.invtab[self.hash(alpha)] << 8 # = t >> 8
|
||||
if DEBUG: assert 1 == x1 * self.g^(t_ << 16), (x1, t_)
|
||||
assert t_ < 0x10000, t_
|
||||
alpha = x2 * self.gtab[1][t_ % 256] * self.gtab[2][t_ >> 8]
|
||||
cost.muls += 2
|
||||
c.muls += 2
|
||||
|
||||
# i = 3
|
||||
t_ += self.invtab[self.hash(alpha)] << 16 # = t
|
||||
if DEBUG: assert 1 == x2 * self.g^(t_ << 8), (x2, t_)
|
||||
assert t_ < 0x1000000, t_
|
||||
alpha = x3 * self.gtab[0][t_ % 256] * self.gtab[1][(t_ >> 8) % 256] * self.gtab[2][t_ >> 16]
|
||||
cost.muls += 3
|
||||
c.muls += 3
|
||||
|
||||
t_ += self.invtab[self.hash(alpha)] << 24 # = t << 1
|
||||
if DEBUG: assert 1 == x3 * self.g^t_, (x3, t_)
|
||||
t_ = (t_ + 1) >> 1
|
||||
assert t_ <= 0x80000000, t_
|
||||
res = uv * self.gtab[0][t_ % 256] * self.gtab[1][(t_ >> 8) % 256] * self.gtab[2][(t_ >> 16) % 256] * self.gtab[3][t_ >> 24]
|
||||
cost.muls += 4
|
||||
c.muls += 4
|
||||
|
||||
issq = (res^2 == u)
|
||||
cost.sqrs += 1
|
||||
zero_if_square = c.mul(c.sqr(res), D) - N
|
||||
if DEBUG:
|
||||
assert issq == u.is_square()
|
||||
if EXPENSIVE: assert issq == (x3_order != 2^self.n), (issq, x3_order)
|
||||
if not issq:
|
||||
assert (zero_if_square == 0) == u.is_square()
|
||||
if EXPENSIVE: assert (zero_if_square == 0) == (x3_order != 2^self.n), (zero_if_square, x3_order)
|
||||
if zero_if_square != 0:
|
||||
assert(res^2 == u * self.g)
|
||||
return (res, issq, cost)
|
||||
|
||||
return (res, zero_if_square)
|
||||
|
||||
|
||||
p = 0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001
|
||||
|
@ -172,26 +242,28 @@ F_q = SqrtField(q, 5, Cost(223, 24), hash_xor=0x116A9E, hash_mod=1206)
|
|||
print("p = %r" % (p,))
|
||||
|
||||
x = Mod(0x1234567890123456789012345678901234567890123456789012345678901234, p)
|
||||
print(F_p.sarkar_sqrt(x))
|
||||
print(F_p.sarkar_sqrt(x, Cost()))
|
||||
Dx = Mod(0x123456, p)
|
||||
print(F_p.sarkar_divsqrt(x*Dx, Dx, Cost()))
|
||||
|
||||
x = Mod(0x2345678901234567890123456789012345678901234567890123456789012345, p)
|
||||
print(F_p.sarkar_sqrt(x))
|
||||
print(F_p.sarkar_sqrt(x, Cost()))
|
||||
|
||||
# nonsquare
|
||||
x = Mod(0x3456789012345678901234567890123456789012345678901234567890123456, p)
|
||||
print(F_p.sarkar_sqrt(x))
|
||||
print(F_p.sarkar_sqrt(x, Cost()))
|
||||
|
||||
if SUBGROUP_TEST:
|
||||
for i in range(33):
|
||||
x = F_p.g^(2^i)
|
||||
print(F_p.sarkar_sqrt(x))
|
||||
print(F_p.sarkar_sqrt(x, Cost()))
|
||||
|
||||
if OP_COUNT:
|
||||
total_cost = Cost(0, 0)
|
||||
cost = Cost()
|
||||
iters = 50
|
||||
for i in range(iters):
|
||||
x = GF(p).random_element()
|
||||
(_, _, cost) = F_p.sarkar_sqrt(x)
|
||||
total_cost += cost
|
||||
y = GF(p).random_element()
|
||||
(_, _) = F_p.sarkar_divsqrt(x, y, cost)
|
||||
|
||||
print total_cost.divide(iters)
|
||||
print cost.divide(iters)
|
||||
|
|
Loading…
Reference in New Issue