Address review comments.

Co-authored-by: Jack Grigg <jack@electriccoin.co>
This commit is contained in:
therealyingtong 2021-05-26 14:08:08 +08:00
parent 406747099a
commit c182edabd4
2 changed files with 53 additions and 54 deletions

View File

@ -8,7 +8,7 @@ from orchard_utils import to_base, to_scalar
from utils import leos2bsp from utils import leos2bsp
class OrchardNote(object): class OrchardNote(object):
def __init__(self, d, pk_d, v: Scalar, rho, rseed): def __init__(self, d, pk_d, v, rho, rseed):
self.d = d self.d = d
self.pk_d = pk_d self.pk_d = pk_d
self.v = v self.v = v
@ -17,25 +17,27 @@ class OrchardNote(object):
self.rcm = self.rcm(rho) self.rcm = self.rcm(rho)
self.psi = self.psi(rho) self.psi = self.psi(rho)
def __bytes__(self): def __eq__(self, other):
if other is None:
return False
return ( return (
self.d + self.d == other.d and
bytes(self.pk_d) + self.pk_d == other.pk_d and
struct.pack('<Q', self.v.s) + self.v == other.v and
bytes(self.rho) + self.rho == other.rho and
bytes(self.rcm) + self.rcm == other.rcm and
bytes(self.psi) self.psi == other.psi
) )
def rcm(self, rho): def rcm(self, rho):
return to_scalar(prf_expand(bytes(self.rseed), b'\x05' + bytes(rho))) return to_scalar(prf_expand(self.rseed, b'\x05' + bytes(rho)))
def psi(self, rho): def psi(self, rho):
return to_base(prf_expand(bytes(self.rseed), b'\x09' + bytes(rho))) return to_base(prf_expand(self.rseed, b'\x09' + bytes(rho)))
def note_commitment(self): def note_commitment(self):
g_d = diversify_hash(self.d) g_d = diversify_hash(self.d)
return note_commit(self.rcm, leos2bsp(bytes(g_d)), leos2bsp(bytes(self.pk_d)), self.v.s, self.rho, self.psi) return note_commit(self.rcm, leos2bsp(bytes(g_d)), leos2bsp(bytes(self.pk_d)), self.v, self.rho, self.psi)
def note_plaintext(self, memo): def note_plaintext(self, memo):
return OrchardNotePlaintext(self.d, self.v, self.rseed, memo) return OrchardNotePlaintext(self.d, self.v, self.rseed, memo)
@ -53,7 +55,7 @@ class OrchardNotePlaintext(object):
return ( return (
self.leadbyte + self.leadbyte +
self.d + self.d +
struct.pack('<Q', self.v.s) + struct.pack('<Q', self.v) +
bytes(self.rseed) + bytes(self.rseed) +
self.memo self.memo
) )
@ -62,14 +64,13 @@ class OrchardNotePlaintext(object):
sk = SpendingKey(rand.b(32)) sk = SpendingKey(rand.b(32))
fvk = FullViewingKey(sk) fvk = FullViewingKey(sk)
pk_d = fvk.default_pkd() pk_d = fvk.default_pkd()
g_d = diversify_hash(fvk.default_d()) d = fvk.default_d()
v = Scalar.ZERO v = 0
rho = Point.rand(rand) rseed = rand.b(32)
rho_bytes = bytes(rho) rho = Point.rand(rand).extract()
rho = rho.extract()
note = OrchardNote(fvk.default_d(), pk_d, v, rho, rho_bytes) note = OrchardNote(d, pk_d, v, rho, rseed)
cm = note_commit(note.rcm, leos2bsp(bytes(g_d)), leos2bsp(bytes(pk_d)), v.s, rho, note.psi) cm = note.note_commitment()
return derive_nullifier(fvk.nk, rho, note.psi, cm) return derive_nullifier(fvk.nk, rho, note.psi, cm)

View File

@ -10,7 +10,7 @@ from tv_rand import Rand
from orchard_generators import VALUE_COMMITMENT_VALUE_BASE, VALUE_COMMITMENT_RANDOMNESS_BASE from orchard_generators import VALUE_COMMITMENT_VALUE_BASE, VALUE_COMMITMENT_RANDOMNESS_BASE
from orchard_pallas import Point, Scalar from orchard_pallas import Point, Scalar
from orchard_commitments import rcv_trapdoor from orchard_commitments import rcv_trapdoor, value_commit
from orchard_key_components import diversify_hash, prf_expand, FullViewingKey, SpendingKey from orchard_key_components import diversify_hash, prf_expand, FullViewingKey, SpendingKey
from orchard_note import OrchardNote, OrchardNotePlaintext from orchard_note import OrchardNote, OrchardNotePlaintext
from orchard_utils import to_scalar from orchard_utils import to_scalar
@ -20,7 +20,7 @@ from utils import leos2bsp
def kdf_orchard(shared_secret, ephemeral_key): def kdf_orchard(shared_secret, ephemeral_key):
digest = blake2b(digest_size=32, person=b'Zcash_OrchardKDF') digest = blake2b(digest_size=32, person=b'Zcash_OrchardKDF')
digest.update(bytes(shared_secret)) digest.update(bytes(shared_secret))
digest.update(bytes(ephemeral_key)) digest.update(ephemeral_key)
return digest.digest() return digest.digest()
# https://zips.z.cash/protocol/nu5.pdf#concreteprfs # https://zips.z.cash/protocol/nu5.pdf#concreteprfs
@ -36,7 +36,7 @@ def prf_ock_orchard(ovk, cv, cmx, ephemeral_key):
class OrchardKeyAgreement(object): class OrchardKeyAgreement(object):
@staticmethod @staticmethod
def esk(rseed, rho): def esk(rseed, rho):
return to_scalar(prf_expand(bytes(rseed), b'\x04' + bytes(rho))) return to_scalar(prf_expand(rseed, b'\x04' + bytes(rho)))
@staticmethod @staticmethod
def derive_public(esk, g_d): def derive_public(esk, g_d):
@ -49,8 +49,8 @@ class OrchardKeyAgreement(object):
# https://zips.z.cash/protocol/nu5.pdf#concretesym # https://zips.z.cash/protocol/nu5.pdf#concretesym
class OrchardSym(object): class OrchardSym(object):
@staticmethod @staticmethod
def k(random): def k(rand):
return random(32) return rand.b(32)
@staticmethod @staticmethod
def encrypt(key, plaintext): def encrypt(key, plaintext):
@ -64,11 +64,8 @@ class OrchardSym(object):
# https://zips.z.cash/protocol/nu5.pdf#saplingandorchardencrypt # https://zips.z.cash/protocol/nu5.pdf#saplingandorchardencrypt
class OrchardNoteEncryption(object): class OrchardNoteEncryption(object):
def __init__(self, random=os.urandom): def __init__(self, rand):
self._random = random self._rand = rand
def rseed(self):
return self._random.b(32)
def encrypt(self, note: OrchardNote, memo, pk_d_new, g_d_new, cv_new, cm_new, ovk=None): def encrypt(self, note: OrchardNote, memo, pk_d_new, g_d_new, cv_new, cm_new, ovk=None):
np = note.note_plaintext(memo) np = note.note_plaintext(memo)
@ -78,12 +75,12 @@ class OrchardNoteEncryption(object):
epk = OrchardKeyAgreement.derive_public(esk, g_d_new) epk = OrchardKeyAgreement.derive_public(esk, g_d_new)
ephemeral_key = bytes(epk) ephemeral_key = bytes(epk)
shared_secret = OrchardKeyAgreement.agree(esk, pk_d_new) shared_secret = OrchardKeyAgreement.agree(esk, pk_d_new)
k_enc = kdf_orchard(shared_secret, epk) k_enc = kdf_orchard(shared_secret, ephemeral_key)
c_enc = OrchardSym.encrypt(k_enc, p_enc) c_enc = OrchardSym.encrypt(k_enc, p_enc)
if not ovk: if not ovk:
ock = OrchardSym.k(self._random) ock = OrchardSym.k(self._rand)
op = self._random.b(64) op = self._rand.b(64)
else: else:
cv = bytes(cv_new) cv = bytes(cv_new)
cmx = bytes(cm_new.extract()) cmx = bytes(cm_new.extract())
@ -115,7 +112,8 @@ class TransmittedNoteCipherText(object):
return None return None
shared_secret = OrchardKeyAgreement.agree(ivk, epk) shared_secret = OrchardKeyAgreement.agree(ivk, epk)
k_enc = kdf_orchard(shared_secret, epk) ephemeral_key = bytes(epk)
k_enc = kdf_orchard(shared_secret, ephemeral_key)
p_enc = OrchardSym.decrypt(k_enc, self.c_enc) p_enc = OrchardSym.decrypt(k_enc, self.c_enc)
if not p_enc: if not p_enc:
return None return None
@ -131,7 +129,7 @@ class TransmittedNoteCipherText(object):
g_d = diversify_hash(np.d) g_d = diversify_hash(np.d)
pk_d = OrchardKeyAgreement.derive_public(ivk, g_d) pk_d = OrchardKeyAgreement.derive_public(ivk, g_d)
note = OrchardNote(np.d, pk_d, np.v, rho, np.rseed) note = OrchardNote(np.d, pk_d, np.v.s, rho, np.rseed)
esk = OrchardKeyAgreement.esk(np.rseed, rho) esk = OrchardKeyAgreement.esk(np.rseed, rho)
if OrchardKeyAgreement.derive_public(esk, g_d) != epk: if OrchardKeyAgreement.derive_public(esk, g_d) != epk:
@ -145,8 +143,8 @@ class TransmittedNoteCipherText(object):
return (note, np.memo) return (note, np.memo)
def decrypt_using_fvk(self, fvk, rseed, rho, cv, cm_star): def decrypt_using_ovk(self, ovk, rseed, rho, cv, cm_star):
ock = prf_ock_orchard(fvk.ovk, bytes(cv), bytes(cm_star.extract()), bytes(self.epk)) ock = prf_ock_orchard(ovk, bytes(cv), bytes(cm_star.extract()), bytes(self.epk))
op = OrchardSym.decrypt(ock, self.c_out) op = OrchardSym.decrypt(ock, self.c_out)
if not op: if not op:
return None return None
@ -160,7 +158,8 @@ class TransmittedNoteCipherText(object):
return None return None
shared_secret = OrchardKeyAgreement.agree(esk, pk_d) shared_secret = OrchardKeyAgreement.agree(esk, pk_d)
k_enc = kdf_orchard(shared_secret, self.epk) ephemeral_key = bytes(self.epk)
k_enc = kdf_orchard(shared_secret, ephemeral_key)
p_enc = OrchardSym.decrypt(k_enc, self.c_enc) p_enc = OrchardSym.decrypt(k_enc, self.c_enc)
if not p_enc: if not p_enc:
return None return None
@ -174,7 +173,7 @@ class TransmittedNoteCipherText(object):
p_enc[52:564], # memo p_enc[52:564], # memo
) )
g_d = diversify_hash(np.d) g_d = diversify_hash(np.d)
note = OrchardNote(np.d, pk_d, np.v, rho, np.rseed) note = OrchardNote(np.d, pk_d, np.v.s, rho, np.rseed)
cm = note.note_commitment() cm = note.note_commitment()
if not cm: if not cm:
@ -199,12 +198,9 @@ def main():
return bytes(ret) return bytes(ret)
rand = Rand(randbytes) rand = Rand(randbytes)
ne = OrchardNoteEncryption(rand)
test_vectors = [] test_vectors = []
for _ in range(0, 10): for _ in range(0, 10):
sender_sk = SpendingKey(rand.b(32)) sender_ovk = rand.b(32)
sender_fvk = FullViewingKey(sender_sk)
receiver_sk = SpendingKey(rand.b(32)) receiver_sk = SpendingKey(rand.b(32))
receiver_fvk = FullViewingKey(receiver_sk) receiver_fvk = FullViewingKey(receiver_sk)
@ -213,42 +209,44 @@ def main():
pk_d = receiver_fvk.default_pkd() pk_d = receiver_fvk.default_pkd()
g_d = diversify_hash(d) g_d = diversify_hash(d)
rseed = ne.rseed() rseed = rand.b(32)
memo = rand.b(512) memo = b'\xff' + rand.b(511)
np = OrchardNotePlaintext( np = OrchardNotePlaintext(
d, d,
Scalar(rand.u64() % (MAX_MONEY + 1)), rand.u64(),
rseed, rseed,
memo memo
) )
rcv = rcv_trapdoor(rand) rcv = rcv_trapdoor(rand)
cv = VALUE_COMMITMENT_VALUE_BASE * np.v + VALUE_COMMITMENT_RANDOMNESS_BASE * rcv cv = value_commit(rcv, Scalar(np.v))
rho = np.dummy_nullifier(rand) rho = np.dummy_nullifier(rand)
note = OrchardNote(d, pk_d, np.v, rho, rseed) note = OrchardNote(d, pk_d, np.v, rho, rseed)
cm = note.note_commitment() cm = note.note_commitment()
transmitted_note_ciphertext = ne.encrypt(note, memo, pk_d, g_d, cv, cm, sender_fvk.ovk) ne = OrchardNoteEncryption(rand)
transmitted_note_ciphertext = ne.encrypt(note, memo, pk_d, g_d, cv, cm, sender_ovk)
(note_using_ivk, memo_using_ivk) = transmitted_note_ciphertext.decrypt_using_ivk( (note_using_ivk, memo_using_ivk) = transmitted_note_ciphertext.decrypt_using_ivk(
Scalar(ivk.s), rho, cm Scalar(ivk.s), rho, cm
) )
(note_using_fvk, memo_using_fvk) = transmitted_note_ciphertext.decrypt_using_fvk( (note_using_ovk, memo_using_ovk) = transmitted_note_ciphertext.decrypt_using_ovk(
sender_fvk, rseed, rho, cv, cm sender_ovk, rseed, rho, cv, cm
) )
assert(bytes(note_using_ivk) == bytes(note_using_fvk)) assert(note_using_ivk == note_using_ovk)
assert(memo_using_ivk == memo_using_fvk) assert(memo_using_ivk == memo_using_ovk)
assert(bytes(note_using_ivk) == bytes(note)) assert(note_using_ivk == note)
assert(memo_using_ivk == memo) assert(memo_using_ivk == memo)
test_vectors.append({ test_vectors.append({
'ovk': sender_fvk.ovk, 'ovk': sender_ovk,
'ivk': bytes(ivk), 'ivk': bytes(ivk),
'default_d': d, 'default_d': d,
'default_pk_d': bytes(pk_d), 'default_pk_d': bytes(pk_d),
'v': np.v.s, 'v': np.v,
'rcm': bytes(note.rcm), 'rcm': bytes(note.rcm),
'memo': np.memo, 'memo': np.memo,
'cv': bytes(cv), 'cv': bytes(cv),