Add some nice assertions and tests to make it clearer what is going on.

Signed-off-by: Daira Hopwood <daira@jacaranda.org>
This commit is contained in:
Daira Hopwood 2020-11-30 13:12:48 +00:00
parent 7bf9015957
commit bf740d64b8
1 changed files with 25 additions and 6 deletions

View File

@ -12,6 +12,8 @@ DEBUG = True
VERBOSE = False
EXPENSIVE = False
SUBGROUP_TEST = True
OP_COUNT = True
class Cost:
def __init__(self, sqrs, muls):
@ -62,8 +64,8 @@ class SqrtField:
gtab[3] = gtab[3][:128]
(self.p, self.n, self.m, self.gtab, self.invtab, self.base_cost) = (
p, n, m, gtab, invtab, base_cost)
(self.p, self.n, self.m, self.g, self.gtab, self.invtab, self.base_cost) = (
p, n, m, g, gtab, invtab, base_cost)
def hash(self, x):
return ((int(x) & 0xFFFFFFFF) ^^ self.hash_xor) % self.hash_mod
@ -105,7 +107,11 @@ class SqrtField:
x3 = uv * v
cost.muls += 2
if DEBUG: assert x3 == u^self.m
if EXPENSIVE: assert x3.multiplicative_order().divides(2^self.n)
if EXPENSIVE:
x3_order = x3.multiplicative_order()
if VERBOSE: print("x3_order = %r" % (x3_order,))
# x3_order is 2^n iff u is nonsquare, otherwise it divides 2^(n-1).
assert x3.divides(2^self.n)
x2 = x3^(1<<8)
x1 = x2^(1<<8)
@ -119,23 +125,28 @@ class SqrtField:
# i = 0, 1
t_ = self.invtab[self.hash(x0)] # = t >> 16
if DEBUG: assert 1 == x0 * self.g^(t_ << 24), (x0, t_)
assert t_ < 0x100, t_
alpha = x1 * self.gtab[2][t_]
cost.muls += 1
# i = 2
t_ += self.invtab[self.hash(alpha)] << 8 # = t >> 8
if DEBUG: assert 1 == x1 * self.g^(t_ << 16), (x1, t_)
assert t_ < 0x10000, t_
alpha = x2 * self.gtab[1][t_ % 256] * self.gtab[2][t_ >> 8]
cost.muls += 2
# i = 3
t_ += self.invtab[self.hash(alpha)] << 16 # = t
if DEBUG: assert 1 == x2 * self.g^(t_ << 8), (x2, t_)
assert t_ < 0x1000000, t_
alpha = x3 * self.gtab[0][t_ % 256] * self.gtab[1][(t_ >> 8) % 256] * self.gtab[2][t_ >> 16]
cost.muls += 3
t_ = ((self.invtab[self.hash(alpha)] << 24) + t_) >> 1 # = t
t_ += self.invtab[self.hash(alpha)] << 24 # = t << 1
if DEBUG: assert 1 == x3 * self.g^t_, (x3, t_)
t_ >>= 1
assert t_ < 0x80000000, t_
res = uv * self.gtab[0][t_ % 256] * self.gtab[1][(t_ >> 8) % 256] * self.gtab[2][(t_ >> 16) % 256] * self.gtab[3][t_ >> 24]
cost.muls += 4
@ -143,7 +154,10 @@ class SqrtField:
if res^2 != u:
res = None
cost.sqrs += 1
if DEBUG: assert u.is_square() == (res is not None)
if DEBUG:
issq = u.is_square()
assert issq == (res is not None)
if EXPENSIVE: assert issq == (x3_order != 2^self.n), (issq, x3_order)
return (res, cost)
@ -167,7 +181,12 @@ print(F_p.sarkar_sqrt(x))
x = Mod(0x3456789012345678901234567890123456789012345678901234567890123456, p)
print(F_p.sarkar_sqrt(x))
if True:
if SUBGROUP_TEST:
for i in range(33):
x = F_p.g^(2^i)
print(F_p.sarkar_sqrt(x))
if OP_COUNT:
total_cost = Cost(0, 0)
iters = 1000
for i in range(iters):