mirror of https://github.com/zcash/pasta.git
Add some nice assertions and tests to make it clearer what is going on.
Signed-off-by: Daira Hopwood <daira@jacaranda.org>
This commit is contained in:
parent
7bf9015957
commit
bf740d64b8
|
@ -12,6 +12,8 @@ DEBUG = True
|
|||
VERBOSE = False
|
||||
EXPENSIVE = False
|
||||
|
||||
SUBGROUP_TEST = True
|
||||
OP_COUNT = True
|
||||
|
||||
class Cost:
|
||||
def __init__(self, sqrs, muls):
|
||||
|
@ -62,8 +64,8 @@ class SqrtField:
|
|||
|
||||
gtab[3] = gtab[3][:128]
|
||||
|
||||
(self.p, self.n, self.m, self.gtab, self.invtab, self.base_cost) = (
|
||||
p, n, m, gtab, invtab, base_cost)
|
||||
(self.p, self.n, self.m, self.g, self.gtab, self.invtab, self.base_cost) = (
|
||||
p, n, m, g, gtab, invtab, base_cost)
|
||||
|
||||
def hash(self, x):
|
||||
return ((int(x) & 0xFFFFFFFF) ^^ self.hash_xor) % self.hash_mod
|
||||
|
@ -105,7 +107,11 @@ class SqrtField:
|
|||
x3 = uv * v
|
||||
cost.muls += 2
|
||||
if DEBUG: assert x3 == u^self.m
|
||||
if EXPENSIVE: assert x3.multiplicative_order().divides(2^self.n)
|
||||
if EXPENSIVE:
|
||||
x3_order = x3.multiplicative_order()
|
||||
if VERBOSE: print("x3_order = %r" % (x3_order,))
|
||||
# x3_order is 2^n iff u is nonsquare, otherwise it divides 2^(n-1).
|
||||
assert x3.divides(2^self.n)
|
||||
|
||||
x2 = x3^(1<<8)
|
||||
x1 = x2^(1<<8)
|
||||
|
@ -119,23 +125,28 @@ class SqrtField:
|
|||
|
||||
# i = 0, 1
|
||||
t_ = self.invtab[self.hash(x0)] # = t >> 16
|
||||
if DEBUG: assert 1 == x0 * self.g^(t_ << 24), (x0, t_)
|
||||
assert t_ < 0x100, t_
|
||||
alpha = x1 * self.gtab[2][t_]
|
||||
cost.muls += 1
|
||||
|
||||
# i = 2
|
||||
t_ += self.invtab[self.hash(alpha)] << 8 # = t >> 8
|
||||
if DEBUG: assert 1 == x1 * self.g^(t_ << 16), (x1, t_)
|
||||
assert t_ < 0x10000, t_
|
||||
alpha = x2 * self.gtab[1][t_ % 256] * self.gtab[2][t_ >> 8]
|
||||
cost.muls += 2
|
||||
|
||||
# i = 3
|
||||
t_ += self.invtab[self.hash(alpha)] << 16 # = t
|
||||
if DEBUG: assert 1 == x2 * self.g^(t_ << 8), (x2, t_)
|
||||
assert t_ < 0x1000000, t_
|
||||
alpha = x3 * self.gtab[0][t_ % 256] * self.gtab[1][(t_ >> 8) % 256] * self.gtab[2][t_ >> 16]
|
||||
cost.muls += 3
|
||||
|
||||
t_ = ((self.invtab[self.hash(alpha)] << 24) + t_) >> 1 # = t
|
||||
t_ += self.invtab[self.hash(alpha)] << 24 # = t << 1
|
||||
if DEBUG: assert 1 == x3 * self.g^t_, (x3, t_)
|
||||
t_ >>= 1
|
||||
assert t_ < 0x80000000, t_
|
||||
res = uv * self.gtab[0][t_ % 256] * self.gtab[1][(t_ >> 8) % 256] * self.gtab[2][(t_ >> 16) % 256] * self.gtab[3][t_ >> 24]
|
||||
cost.muls += 4
|
||||
|
@ -143,7 +154,10 @@ class SqrtField:
|
|||
if res^2 != u:
|
||||
res = None
|
||||
cost.sqrs += 1
|
||||
if DEBUG: assert u.is_square() == (res is not None)
|
||||
if DEBUG:
|
||||
issq = u.is_square()
|
||||
assert issq == (res is not None)
|
||||
if EXPENSIVE: assert issq == (x3_order != 2^self.n), (issq, x3_order)
|
||||
return (res, cost)
|
||||
|
||||
|
||||
|
@ -167,7 +181,12 @@ print(F_p.sarkar_sqrt(x))
|
|||
x = Mod(0x3456789012345678901234567890123456789012345678901234567890123456, p)
|
||||
print(F_p.sarkar_sqrt(x))
|
||||
|
||||
if True:
|
||||
if SUBGROUP_TEST:
|
||||
for i in range(33):
|
||||
x = F_p.g^(2^i)
|
||||
print(F_p.sarkar_sqrt(x))
|
||||
|
||||
if OP_COUNT:
|
||||
total_cost = Cost(0, 0)
|
||||
iters = 1000
|
||||
for i in range(iters):
|
||||
|
|
Loading…
Reference in New Issue