From 9daa1ba3c8c1bdfeeb5a12e1f40e3d822eaa9ebb Mon Sep 17 00:00:00 2001 From: Deirdre Connolly Date: Fri, 17 Apr 2020 02:38:44 -0400 Subject: [PATCH] Impl PartialEq for some Sapling keys --- zebra-chain/src/keys/sapling.rs | 95 ++++++++++++++++++++------- zebra-chain/src/keys/sapling/tests.rs | 16 ++--- 2 files changed, 76 insertions(+), 35 deletions(-) diff --git a/zebra-chain/src/keys/sapling.rs b/zebra-chain/src/keys/sapling.rs index 135d71d3e..6fe12719b 100644 --- a/zebra-chain/src/keys/sapling.rs +++ b/zebra-chain/src/keys/sapling.rs @@ -18,6 +18,7 @@ use std::{ fmt, io::{self, Write}, ops::Deref, + str::FromStr, }; use bech32::{self, FromBase32, ToBase32}; @@ -209,7 +210,7 @@ impl fmt::Display for SpendingKey { } } -impl std::str::FromStr for SpendingKey { +impl FromStr for SpendingKey { type Err = SerializationError; fn from_str(s: &str) -> Result { @@ -365,6 +366,18 @@ impl From for OutgoingViewingKey { } } +impl PartialEq<[u8; 32]> for OutgoingViewingKey { + fn eq(&self, other: &[u8; 32]) -> bool { + self.0 == *other + } +} + +impl PartialEq for [u8; 32] { + fn eq(&self, other: &OutgoingViewingKey) -> bool { + *self == other.0 + } +} + /// An _Authorizing Key_, as described in [protocol specification /// §4.2.2][ps]. /// @@ -403,6 +416,18 @@ impl From for AuthorizingKey { } } +impl PartialEq<[u8; 32]> for AuthorizingKey { + fn eq(&self, other: &[u8; 32]) -> bool { + Into::<[u8; 32]>::into(self.0) == *other + } +} + +impl PartialEq for [u8; 32] { + fn eq(&self, other: &AuthorizingKey) -> bool { + *self == Into::<[u8; 32]>::into(other.0) + } +} + /// A _Nullifier Deriving Key_, as described in [protocol /// specification §4.2.2][ps]. /// @@ -482,8 +507,6 @@ pub struct IncomingViewingKey { scalar: Scalar, } -// TODO: impl a top-level to_bytes or PartialEq between this and [u8; 32] - // TODO: impl a From that accepts a Network? impl Deref for IncomingViewingKey { @@ -494,6 +517,25 @@ impl Deref for IncomingViewingKey { } } +impl fmt::Debug for IncomingViewingKey { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_tuple("IncomingViewingKey") + .field(&hex::encode(self.to_bytes())) + .finish() + } +} + +impl fmt::Display for IncomingViewingKey { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let hrp = match self.network { + Network::Mainnet => ivk_hrp::MAINNET, + _ => ivk_hrp::TESTNET, + }; + + bech32::encode_to_fmt(f, hrp, &self.scalar.to_bytes().to_base32()).unwrap() + } +} + impl From<[u8; 32]> for IncomingViewingKey { /// Generate an _IncomingViewingKey_ from existing bytes. fn from(mut bytes: [u8; 32]) -> Self { @@ -539,26 +581,7 @@ impl From for [u8; 32] { } } -impl fmt::Debug for IncomingViewingKey { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_tuple("IncomingViewingKey") - .field(&hex::encode(self.to_bytes())) - .finish() - } -} - -impl fmt::Display for IncomingViewingKey { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let hrp = match self.network { - Network::Mainnet => ivk_hrp::MAINNET, - _ => ivk_hrp::TESTNET, - }; - - bech32::encode_to_fmt(f, hrp, &self.scalar.to_bytes().to_base32()).unwrap() - } -} - -impl std::str::FromStr for IncomingViewingKey { +impl FromStr for IncomingViewingKey { type Err = SerializationError; fn from_str(s: &str) -> Result { @@ -582,6 +605,18 @@ impl std::str::FromStr for IncomingViewingKey { } } +impl PartialEq<[u8; 32]> for IncomingViewingKey { + fn eq(&self, other: &[u8; 32]) -> bool { + self.scalar.to_bytes() == *other + } +} + +impl PartialEq for [u8; 32] { + fn eq(&self, other: &IncomingViewingKey) -> bool { + *self == other.scalar.to_bytes() + } +} + /// A _Diversifier_, as described in [protocol specification §4.2.2][ps]. /// /// Combined with an _IncomingViewingKey_, produces a _diversified @@ -629,6 +664,18 @@ impl From for Diversifier { } } +impl PartialEq<[u8; 11]> for Diversifier { + fn eq(&self, other: &[u8; 11]) -> bool { + self.0 == *other + } +} + +impl PartialEq for [u8; 11] { + fn eq(&self, other: &Diversifier) -> bool { + *self == other.0 + } +} + impl Diversifier { /// Generate a new _Diversifier_ that has already been confirmed /// as a preimage to a valid diversified base point when used to @@ -764,7 +811,7 @@ impl fmt::Display for FullViewingKey { } } -impl std::str::FromStr for FullViewingKey { +impl FromStr for FullViewingKey { type Err = SerializationError; fn from_str(s: &str) -> Result { diff --git a/zebra-chain/src/keys/sapling/tests.rs b/zebra-chain/src/keys/sapling/tests.rs index a69333f9d..eb00964a8 100644 --- a/zebra-chain/src/keys/sapling/tests.rs +++ b/zebra-chain/src/keys/sapling/tests.rs @@ -66,24 +66,18 @@ mod tests { let proof_authorizing_key = ProofAuthorizingKey::from(spending_key); assert_eq!(proof_authorizing_key.to_bytes(), test_vector.nsk); let outgoing_viewing_key = OutgoingViewingKey::from(spending_key); - assert_eq!( - Into::<[u8; 32]>::into(outgoing_viewing_key), - test_vector.ovk - ); + assert_eq!(outgoing_viewing_key, test_vector.ovk); let authorizing_key = AuthorizingKey::from(spend_authorizing_key); - assert_eq!(Into::<[u8; 32]>::into(authorizing_key), test_vector.ak); + assert_eq!(authorizing_key, test_vector.ak); let nullifier_deriving_key = NullifierDerivingKey::from(proof_authorizing_key); - assert_eq!( - Into::<[u8; 32]>::into(nullifier_deriving_key), - test_vector.nk - ); + assert_eq!(nullifier_deriving_key.to_bytes(), test_vector.nk); let incoming_viewing_key = IncomingViewingKey::from((authorizing_key, nullifier_deriving_key)); - assert_eq!(incoming_viewing_key.scalar.to_bytes(), test_vector.ivk); + assert_eq!(incoming_viewing_key, test_vector.ivk); let diversifier = Diversifier::from(spending_key); - assert_eq!(diversifier.0, test_vector.default_d); + assert_eq!(diversifier, test_vector.default_d); let transmission_key = TransmissionKey::from(incoming_viewing_key, diversifier); assert_eq!(transmission_key.to_bytes(), test_vector.default_pk_d);