Merge pull request #10 from str4d/zip0243-spend-rk

Update ZIP 243 test vectors with valid rk values
This commit is contained in:
str4d 2019-02-22 22:34:21 +00:00 committed by GitHub
commit 38cdeda51c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 21 additions and 11 deletions

View File

@ -23,7 +23,7 @@ def group_hash(D, M):
digest.update(URS) digest.update(URS)
digest.update(M) digest.update(M)
p = Point.from_bytes(digest.digest()) p = Point.from_bytes(digest.digest())
if not p: if p is None:
return None return None
q = p * JUBJUB_COFACTOR q = p * JUBJUB_COFACTOR
if q == Point.ZERO: if q == Point.ZERO:
@ -34,7 +34,7 @@ def find_group_hash(D, M):
i = 0 i = 0
while True: while True:
p = group_hash(D, M + bytes([i])) p = group_hash(D, M + bytes([i]))
if p: if p is not None:
return p return p
i += 1 i += 1
assert i < 256 assert i < 256

View File

@ -142,6 +142,14 @@ JUBJUB_D = Fq(-10240) / Fq(10241)
JUBJUB_COFACTOR = Fr(8) JUBJUB_COFACTOR = Fr(8)
class Point(object): class Point(object):
@staticmethod
def rand(rand):
while True:
data = rand.b(32)
p = Point.from_bytes(data)
if p is not None:
return p
@staticmethod @staticmethod
def from_bytes(buf): def from_bytes(buf):
assert len(buf) == 32 assert len(buf) == 32
@ -156,7 +164,7 @@ class Point(object):
u2 = (vv - Fq.ONE) / (vv * JUBJUB_D - JUBJUB_A) u2 = (vv - Fq.ONE) / (vv * JUBJUB_D - JUBJUB_A)
u = u2.sqrt() u = u2.sqrt()
if not u: if u is None:
return None return None
if u.s % 2 != u_sign: if u.s % 2 != u_sign:

View File

@ -2,7 +2,7 @@
import struct import struct
from sapling_generators import find_group_hash, SPENDING_KEY_BASE from sapling_generators import find_group_hash, SPENDING_KEY_BASE
from sapling_jubjub import Fq from sapling_jubjub import Fq, Point
from sapling_utils import leos2ip from sapling_utils import leos2ip
from zc_utils import write_compact_size from zc_utils import write_compact_size
@ -80,7 +80,7 @@ class SpendDescription(object):
self.cv = find_group_hash(b'TVRandPt', rand.b(32)) self.cv = find_group_hash(b'TVRandPt', rand.b(32))
self.anchor = Fq(leos2ip(rand.b(32))) self.anchor = Fq(leos2ip(rand.b(32)))
self.nullifier = rand.b(32) self.nullifier = rand.b(32)
self.rk = rand.b(32) self.rk = Point.rand(rand)
self.proof = GrothProof(rand) self.proof = GrothProof(rand)
self.spendAuthSig = rand.b(64) # Invalid self.spendAuthSig = rand.b(64) # Invalid
@ -89,7 +89,7 @@ class SpendDescription(object):
bytes(self.cv) + bytes(self.cv) +
bytes(self.anchor) + bytes(self.anchor) +
self.nullifier + self.nullifier +
self.rk + bytes(self.rk) +
bytes(self.proof) + bytes(self.proof) +
self.spendAuthSig self.spendAuthSig
) )
@ -221,6 +221,7 @@ class Transaction(object):
self.nLockTime = rand.u32() self.nLockTime = rand.u32()
self.nExpiryHeight = rand.u32() % TX_EXPIRY_HEIGHT_THRESHOLD self.nExpiryHeight = rand.u32() % TX_EXPIRY_HEIGHT_THRESHOLD
if self.nVersion >= SAPLING_TX_VERSION:
self.valueBalance = rand.u64() % (MAX_MONEY + 1) self.valueBalance = rand.u64() % (MAX_MONEY + 1)
self.vShieldedSpends = [] self.vShieldedSpends = []
@ -239,6 +240,7 @@ class Transaction(object):
self.joinSplitPubKey = rand.b(32) # Potentially invalid self.joinSplitPubKey = rand.b(32) # Potentially invalid
self.joinSplitSig = rand.b(64) # Invalid self.joinSplitSig = rand.b(64) # Invalid
if self.nVersion >= SAPLING_TX_VERSION:
self.bindingSig = rand.b(64) # Invalid self.bindingSig = rand.b(64) # Invalid
def header(self): def header(self):

View File

@ -156,7 +156,7 @@ def main():
'rust_fmt': lambda x: None if x == -1 else Some(x), 'rust_fmt': lambda x: None if x == -1 else Some(x),
}), }),
('hash_type', 'u32'), ('hash_type', 'u32'),
('amount', 'u64'), ('amount', 'i64'),
('consensus_branch_id', 'u32'), ('consensus_branch_id', 'u32'),
('sighash', '[u8; 32]'), ('sighash', '[u8; 32]'),
), ),

View File

@ -31,7 +31,7 @@ def getHashShieldedSpends(tx):
digest.update(bytes(desc.cv)) digest.update(bytes(desc.cv))
digest.update(bytes(desc.anchor)) digest.update(bytes(desc.anchor))
digest.update(desc.nullifier) digest.update(desc.nullifier)
digest.update(desc.rk) digest.update(bytes(desc.rk))
digest.update(bytes(desc.proof)) digest.update(bytes(desc.proof))
return digest.digest() return digest.digest()
@ -163,7 +163,7 @@ def main():
'rust_fmt': lambda x: None if x == -1 else Some(x), 'rust_fmt': lambda x: None if x == -1 else Some(x),
}), }),
('hash_type', 'u32'), ('hash_type', 'u32'),
('amount', 'u64'), ('amount', 'i64'),
('consensus_branch_id', 'u32'), ('consensus_branch_id', 'u32'),
('sighash', '[u8; 32]'), ('sighash', '[u8; 32]'),
), ),