diff --git a/sapling_jubjub.py b/sapling_jubjub.py index 423ff05..ede62d2 100644 --- a/sapling_jubjub.py +++ b/sapling_jubjub.py @@ -13,7 +13,9 @@ assert (q_j - 1) // 2 == qm1d2 # class FieldElement(object): - def __init__(self, t, s, modulus): + def __init__(self, t, s, modulus, strict=False): + if strict and not (0 <= s and s < modulus): + raise ValueError self.t = t self.s = s % modulus self.m = modulus @@ -47,7 +49,6 @@ class FieldElement(object): return i2lebsp(l, self.s) def __bytes__(self): - # TODO: Check length return i2leosp(256, self.s) def __eq__(self, a): @@ -58,10 +59,10 @@ class FieldElement(object): class Fq(FieldElement): @staticmethod def from_bytes(buf): - return Fq(leos2ip(buf)) + return Fq(leos2ip(buf), strict=True) - def __init__(self, s): - FieldElement.__init__(self, Fq, s, q_j) + def __init__(self, s, strict=False): + FieldElement.__init__(self, Fq, s, q_j, strict=strict) def __str__(self): return 'Fq(%s)' % self.s @@ -106,8 +107,8 @@ class Fq(FieldElement): class Fr(FieldElement): - def __init__(self, s): - FieldElement.__init__(self, Fr, s, r_j) + def __init__(self, s, strict=False): + FieldElement.__init__(self, Fr, s, r_j, strict=strict) def __str__(self): return 'Fr(%s)' % self.s @@ -141,10 +142,12 @@ JUBJUB_COFACTOR = Fr(8) class Point(object): @staticmethod def from_bytes(buf): + assert len(buf) == 32 u_sign = buf[31] >> 7 buf = buf[:31] + bytes([buf[31] & 0b01111111]) - v = Fq.from_bytes(buf) - if bytes(v) != buf: + try: + v = Fq.from_bytes(buf) + except ValueError: return None vv = v * v