20201129 10:43:07 08:00



#!/usr/bin/env sage








# This implements a prototype of Palash Sarkar's square root algorithm




# from <https://eprint.iacr.org/2020/1407>, for the Pasta fields.





20201130 05:28:25 08:00



import sys

20201129 10:43:07 08:00




20201130 02:13:16 08:00



if sys.version_info[0] == 2:




range = xrange





20201129 10:43:07 08:00



DEBUG = True




VERBOSE = False




EXPENSIVE = False





20210221 13:11:19 08:00



SUBGROUP_TEST = False




OP_COUNT = False

20201129 10:43:07 08:00




20201231 19:36:05 08:00




20201129 10:43:07 08:00



class Cost:

20201231 19:36:05 08:00



def __init__(self, sqrs=0, muls=0, invs=0):

20201129 10:43:07 08:00



self.sqrs = sqrs




self.muls = muls

20201231 19:36:05 08:00



self.invs = invs








def sqr(self, x):




self.sqrs += 1




return x^2








def mul(self, x, y):




self.muls += 1




return x * y








def div(self, x, y):




self.invs += 1




self.muls += 1




return x / y








def batch_inv0(self, xs):




self.invs += 1




self.muls += 3*(len(xs)1)




# This should use Montgomery's trick (with constanttime substitutions to handle zeros).




return [0 if x == 0 else x^1 for x in xs]

20201129 10:43:07 08:00







def __repr__(self):

20201231 19:36:05 08:00



return "%dS + %dM + %dI" % (self.sqrs, self.muls, self.invs)

20201129 10:43:07 08:00







def __add__(self, other):

20201231 19:36:05 08:00



return Cost(self.sqrs + other.sqrs, self.muls + other.muls, self.invs + other.invs)








def include(self, other):




self.sqrs += other.sqrs




self.muls += other.muls

20201129 10:43:07 08:00







def divide(self, divisor):




return Cost((self.sqrs / divisor).numerical_approx(), (self.muls / divisor).numerical_approx())












class SqrtField:

20201231 05:45:35 08:00



def __init__(self, p, z, base_cost, hash_xor=None, hash_mod=None):

20201129 10:43:07 08:00



n = 32




m = p >> n




assert p == 1 + m * 2^n




if EXPENSIVE: assert Mod(z, p).multiplicative_order() == p1




g = Mod(z, p)^m




if EXPENSIVE: assert g.multiplicative_order() == 2^n








gtab = [[0]*256 for i in range(4)]




gi = g




for i in range(4):




if DEBUG: assert gi == g^(256^i), (i, gi)




acc = Mod(1, p)




for j in range(256):




if DEBUG: assert acc == g^(256^i * j), (i, j, acc)




gtab[i][j] = acc




acc *= gi




gi = acc





20201130 02:43:17 08:00



if hash_xor is None:




(hash_xor, hash_mod) = self.find_perfect_hash(gtab[3])




(self.hash_xor, self.hash_mod) = (hash_xor, hash_mod)

20201129 10:43:07 08:00







# Now invert gtab[3].

20201130 04:00:50 08:00



invtab = [1]*hash_mod

20201129 10:43:07 08:00



for j in range(256):

20201130 04:00:50 08:00



h = self.hash(gtab[3][j])




# 1 is the last value to be assigned, so this ensures there are no collisions.




assert invtab[h] == 1




invtab[h] = (256j) % 256

20201129 10:43:07 08:00




20201230 19:35:50 08:00



gtab[3] = gtab[3][:129]

20201129 11:03:43 08:00




20201130 05:12:48 08:00



(self.p, self.n, self.m, self.g, self.gtab, self.invtab, self.base_cost) = (




p, n, m, g, gtab, invtab, base_cost)

20201129 10:43:07 08:00







def hash(self, x):

20201130 02:43:17 08:00



return ((int(x) & 0xFFFFFFFF) ^^ self.hash_xor) % self.hash_mod

20201129 10:43:07 08:00







def find_perfect_hash(self, gt):

20201130 02:43:17 08:00



gt = [int(x) & 0xFFFFFFFF for x in gt]




assert len(set(gt)) == len(gt)








def is_ok(c_invtab, c_xor, c_mod):

20201129 10:43:07 08:00



for j in range(256):

20201130 02:43:17 08:00



hash = (gt[j] ^^ c_xor) % c_mod




if c_invtab[hash] == c_mod:

20201129 10:43:07 08:00



return False

20201130 02:43:17 08:00



c_invtab[hash] = c_mod

20201129 10:43:07 08:00







return True





20201130 02:43:17 08:00



hash_xor = None




hash_mod = 10000




for c_xor in range(1, 0x200000):




c_invtab = [0]*hash_mod




for c_mod in range(256, hash_mod):




if is_ok(c_invtab, c_xor, c_mod):




(hash_xor, hash_mod) = (c_xor, c_mod)




print("0x%X: %d" % (hash_xor, hash_mod))

20201129 10:43:07 08:00



break





20201130 02:43:17 08:00



print("best is hash_xor=0x%X, hash_mod=%d" % (hash_xor, hash_mod))




return (hash_xor, hash_mod)

20201129 10:43:07 08:00




20201231 19:36:05 08:00



"""




Return (sqrt(u), True ), if u is square in the field.




(sqrt(g*u), False), otherwise.




"""




def sarkar_sqrt(self, u, c):

20201129 10:43:07 08:00



if VERBOSE: print("u = %r" % (u,))








# This would actually be done using the addition chain.




v = u^((self.m1)/2)

20201231 19:36:05 08:00



c.include(self.base_cost)








uv = c.mul(u, v)




(res, zero_if_square) = self.sarkar_sqrt_common(u, 1, uv, v, c)




return (res, zero_if_square)








"""

20210221 13:11:19 08:00



Return (sqrt(N/D), True ), if N/D is square in the field.




(sqrt(g*N/D), False), otherwise.

20201231 19:36:05 08:00




20210101 11:40:51 08:00



This avoids the full cost of computing N/D.

20201231 19:36:05 08:00



"""




def sarkar_divsqrt(self, N, D, c):




if DEBUG:




u = N/D




if VERBOSE: print("N/D = %r/%r\n = %r" % (N, D, u))





20210101 11:40:51 08:00



# We need to calculate uv and v, where v = u^((m1)/2), u = N/D, and p1 = m * 2^n.




# We can rewrite as follows:




#




# v = (N/D)^((m1)/2)




# = N^((m1)/2) * D^(p1  (m1)/2) [Fermat's Little Theorem]




# = " * D^(m * 2^n  (m1)/2)




# = " * D^((2^(n+1)  1)*(m1)/2 + 2^n)




# = (N * D^(2^(n+1)  1))^((m1)/2) * D^(2^n)




#




# Let w = (N * D^(2^(n+1)  1))^((m1)/2) * D^(2^n  1).




# Then v = w * D, and uv = N * v/D = N * w.




#




# We calculate:




#




# s = D^(2^n  1) using an addition chain




# t = D^(2^(n+1)  1) = s^2 * D




# w = (N * t)^((m1)/2) * s using another addition chain




#




# then u and uv as above. The addition chains are given in addchain_sqrt.py .




# The overall cost of this part is similar to a single fullwidth exponentiation,




# regardless of n.





20201231 19:36:05 08:00



s = D^(2^self.n  1)




c.sqrs += 31




c.muls += 5




t = c.mul(c.sqr(s), D)




if DEBUG: assert t == D^(2^(self.n+1)  1)




w = c.mul(c.mul(N, t)^((self.m1)/2), s)




c.include(self.base_cost)




v = c.mul(w, D)




uv = c.mul(N, w)








if DEBUG:




assert v == u^((self.m1)/2)




assert uv == u * v

20201129 10:43:07 08:00




20201231 19:36:05 08:00



(res, zero_if_square) = self.sarkar_sqrt_common(N, D, uv, v, c)








if DEBUG:




(res_ref, zero_if_square_ref) = self.sarkar_sqrt(u, Cost())




assert res == res_ref




assert (zero_if_square == 0) == (zero_if_square_ref == 0)








return (res, zero_if_square)








def sarkar_sqrt_common(self, N, D, uv, v, c):

20201129 12:36:08 08:00



x3 = uv * v

20201231 19:36:05 08:00



c.muls += 2




if DEBUG:




u = N/D




assert x3 == u^self.m

20201130 05:12:48 08:00



if EXPENSIVE:




x3_order = x3.multiplicative_order()




if VERBOSE: print("x3_order = %r" % (x3_order,))




# x3_order is 2^n iff u is nonsquare, otherwise it divides 2^(n1).




assert x3.divides(2^self.n)

20201129 10:43:07 08:00







x2 = x3^(1<<8)




x1 = x2^(1<<8)




x0 = x1^(1<<8)




if DEBUG:




assert x0 == x3^(1<<(self.n17))




assert x1 == x3^(1<<(self.n115))




assert x2 == x3^(1<<(self.n123))





20201231 19:36:05 08:00



c.sqrs += 8+8+8

20201129 10:43:07 08:00




20201129 11:03:43 08:00



# i = 0, 1

20201129 11:29:32 08:00



t_ = self.invtab[self.hash(x0)] # = t >> 16

20201130 05:12:48 08:00



if DEBUG: assert 1 == x0 * self.g^(t_ << 24), (x0, t_)

20201129 11:29:32 08:00



assert t_ < 0x100, t_




alpha = x1 * self.gtab[2][t_]

20201231 19:36:05 08:00



c.muls += 1

20201129 10:43:07 08:00







# i = 2

20201129 11:29:32 08:00



t_ += self.invtab[self.hash(alpha)] << 8 # = t >> 8

20201130 05:12:48 08:00



if DEBUG: assert 1 == x1 * self.g^(t_ << 16), (x1, t_)

20201129 11:29:32 08:00



assert t_ < 0x10000, t_




alpha = x2 * self.gtab[1][t_ % 256] * self.gtab[2][t_ >> 8]

20201231 19:36:05 08:00



c.muls += 2

20201129 10:43:07 08:00







# i = 3

20201129 11:29:32 08:00



t_ += self.invtab[self.hash(alpha)] << 16 # = t

20201130 05:12:48 08:00



if DEBUG: assert 1 == x2 * self.g^(t_ << 8), (x2, t_)

20201129 11:29:32 08:00



assert t_ < 0x1000000, t_




alpha = x3 * self.gtab[0][t_ % 256] * self.gtab[1][(t_ >> 8) % 256] * self.gtab[2][t_ >> 16]

20201231 19:36:05 08:00



c.muls += 3

20201129 11:29:32 08:00




20201130 05:12:48 08:00



t_ += self.invtab[self.hash(alpha)] << 24 # = t << 1




if DEBUG: assert 1 == x3 * self.g^t_, (x3, t_)

20201230 19:35:50 08:00



t_ = (t_ + 1) >> 1




assert t_ <= 0x80000000, t_

20201129 12:36:08 08:00



res = uv * self.gtab[0][t_ % 256] * self.gtab[1][(t_ >> 8) % 256] * self.gtab[2][(t_ >> 16) % 256] * self.gtab[3][t_ >> 24]

20201231 19:36:05 08:00



c.muls += 4

20201129 10:43:07 08:00




20201231 19:36:05 08:00



zero_if_square = c.mul(c.sqr(res), D)  N

20201130 05:12:48 08:00



if DEBUG:

20201231 19:36:05 08:00



assert (zero_if_square == 0) == u.is_square()




if EXPENSIVE: assert (zero_if_square == 0) == (x3_order != 2^self.n), (zero_if_square, x3_order)




if zero_if_square != 0:

20201230 19:35:50 08:00



assert(res^2 == u * self.g)

20201231 19:36:05 08:00







return (res, zero_if_square)

20201129 10:43:07 08:00











p = 0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001




q = 0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001





20201212 08:46:32 08:00



# see addchain_sqrt.py for base costs of u^{(m1)/2}

20201231 05:45:35 08:00



F_p = SqrtField(p, 5, Cost(223, 23), hash_xor=0x11BE, hash_mod=1098)




F_q = SqrtField(q, 5, Cost(223, 24), hash_xor=0x116A9E, hash_mod=1206)

20201129 10:43:07 08:00







print("p = %r" % (p,))








x = Mod(0x1234567890123456789012345678901234567890123456789012345678901234, p)

20201231 19:36:05 08:00



print(F_p.sarkar_sqrt(x, Cost()))




Dx = Mod(0x123456, p)




print(F_p.sarkar_divsqrt(x*Dx, Dx, Cost()))

20201129 10:43:07 08:00







x = Mod(0x2345678901234567890123456789012345678901234567890123456789012345, p)

20201231 19:36:05 08:00



print(F_p.sarkar_sqrt(x, Cost()))

20201129 10:43:07 08:00







# nonsquare




x = Mod(0x3456789012345678901234567890123456789012345678901234567890123456, p)

20201231 19:36:05 08:00



print(F_p.sarkar_sqrt(x, Cost()))

20201129 10:43:07 08:00




20201130 05:12:48 08:00



if SUBGROUP_TEST:




for i in range(33):




x = F_p.g^(2^i)

20201231 19:36:05 08:00



print(F_p.sarkar_sqrt(x, Cost()))

20201130 05:12:48 08:00







if OP_COUNT:

20201231 19:36:05 08:00



cost = Cost()

20201230 19:35:50 08:00



iters = 50

20201129 10:43:07 08:00



for i in range(iters):




x = GF(p).random_element()

20201231 19:36:05 08:00



y = GF(p).random_element()




(_, _) = F_p.sarkar_divsqrt(x, y, cost)

20201129 10:43:07 08:00




20210421 04:32:27 07:00



print(cost.divide(iters))
