Merge pull request #1 from str4d/to_scalar

Implement ToScalar from spec, and small refactor to match spec more closely
This commit is contained in:
str4d 2018-06-05 09:05:46 +12:00 committed by GitHub
commit 41d250ed0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 14 deletions

View File

@ -106,10 +106,6 @@ class Fq(FieldElement):
class Fr(FieldElement):
@staticmethod
def from_bytes(buf):
return Fr(leos2ip(buf))
def __init__(self, s):
FieldElement.__init__(self, Fr, s, r_j)

View File

@ -6,7 +6,15 @@ from sapling_generators import PROVING_KEY_BASE, SPENDING_KEY_BASE, group_hash
from sapling_jubjub import Fr
from sapling_merkle_tree import MERKLE_DEPTH
from sapling_notes import note_commit, note_nullifier
from sapling_utils import chunk, leos2bsp
from sapling_utils import chunk, leos2bsp, leos2ip
#
# Utilities
#
def to_scalar(buf):
return Fr(leos2ip(buf))
#
# PRFs and hashes
@ -23,8 +31,7 @@ def crh_ivk(ak, nk):
digest.update(ak)
digest.update(nk)
ivk = digest.digest()
ivk = ivk[:31] + bytes([ivk[31] & 0b00000111])
return ivk
return leos2ip(ivk) % 2**251
#
@ -46,11 +53,11 @@ class SpendingKey(object):
@cached
def ask(self):
return Fr.from_bytes(prf_expand(self.data, b'\0'))
return to_scalar(prf_expand(self.data, b'\0'))
@cached
def nsk(self):
return Fr.from_bytes(prf_expand(self.data, b'\1'))
return to_scalar(prf_expand(self.data, b'\1'))
@cached
def ovk(self):
@ -66,7 +73,7 @@ class SpendingKey(object):
@cached
def ivk(self):
return Fr.from_bytes(crh_ivk(bytes(self.ak()), bytes(self.nk())))
return Fr(crh_ivk(bytes(self.ak()), bytes(self.nk())))
@cached
def default_d(self):

View File

@ -4,7 +4,8 @@ import os
from pyblake2 import blake2b
from sapling_generators import SPENDING_KEY_BASE
from sapling_jubjub import Fr, Point
from sapling_jubjub import Fr, Point, r_j
from sapling_key_components import to_scalar
from sapling_utils import cldiv, chunk, leos2ip
@ -29,7 +30,7 @@ class RedJubjub(object):
self._random = random
def gen_private(self):
return self.Private.from_bytes(self._random(64))
return to_scalar(self._random(64))
def derive_public(self, sk):
return self.P_g * sk
@ -58,9 +59,9 @@ class RedJubjub(object):
mid = cldiv(self.l_G, 8)
(Rbar, Sbar) = (sig[:mid], sig[mid:]) # TODO: bitlength(r_j)
R = Point.from_bytes(Rbar)
S = Fr.from_bytes(Sbar)
S = leos2ip(Sbar)
c = h_star(Rbar + M)
return R and S.s == leos2ip(Sbar) and self.P_g * S == R + vk * c
return R and S < r_j and self.P_g * Fr(S) == R + vk * c
def main():