diff --git a/zebra-chain/src/keys/sprout.rs b/zebra-chain/src/keys/sprout.rs index 3ff470dba..83ab9bf7b 100644 --- a/zebra-chain/src/keys/sprout.rs +++ b/zebra-chain/src/keys/sprout.rs @@ -22,12 +22,94 @@ use crate::{ Network, }; +/// Magic numbers used to identify what networks Sprout Shielded +/// Addresses are associated with. +mod sk_magics { + pub const MAINNET: [u8; 2] = [0xAB, 0x36]; + pub const TESTNET: [u8; 2] = [0xAC, 0x08]; +} + /// Our root secret key of the Sprout key derivation tree. /// /// All other Sprout key types derive from the SpendingKey value. /// Actually 252 bits. #[derive(Copy, Clone, Debug, Eq, PartialEq)] -pub struct SpendingKey(pub [u8; 32]); +#[cfg_attr(test, derive(Arbitrary))] +pub struct SpendingKey { + /// + pub bytes: [u8; 32], + /// + pub network: Network, +} + +impl ZcashSerialize for SpendingKey { + fn zcash_serialize(&self, mut writer: W) -> Result<(), io::Error> { + match self.network { + Network::Mainnet => writer.write_all(&sk_magics::MAINNET[..])?, + _ => writer.write_all(&sk_magics::TESTNET[..])?, + } + writer.write_all(&self.bytes[..])?; + + Ok(()) + } +} + +impl ZcashDeserialize for SpendingKey { + fn zcash_deserialize(mut reader: R) -> Result { + let mut version_bytes = [0; 2]; + reader.read_exact(&mut version_bytes)?; + + let network = match version_bytes { + sk_magics::MAINNET => Network::Mainnet, + sk_magics::TESTNET => Network::Testnet, + _ => panic!(SerializationError::Parse( + "bad sprout shielded addr version/type", + )), + }; + + Ok(SpendingKey { + network, + bytes: reader.read_32_bytes()?, + }) + } +} + +impl fmt::Display for SpendingKey { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut bytes = io::Cursor::new(Vec::new()); + + let _ = self.zcash_serialize(&mut bytes); + + f.write_str(&bs58::encode(bytes.get_ref()).with_check().into_string()) + } +} + +impl std::str::FromStr for SpendingKey { + type Err = SerializationError; + + fn from_str(s: &str) -> Result { + let result = &bs58::decode(s).with_check(None).into_vec(); + + match result { + Ok(bytes) => Self::zcash_deserialize(&bytes[..]), + Err(_) => Err(SerializationError::Parse("bs58 decoding error")), + } + } +} + +impl From<[u8; 32]> for SpendingKey { + /// Generate a _SpendingKey_ from existing bytes, with the high 4 + /// bits of the first byte set to zero (ie, 256 bits clamped to + /// 252). + fn from(mut bytes: [u8; 32]) -> SpendingKey { + bytes[0] &= 0b0000_1111; // Force the 4 high-order bits to zero. + + SpendingKey { + bytes, + network: Network::default(), + } + } +} impl SpendingKey { /// Generate a new _SpendingKey_ with the high 4 bits of the first @@ -43,16 +125,6 @@ impl SpendingKey { } } -impl From<[u8; 32]> for SpendingKey { - /// Generate a _SpendingKey_ from existing bytes, with the high 4 - /// bits of the first byte set to zero (ie, 256 bits clamped to - /// 252). - fn from(mut bytes: [u8; 32]) -> SpendingKey { - bytes[0] &= 0b0000_1111; // Force the 4 high-order bits to zero. - SpendingKey(bytes) - } -} - /// Derived from a _SpendingKey_. pub type ReceivingKey = x25519_dalek::StaticSecret; @@ -67,7 +139,7 @@ impl From for ReceivingKey { let mut state = [0u32; 8]; let mut block = [0u8; 64]; // Thus, t = 0 - block[0..32].copy_from_slice(&spending_key.0[..]); + block[0..32].copy_from_slice(&spending_key.bytes[..]); block[0] |= 0b1100_0000; sha2::compress256(&mut state, &block); @@ -100,7 +172,7 @@ impl From for PayingKey { let mut state = [0u32; 8]; let mut block = [0u8; 64]; - block[0..32].copy_from_slice(&spending_key.0[..]); + block[0..32].copy_from_slice(&spending_key.bytes[..]); block[0] |= 0b1100_0000; block[32] = 1u8; // t = 1 @@ -117,8 +189,8 @@ impl From for PayingKey { /// Derived from a _ReceivingKey_. pub type TransmissionKey = x25519_dalek::PublicKey; -/// Magic numbers used to identify what networks Sprout Shielded -/// Addresses are associated with. +/// Magic numbers used to identify with what networks Sprout Incoming +/// Viewing Keys are associated. mod ivk_magics { pub const MAINNET: [u8; 3] = [0xA8, 0xAB, 0xD3]; pub const TESTNET: [u8; 3] = [0xA8, 0xAC, 0x0C]; @@ -177,7 +249,7 @@ impl ZcashDeserialize for IncomingViewingKey { ivk_magics::MAINNET => Network::Mainnet, ivk_magics::TESTNET => Network::Testnet, _ => panic!(SerializationError::Parse( - "bad sprout shielded addr version/type", + "bad sprout incoming viewing key network", )), }; @@ -251,13 +323,37 @@ mod tests { let receiving_key = ReceivingKey::from(spending_key); - let transmission_key = TransmissionKey::from(&receiving_key); + let _transmission_key = TransmissionKey::from(&receiving_key); } } #[cfg(test)] proptest! { + #[test] + fn spending_key_roundtrip(sk in any::()) { + + let mut data = Vec::new(); + + sk.zcash_serialize(&mut data).expect("sprout spending keyshould serialize"); + + let sk2 = SpendingKey::zcash_deserialize(&data[..]).expect("randomized sprout spending key should deserialize"); + + prop_assert_eq![sk, sk2]; + + } + + #[test] + fn spending_key_string_roundtrip(sk in any::()) { + + let string = sk.to_string(); + + let sk2 = string.parse::().unwrap(); + + prop_assert_eq![sk, sk2]; + + } + #[test] fn incoming_viewing_key_roundtrip(ivk in any::()) {