#!/usr/bin/env sage # This implements a prototype of Palash Sarkar's square root algorithm # from , for the Pasta fields. import sys from copy import copy if sys.version_info[0] == 2: range = xrange DEBUG = True VERBOSE = False EXPENSIVE = False SUBGROUP_TEST = True OP_COUNT = True class Cost: def __init__(self, sqrs, muls): self.sqrs = sqrs self.muls = muls def __repr__(self): return repr((self.sqrs, self.muls)) def __add__(self, other): return Cost(self.sqrs + other.sqrs, self.muls + other.muls) def divide(self, divisor): return Cost((self.sqrs / divisor).numerical_approx(), (self.muls / divisor).numerical_approx()) class SqrtField: def __init__(self, p, z, base_cost, hash_xor=None, hash_mod=None): n = 32 m = p >> n assert p == 1 + m * 2^n if EXPENSIVE: assert Mod(z, p).multiplicative_order() == p-1 g = Mod(z, p)^m if EXPENSIVE: assert g.multiplicative_order() == 2^n gtab = [[0]*256 for i in range(4)] gi = g for i in range(4): if DEBUG: assert gi == g^(256^i), (i, gi) acc = Mod(1, p) for j in range(256): if DEBUG: assert acc == g^(256^i * j), (i, j, acc) gtab[i][j] = acc acc *= gi gi = acc if hash_xor is None: (hash_xor, hash_mod) = self.find_perfect_hash(gtab[3]) (self.hash_xor, self.hash_mod) = (hash_xor, hash_mod) # Now invert gtab[3]. invtab = [1]*hash_mod for j in range(256): h = self.hash(gtab[3][j]) # 1 is the last value to be assigned, so this ensures there are no collisions. assert invtab[h] == 1 invtab[h] = (256-j) % 256 gtab[3] = gtab[3][:128] (self.p, self.n, self.m, self.g, self.gtab, self.invtab, self.base_cost) = ( p, n, m, g, gtab, invtab, base_cost) def hash(self, x): return ((int(x) & 0xFFFFFFFF) ^^ self.hash_xor) % self.hash_mod def find_perfect_hash(self, gt): gt = [int(x) & 0xFFFFFFFF for x in gt] assert len(set(gt)) == len(gt) def is_ok(c_invtab, c_xor, c_mod): for j in range(256): hash = (gt[j] ^^ c_xor) % c_mod if c_invtab[hash] == c_mod: return False c_invtab[hash] = c_mod return True hash_xor = None hash_mod = 10000 for c_xor in range(1, 0x200000): c_invtab = [0]*hash_mod for c_mod in range(256, hash_mod): if is_ok(c_invtab, c_xor, c_mod): (hash_xor, hash_mod) = (c_xor, c_mod) print("0x%X: %d" % (hash_xor, hash_mod)) break print("best is hash_xor=0x%X, hash_mod=%d" % (hash_xor, hash_mod)) return (hash_xor, hash_mod) def sarkar_sqrt(self, u): if VERBOSE: print("u = %r" % (u,)) # This would actually be done using the addition chain. v = u^((self.m-1)/2) cost = copy(self.base_cost) uv = u * v x3 = uv * v cost.muls += 2 if DEBUG: assert x3 == u^self.m if EXPENSIVE: x3_order = x3.multiplicative_order() if VERBOSE: print("x3_order = %r" % (x3_order,)) # x3_order is 2^n iff u is nonsquare, otherwise it divides 2^(n-1). assert x3.divides(2^self.n) x2 = x3^(1<<8) x1 = x2^(1<<8) x0 = x1^(1<<8) if DEBUG: assert x0 == x3^(1<<(self.n-1-7)) assert x1 == x3^(1<<(self.n-1-15)) assert x2 == x3^(1<<(self.n-1-23)) cost.sqrs += 8+8+8 # i = 0, 1 t_ = self.invtab[self.hash(x0)] # = t >> 16 if DEBUG: assert 1 == x0 * self.g^(t_ << 24), (x0, t_) assert t_ < 0x100, t_ alpha = x1 * self.gtab[2][t_] cost.muls += 1 # i = 2 t_ += self.invtab[self.hash(alpha)] << 8 # = t >> 8 if DEBUG: assert 1 == x1 * self.g^(t_ << 16), (x1, t_) assert t_ < 0x10000, t_ alpha = x2 * self.gtab[1][t_ % 256] * self.gtab[2][t_ >> 8] cost.muls += 2 # i = 3 t_ += self.invtab[self.hash(alpha)] << 16 # = t if DEBUG: assert 1 == x2 * self.g^(t_ << 8), (x2, t_) assert t_ < 0x1000000, t_ alpha = x3 * self.gtab[0][t_ % 256] * self.gtab[1][(t_ >> 8) % 256] * self.gtab[2][t_ >> 16] cost.muls += 3 t_ += self.invtab[self.hash(alpha)] << 24 # = t << 1 if DEBUG: assert 1 == x3 * self.g^t_, (x3, t_) t_ >>= 1 assert t_ < 0x80000000, t_ res = uv * self.gtab[0][t_ % 256] * self.gtab[1][(t_ >> 8) % 256] * self.gtab[2][(t_ >> 16) % 256] * self.gtab[3][t_ >> 24] cost.muls += 4 if res^2 != u: res = None cost.sqrs += 1 if DEBUG: issq = u.is_square() assert issq == (res is not None) if EXPENSIVE: assert issq == (x3_order != 2^self.n), (issq, x3_order) return (res, cost) p = 0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001 q = 0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001 # see addchain_sqrt.py for base costs of u^{(m-1)/2} F_p = SqrtField(p, 5, Cost(223, 23), hash_xor=0x11BE, hash_mod=1098) F_q = SqrtField(q, 5, Cost(223, 24), hash_xor=0x116A9E, hash_mod=1206) print("p = %r" % (p,)) x = Mod(0x1234567890123456789012345678901234567890123456789012345678901234, p) print(F_p.sarkar_sqrt(x)) x = Mod(0x2345678901234567890123456789012345678901234567890123456789012345, p) print(F_p.sarkar_sqrt(x)) # nonsquare x = Mod(0x3456789012345678901234567890123456789012345678901234567890123456, p) print(F_p.sarkar_sqrt(x)) if SUBGROUP_TEST: for i in range(33): x = F_p.g^(2^i) print(F_p.sarkar_sqrt(x)) if OP_COUNT: total_cost = Cost(0, 0) iters = 1000 for i in range(iters): x = GF(p).random_element() (_, cost) = F_p.sarkar_sqrt(x) total_cost += cost print total_cost.divide(iters)