Refactor the Orchard FullViewingKey constructor by adding FullViewingKey.from_spending_key.

Signed-off-by: Daira Hopwood <daira@jacaranda.org>
This commit is contained in:
Daira Hopwood 2022-02-01 15:54:50 +00:00
parent 6c2bac7b73
commit 2234fa1242
5 changed files with 12 additions and 12 deletions

View File

@ -42,7 +42,7 @@ def main():
has_o_addr = (not has_s_addr) or rand.bool() has_o_addr = (not has_s_addr) or rand.bool()
if has_o_addr: if has_o_addr:
orchard_sk = orchard_key_components.SpendingKey(rand.b(32)) orchard_sk = orchard_key_components.SpendingKey(rand.b(32))
orchard_fvk = orchard_key_components.FullViewingKey(orchard_sk) orchard_fvk = orchard_key_components.FullViewingKey.from_spending_key(orchard_sk)
orchard_default_d = orchard_fvk.default_d() orchard_default_d = orchard_fvk.default_d()
orchard_default_pk_d = orchard_fvk.default_pkd() orchard_default_pk_d = orchard_fvk.default_pkd()
orchard_raw_addr = b"".join([orchard_default_d[:11], bytes(orchard_default_pk_d)[:32]]) orchard_raw_addr = b"".join([orchard_default_d[:11], bytes(orchard_default_pk_d)[:32]])

View File

@ -58,7 +58,7 @@ def main():
has_o_key = (not has_s_key) or rand.bool() has_o_key = (not has_s_key) or rand.bool()
if has_o_key: if has_o_key:
orchard_sk = orchard_key_components.SpendingKey(rand.b(32)) orchard_sk = orchard_key_components.SpendingKey(rand.b(32))
orchard_fvk = orchard_key_components.FullViewingKey(orchard_sk) orchard_fvk = orchard_key_components.FullViewingKey.from_spending_key(orchard_sk)
orchard_fvk_bytes = b"".join([ orchard_fvk_bytes = b"".join([
bytes(orchard_fvk.ak), bytes(orchard_fvk.ak),
bytes(orchard_fvk.nk), bytes(orchard_fvk.nk),

View File

@ -53,17 +53,17 @@ class SpendingKey:
class FullViewingKey(object): class FullViewingKey(object):
def __init__(self, sk): def __init__(self, rivk, ak, nk):
if isinstance(sk, SpendingKey): (self.rivk, self.ak, self.nk) = (rivk, ak, nk)
(self.rivk, self.ak, self.nk) = (sk.rivk, sk.ak, sk.nk)
else:
(self.rivk, self.ak, self.nk) = sk
K = i2leosp(256, self.rivk.s) K = i2leosp(256, self.rivk.s)
R = prf_expand(K, b'\x82' + i2leosp(256, self.ak.s) + i2leosp(256, self.nk.s)) R = prf_expand(K, b'\x82' + i2leosp(256, self.ak.s) + i2leosp(256, self.nk.s))
self.dk = R[:32] self.dk = R[:32]
self.ovk = R[32:] self.ovk = R[32:]
@classmethod
def from_spending_key(cls, sk):
return cls(sk.rivk, sk.ak, sk.nk)
def ivk(self): def ivk(self):
return commit_ivk(self.rivk, self.ak, self.nk) return commit_ivk(self.rivk, self.ak, self.nk)
@ -80,7 +80,7 @@ class FullViewingKey(object):
def internal(self): def internal(self):
K = i2leosp(256, self.rivk.s) K = i2leosp(256, self.rivk.s)
rivk_internal = to_scalar(prf_expand(K, b'\x83' + i2leosp(256, self.ak.s) + i2leosp(256, self.nk.s))) rivk_internal = to_scalar(prf_expand(K, b'\x83' + i2leosp(256, self.ak.s) + i2leosp(256, self.nk.s)))
return self.__class__((rivk_internal, self.ak, self.nk)) return self.__class__(rivk_internal, self.ak, self.nk)
def main(): def main():
@ -101,7 +101,7 @@ def main():
test_vectors = [] test_vectors = []
for _ in range(0, 10): for _ in range(0, 10):
sk = SpendingKey(rand.b(32)) sk = SpendingKey(rand.b(32))
fvk = FullViewingKey(sk) fvk = FullViewingKey.from_spending_key(sk)
default_d = fvk.default_d() default_d = fvk.default_d()
default_pk_d = fvk.default_pkd() default_pk_d = fvk.default_pkd()

View File

@ -63,7 +63,7 @@ class OrchardNotePlaintext(object):
def dummy_nullifier(self, rand): def dummy_nullifier(self, rand):
sk = SpendingKey(rand.b(32)) sk = SpendingKey(rand.b(32))
fvk = FullViewingKey(sk) fvk = FullViewingKey.from_spending_key(sk)
pk_d = fvk.default_pkd() pk_d = fvk.default_pkd()
d = fvk.default_d() d = fvk.default_d()

View File

@ -214,7 +214,7 @@ def main():
sender_ovk = rand.b(32) sender_ovk = rand.b(32)
receiver_sk = SpendingKey(rand.b(32)) receiver_sk = SpendingKey(rand.b(32))
receiver_fvk = FullViewingKey(receiver_sk) receiver_fvk = FullViewingKey.from_spending_key(receiver_sk)
ivk = receiver_fvk.ivk() ivk = receiver_fvk.ivk()
d = receiver_fvk.default_d() d = receiver_fvk.default_d()
pk_d = receiver_fvk.default_pkd() pk_d = receiver_fvk.default_pkd()