Merge pull request #2 from daira/refactor-field-errors

Refactor error handling
This commit is contained in:
str4d 2018-06-05 16:06:47 +12:00 committed by GitHub
commit 9edd16e17a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 9 deletions

View File

@ -13,7 +13,9 @@ assert (q_j - 1) // 2 == qm1d2
#
class FieldElement(object):
def __init__(self, t, s, modulus):
def __init__(self, t, s, modulus, strict=False):
if strict and not (0 <= s and s < modulus):
raise ValueError
self.t = t
self.s = s % modulus
self.m = modulus
@ -47,7 +49,6 @@ class FieldElement(object):
return i2lebsp(l, self.s)
def __bytes__(self):
# TODO: Check length
return i2leosp(256, self.s)
def __eq__(self, a):
@ -58,10 +59,10 @@ class FieldElement(object):
class Fq(FieldElement):
@staticmethod
def from_bytes(buf):
return Fq(leos2ip(buf))
return Fq(leos2ip(buf), strict=True)
def __init__(self, s):
FieldElement.__init__(self, Fq, s, q_j)
def __init__(self, s, strict=False):
FieldElement.__init__(self, Fq, s, q_j, strict=strict)
def __str__(self):
return 'Fq(%s)' % self.s
@ -106,8 +107,8 @@ class Fq(FieldElement):
class Fr(FieldElement):
def __init__(self, s):
FieldElement.__init__(self, Fr, s, r_j)
def __init__(self, s, strict=False):
FieldElement.__init__(self, Fr, s, r_j, strict=strict)
def __str__(self):
return 'Fr(%s)' % self.s
@ -141,10 +142,12 @@ JUBJUB_COFACTOR = Fr(8)
class Point(object):
@staticmethod
def from_bytes(buf):
assert len(buf) == 32
u_sign = buf[31] >> 7
buf = buf[:31] + bytes([buf[31] & 0b01111111])
v = Fq.from_bytes(buf)
if bytes(v) != buf:
try:
v = Fq.from_bytes(buf)
except ValueError:
return None
vv = v * v