diff --git a/sapling_jubjub.py b/sapling_jubjub.py index 4f0d579..423ff05 100644 --- a/sapling_jubjub.py +++ b/sapling_jubjub.py @@ -106,10 +106,6 @@ class Fq(FieldElement): class Fr(FieldElement): - @staticmethod - def from_bytes(buf): - return Fr(leos2ip(buf)) - def __init__(self, s): FieldElement.__init__(self, Fr, s, r_j) diff --git a/sapling_key_components.py b/sapling_key_components.py index bf9cfc5..659c5f0 100644 --- a/sapling_key_components.py +++ b/sapling_key_components.py @@ -6,7 +6,15 @@ from sapling_generators import PROVING_KEY_BASE, SPENDING_KEY_BASE, group_hash from sapling_jubjub import Fr from sapling_merkle_tree import MERKLE_DEPTH from sapling_notes import note_commit, note_nullifier -from sapling_utils import chunk, leos2bsp +from sapling_utils import chunk, leos2bsp, leos2ip + +# +# Utilities +# + +def to_scalar(buf): + return Fr(leos2ip(buf)) + # # PRFs and hashes @@ -23,8 +31,7 @@ def crh_ivk(ak, nk): digest.update(ak) digest.update(nk) ivk = digest.digest() - ivk = ivk[:31] + bytes([ivk[31] & 0b00000111]) - return ivk + return leos2ip(ivk) % 2**251 # @@ -46,11 +53,11 @@ class SpendingKey(object): @cached def ask(self): - return Fr.from_bytes(prf_expand(self.data, b'\0')) + return to_scalar(prf_expand(self.data, b'\0')) @cached def nsk(self): - return Fr.from_bytes(prf_expand(self.data, b'\1')) + return to_scalar(prf_expand(self.data, b'\1')) @cached def ovk(self): @@ -66,7 +73,7 @@ class SpendingKey(object): @cached def ivk(self): - return Fr.from_bytes(crh_ivk(bytes(self.ak()), bytes(self.nk()))) + return Fr(crh_ivk(bytes(self.ak()), bytes(self.nk()))) @cached def default_d(self): diff --git a/sapling_signatures.py b/sapling_signatures.py index 424be53..f249f4e 100644 --- a/sapling_signatures.py +++ b/sapling_signatures.py @@ -4,7 +4,8 @@ import os from pyblake2 import blake2b from sapling_generators import SPENDING_KEY_BASE -from sapling_jubjub import Fr, Point +from sapling_jubjub import Fr, Point, r_j +from sapling_key_components import to_scalar from sapling_utils import cldiv, chunk, leos2ip @@ -29,7 +30,7 @@ class RedJubjub(object): self._random = random def gen_private(self): - return self.Private.from_bytes(self._random(64)) + return to_scalar(self._random(64)) def derive_public(self, sk): return self.P_g * sk @@ -58,9 +59,9 @@ class RedJubjub(object): mid = cldiv(self.l_G, 8) (Rbar, Sbar) = (sig[:mid], sig[mid:]) # TODO: bitlength(r_j) R = Point.from_bytes(Rbar) - S = Fr.from_bytes(Sbar) + S = leos2ip(Sbar) c = h_star(Rbar + M) - return R and S.s == leos2ip(Sbar) and self.P_g * S == R + vk * c + return R and S < r_j and self.P_g * Fr(S) == R + vk * c def main():