diff --git a/zcash_client_backend/src/data_api.rs b/zcash_client_backend/src/data_api.rs index abad10487..6a0fbb5e9 100644 --- a/zcash_client_backend/src/data_api.rs +++ b/zcash_client_backend/src/data_api.rs @@ -7,7 +7,7 @@ use std::fmt::Debug; use zcash_primitives::{ block::BlockHash, consensus::BlockHeight, - memo::Memo, + memo::{Memo, MemoBytes}, merkle_tree::{CommitmentTree, IncrementalWitness}, primitives::{Nullifier, PaymentAddress}, sapling::Node, @@ -137,11 +137,11 @@ pub trait WalletRead { anchor_height: BlockHeight, ) -> Result; - /// Returns the memo for a note, if it is known and a valid UTF-8 string. + /// Returns the memo for a note. /// - /// This will return `Ok(None)` if the note identifier does not appear in the - /// database as a known note ID. - fn get_memo_as_utf8(&self, id_note: Self::NoteRef) -> Result, Self::Error>; + /// Implementations of this method must return an error if the note identifier + /// does not appear in the backing data store. + fn get_memo(&self, id_note: Self::NoteRef) -> Result; /// Returns the note commitment tree at the specified block height. fn get_commitment_tree( @@ -199,7 +199,7 @@ pub struct SentTransaction<'a> { pub account: AccountId, pub recipient_address: &'a RecipientAddress, pub value: Amount, - pub memo: Option, + pub memo: Option, } /// This trait encapsulates the write capabilities required to update stored @@ -259,6 +259,7 @@ pub mod testing { use zcash_primitives::{ block::BlockHash, consensus::BlockHeight, + memo::Memo, merkle_tree::{CommitmentTree, IncrementalWitness}, primitives::{Nullifier, PaymentAddress}, sapling::Node, @@ -342,8 +343,8 @@ pub mod testing { Ok(Amount::zero()) } - fn get_memo_as_utf8(&self, _id_note: Self::NoteRef) -> Result, Self::Error> { - Ok(None) + fn get_memo(&self, _id_note: Self::NoteRef) -> Result { + Ok(Memo::Empty) } fn get_commitment_tree( diff --git a/zcash_client_backend/src/data_api/wallet.rs b/zcash_client_backend/src/data_api/wallet.rs index 49f8f4297..0c472e000 100644 --- a/zcash_client_backend/src/data_api/wallet.rs +++ b/zcash_client_backend/src/data_api/wallet.rs @@ -3,7 +3,7 @@ use std::fmt::Debug; use zcash_primitives::{ consensus::{self, BranchId, NetworkUpgrade}, - memo::Memo, + memo::MemoBytes, prover::TxProver, transaction::{ builder::Builder, @@ -155,7 +155,7 @@ pub fn create_spend_to_address( extsk: &ExtendedSpendingKey, to: &RecipientAddress, value: Amount, - memo: Option, + memo: Option, ovk_policy: OvkPolicy, ) -> Result where diff --git a/zcash_client_backend/src/decrypt.rs b/zcash_client_backend/src/decrypt.rs index 17793235a..4b82bea92 100644 --- a/zcash_client_backend/src/decrypt.rs +++ b/zcash_client_backend/src/decrypt.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use zcash_primitives::{ consensus::{self, BlockHeight}, - memo::Memo, + memo::MemoBytes, note_encryption::{try_sapling_note_decryption, try_sapling_output_recovery}, primitives::{Note, PaymentAddress}, transaction::Transaction, @@ -23,8 +23,8 @@ pub struct DecryptedOutput { pub account: AccountId, /// The address the note was sent to. pub to: PaymentAddress, - /// The memo included with the note. - pub memo: Memo, + /// The memo bytes included with the note. + pub memo: MemoBytes, /// True if this output was recovered using an [`OutgoingViewingKey`], meaning that /// this is a logical output of the transaction. /// diff --git a/zcash_client_backend/src/welding_rig.rs b/zcash_client_backend/src/welding_rig.rs index ce18914d4..4dfb4be8e 100644 --- a/zcash_client_backend/src/welding_rig.rs +++ b/zcash_client_backend/src/welding_rig.rs @@ -275,7 +275,7 @@ mod tests { use zcash_primitives::{ consensus::{BlockHeight, Network}, constants::SPENDING_KEY_GENERATOR, - memo::Memo, + memo::MemoBytes, merkle_tree::CommitmentTree, note_encryption::SaplingNoteEncryption, primitives::{Note, Nullifier, SaplingIvk}, @@ -345,7 +345,7 @@ mod tests { Some(extfvk.fvk.ovk), note.clone(), to, - Memo::default(), + MemoBytes::default(), &mut rng, ); let cmu = note.cmu().to_repr().as_ref().to_owned(); diff --git a/zcash_client_sqlite/src/error.rs b/zcash_client_sqlite/src/error.rs index 0ae1fb5b2..0cc1754f8 100644 --- a/zcash_client_sqlite/src/error.rs +++ b/zcash_client_sqlite/src/error.rs @@ -36,7 +36,7 @@ pub enum SqliteClientError { Io(std::io::Error), /// A received memo cannot be interpreted as a UTF-8 string. - InvalidMemo(std::str::Utf8Error), + InvalidMemo(zcash_primitives::memo::Error), /// Wrapper for errors from zcash_client_backend BackendError(data_api::error::Error), @@ -98,6 +98,12 @@ impl From for SqliteClientError { } } +impl From for SqliteClientError { + fn from(e: zcash_primitives::memo::Error) -> Self { + SqliteClientError::InvalidMemo(e) + } +} + impl From> for SqliteClientError { fn from(e: data_api::error::Error) -> Self { SqliteClientError::BackendError(e) diff --git a/zcash_client_sqlite/src/lib.rs b/zcash_client_sqlite/src/lib.rs index 9fc3d1849..5d411a22b 100644 --- a/zcash_client_sqlite/src/lib.rs +++ b/zcash_client_sqlite/src/lib.rs @@ -33,6 +33,7 @@ use rusqlite::{Connection, Statement, NO_PARAMS}; use zcash_primitives::{ block::BlockHash, consensus::{self, BlockHeight}, + memo::Memo, merkle_tree::{CommitmentTree, IncrementalWitness}, primitives::{Nullifier, PaymentAddress}, sapling::Node, @@ -205,10 +206,10 @@ impl WalletRead for WalletDB

{ wallet::get_balance_at(self, account, anchor_height) } - fn get_memo_as_utf8(&self, id_note: Self::NoteRef) -> Result, Self::Error> { + fn get_memo(&self, id_note: Self::NoteRef) -> Result { match id_note { - NoteId::SentNoteId(id_note) => wallet::get_sent_memo_as_utf8(self, id_note), - NoteId::ReceivedNoteId(id_note) => wallet::get_received_memo_as_utf8(self, id_note), + NoteId::SentNoteId(id_note) => wallet::get_sent_memo(self, id_note), + NoteId::ReceivedNoteId(id_note) => wallet::get_received_memo(self, id_note), } } @@ -317,8 +318,8 @@ impl<'a, P: consensus::Parameters> WalletRead for DataConnStmtCache<'a, P> { self.wallet_db.get_balance_at(account, anchor_height) } - fn get_memo_as_utf8(&self, id_note: Self::NoteRef) -> Result, Self::Error> { - self.wallet_db.get_memo_as_utf8(id_note) + fn get_memo(&self, id_note: Self::NoteRef) -> Result { + self.wallet_db.get_memo(id_note) } fn get_commitment_tree( @@ -545,7 +546,7 @@ mod tests { use zcash_primitives::{ block::BlockHash, consensus::{BlockHeight, Network, NetworkUpgrade, Parameters}, - memo::Memo, + memo::MemoBytes, note_encryption::SaplingNoteEncryption, primitives::{Note, Nullifier, PaymentAddress}, transaction::components::Amount, @@ -602,7 +603,7 @@ mod tests { Some(extfvk.fvk.ovk), note.clone(), to, - Memo::default(), + MemoBytes::default(), &mut rng, ); let cmu = note.cmu().to_repr().as_ref().to_vec(); @@ -662,7 +663,7 @@ mod tests { Some(extfvk.fvk.ovk), note.clone(), to, - Memo::default(), + MemoBytes::default(), &mut rng, ); let cmu = note.cmu().to_repr().as_ref().to_vec(); @@ -690,7 +691,7 @@ mod tests { Some(extfvk.fvk.ovk), note.clone(), change_addr, - Memo::default(), + MemoBytes::default(), &mut rng, ); let cmu = note.cmu().to_repr().as_ref().to_vec(); diff --git a/zcash_client_sqlite/src/wallet.rs b/zcash_client_sqlite/src/wallet.rs index 755de07d8..ff6b0c283 100644 --- a/zcash_client_sqlite/src/wallet.rs +++ b/zcash_client_sqlite/src/wallet.rs @@ -3,11 +3,12 @@ use ff::PrimeField; use rusqlite::{params, OptionalExtension, ToSql, NO_PARAMS}; use std::collections::HashMap; +use std::convert::TryFrom; use zcash_primitives::{ block::BlockHash, consensus::{self, BlockHeight, NetworkUpgrade}, - memo::Memo, + memo::{Memo, MemoBytes}, merkle_tree::{CommitmentTree, IncrementalWitness}, primitives::{Note, Nullifier, PaymentAddress}, sapling::Node, @@ -41,7 +42,7 @@ pub trait ShieldedOutput { fn account(&self) -> AccountId; fn to(&self) -> &PaymentAddress; fn note(&self) -> &Note; - fn memo(&self) -> Option<&Memo>; + fn memo(&self) -> Option<&MemoBytes>; fn is_change(&self) -> Option; fn nullifier(&self) -> Option; } @@ -59,7 +60,7 @@ impl ShieldedOutput for WalletShieldedOutput { fn note(&self) -> &Note { &self.note } - fn memo(&self) -> Option<&Memo> { + fn memo(&self) -> Option<&MemoBytes> { None } fn is_change(&self) -> Option { @@ -84,7 +85,7 @@ impl ShieldedOutput for DecryptedOutput { fn note(&self) -> &Note { &self.note } - fn memo(&self) -> Option<&Memo> { + fn memo(&self) -> Option<&MemoBytes> { Some(&self.memo) } fn is_change(&self) -> Option { @@ -273,32 +274,24 @@ pub fn get_balance_at

( /// use zcash_client_sqlite::{ /// NoteId, /// WalletDB, -/// wallet::get_received_memo_as_utf8, +/// wallet::get_received_memo, /// }; /// /// let data_file = NamedTempFile::new().unwrap(); /// let db = WalletDB::for_path(data_file, Network::TestNetwork).unwrap(); -/// let memo = get_received_memo_as_utf8(&db, 27); +/// let memo = get_received_memo(&db, 27); /// ``` -pub fn get_received_memo_as_utf8

( - wdb: &WalletDB

, - id_note: i64, -) -> Result, SqliteClientError> { - let memo: Vec<_> = wdb.conn.query_row( +pub fn get_received_memo

(wdb: &WalletDB

, id_note: i64) -> Result { + let memo_bytes: Vec<_> = wdb.conn.query_row( "SELECT memo FROM received_notes WHERE id_note = ?", &[id_note], |row| row.get(0), )?; - match Memo::from_bytes(&memo) { - Some(memo) => match memo.to_utf8() { - Some(Ok(res)) => Ok(Some(res)), - Some(Err(e)) => Err(SqliteClientError::InvalidMemo(e)), - None => Ok(None), - }, - None => Ok(None), - } + MemoBytes::from_bytes(&memo_bytes) + .and_then(Memo::try_from) + .map_err(SqliteClientError::from) } /// Returns the memo for a sent note, if it is known and a valid UTF-8 string. @@ -314,32 +307,24 @@ pub fn get_received_memo_as_utf8

( /// use zcash_client_sqlite::{ /// NoteId, /// WalletDB, -/// wallet::get_sent_memo_as_utf8, +/// wallet::get_sent_memo, /// }; /// /// let data_file = NamedTempFile::new().unwrap(); /// let db = WalletDB::for_path(data_file, Network::TestNetwork).unwrap(); -/// let memo = get_sent_memo_as_utf8(&db, 12); +/// let memo = get_sent_memo(&db, 12); /// ``` -pub fn get_sent_memo_as_utf8

( - wdb: &WalletDB

, - id_note: i64, -) -> Result, SqliteClientError> { - let memo: Vec<_> = wdb.conn.query_row( +pub fn get_sent_memo

(wdb: &WalletDB

, id_note: i64) -> Result { + let memo_bytes: Vec<_> = wdb.conn.query_row( "SELECT memo FROM sent_notes WHERE id_note = ?", &[id_note], |row| row.get(0), )?; - match Memo::from_bytes(&memo) { - Some(memo) => match memo.to_utf8() { - Some(Ok(res)) => Ok(Some(res)), - Some(Err(e)) => Err(SqliteClientError::InvalidMemo(e)), - None => Ok(None), - }, - None => Ok(None), - } + MemoBytes::from_bytes(&memo_bytes) + .and_then(Memo::try_from) + .map_err(SqliteClientError::from) } pub fn block_height_extrema

( @@ -605,7 +590,7 @@ pub fn put_received_note<'a, P, T: ShieldedOutput>( let diversifier = output.to().diversifier().0.to_vec(); let value = output.note().value as i64; let rcm = rcm.as_ref(); - let memo = output.memo().map(|m| m.as_bytes()); + let memo = output.memo().map(|m| m.as_slice()); let is_change = output.is_change(); let tx = tx_ref; let output_index = output.index() as i64; @@ -694,7 +679,7 @@ pub fn put_sent_note<'a, P: consensus::Parameters>( account, to_str, value, - &output.memo.as_bytes(), + &output.memo.as_slice(), tx_ref, output_index ])? == 0 @@ -722,7 +707,7 @@ pub fn insert_sent_note<'a, P: consensus::Parameters>( account: AccountId, to: &RecipientAddress, value: Amount, - memo: &Option, + memo: &Option, ) -> Result<(), SqliteClientError> { let to_str = to.encode(&stmts.wallet_db.params); let ivalue: i64 = value.into(); @@ -732,7 +717,7 @@ pub fn insert_sent_note<'a, P: consensus::Parameters>( account.0, to_str, ivalue, - memo.as_ref().map(|m| m.as_bytes().to_vec()), + memo.as_ref().map(|m| m.as_slice().to_vec()), ])?; Ok(()) diff --git a/zcash_primitives/benches/note_decryption.rs b/zcash_primitives/benches/note_decryption.rs index 255359b03..6eb00bba3 100644 --- a/zcash_primitives/benches/note_decryption.rs +++ b/zcash_primitives/benches/note_decryption.rs @@ -3,7 +3,7 @@ use ff::Field; use rand_core::OsRng; use zcash_primitives::{ consensus::{NetworkUpgrade::Canopy, Parameters, TEST_NETWORK}, - memo::Memo, + memo::MemoBytes, note_encryption::{try_sapling_note_decryption, SaplingNoteEncryption}, primitives::{Diversifier, PaymentAddress, SaplingIvk, ValueCommitment}, transaction::components::{OutputDescription, GROTH_PROOF_SIZE}, @@ -36,7 +36,7 @@ fn bench_note_decryption(c: &mut Criterion) { let note = pa.create_note(value, rseed).unwrap(); let cmu = note.cmu(); - let mut ne = SaplingNoteEncryption::new(None, note, pa, Memo::default(), &mut rng); + let mut ne = SaplingNoteEncryption::new(None, note, pa, MemoBytes::default(), &mut rng); let ephemeral_key = ne.epk().clone().into(); let enc_ciphertext = ne.encrypt_note_plaintext(); let out_ciphertext = ne.encrypt_outgoing_plaintext(&cv, &cmu); diff --git a/zcash_primitives/src/memo.rs b/zcash_primitives/src/memo.rs index f2ba1f97c..600a94efd 100644 --- a/zcash_primitives/src/memo.rs +++ b/zcash_primitives/src/memo.rs @@ -1,6 +1,10 @@ //! Structs for handling encrypted memos. +use std::cmp::Ordering; +use std::convert::{TryFrom, TryInto}; +use std::error; use std::fmt; +use std::ops::Deref; use std::str; /// Format a byte array as a colon-delimited hex string. @@ -24,96 +28,259 @@ where Ok(()) } +/// Errors that may result from attempting to construct an invalid memo. +#[derive(Debug, PartialEq)] +pub enum Error { + InvalidUtf8(std::str::Utf8Error), + TooLong(usize), +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Error::InvalidUtf8(e) => write!(f, "Invalid UTF-8: {}", e), + Error::TooLong(n) => write!(f, "Memo length {} is larger than maximum of 512", n), + } + } +} + +impl error::Error for Error {} + +/// The unencrypted memo bytes received alongside a shielded note in a Zcash transaction. +#[derive(Clone)] +pub struct MemoBytes(pub(crate) Box<[u8; 512]>); + +impl fmt::Debug for MemoBytes { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "MemoBytes(")?; + fmt_colon_delimited_hex(f, &self.0[..])?; + write!(f, ")") + } +} + +impl Default for MemoBytes { + fn default() -> Self { + let mut bytes = [0u8; 512]; + bytes[0] = 0xF6; + MemoBytes(Box::new(bytes)) + } +} + +impl PartialEq for MemoBytes { + fn eq(&self, rhs: &MemoBytes) -> bool { + self.0[..] == rhs.0[..] + } +} + +impl Eq for MemoBytes {} + +impl PartialOrd for MemoBytes { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for MemoBytes { + fn cmp(&self, rhs: &Self) -> Ordering { + self.0[..].cmp(&rhs.0[..]) + } +} + +impl MemoBytes { + /// Creates a `MemoBytes` from a slice. + /// + /// Returns an error if the provided slice is longer than 512 bytes. Slices shorter + /// than 512 bytes are padded with null bytes. + pub fn from_bytes(bytes: &[u8]) -> Result { + if bytes.len() > 512 { + return Err(Error::TooLong(bytes.len())); + } + + let mut memo = [0u8; 512]; + memo[..bytes.len()].copy_from_slice(bytes); + Ok(MemoBytes(Box::new(memo))) + } + + /// Returns the raw byte array containing the memo bytes, including null padding. + pub fn as_array(&self) -> &[u8; 512] { + &self.0 + } + + /// Returns a slice of the raw bytes, excluding null padding. + pub fn as_slice(&self) -> &[u8] { + let first_null = self + .0 + .iter() + .enumerate() + .rev() + .find(|(_, &b)| b != 0) + .map(|(i, _)| i + 1) + .unwrap_or_default(); + + &self.0[..first_null] + } +} + +/// Type-safe wrapper around String to enforce memo length requirements. +#[derive(Clone, PartialEq)] +pub struct TextMemo(String); + +impl From for String { + fn from(memo: TextMemo) -> String { + memo.0 + } +} + +impl Deref for TextMemo { + type Target = str; + + #[inline] + fn deref(&self) -> &str { + self.0.deref() + } +} + /// An unencrypted memo received alongside a shielded note in a Zcash transaction. #[derive(Clone)] -pub struct Memo(pub(crate) [u8; 512]); +pub enum Memo { + /// An empty memo field. + Empty, + /// A memo field containing a UTF-8 string. + Text(TextMemo), + /// Some unknown memo format from ✨*the future*✨ that we can't parse. + Future(MemoBytes), + /// A memo field containing arbitrary bytes. + Arbitrary(Box<[u8; 511]>), +} impl fmt::Debug for Memo { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Memo(")?; - match self.to_utf8() { - Some(Ok(memo)) => write!(f, "\"{}\"", memo)?, - _ => fmt_colon_delimited_hex(f, &self.0[..])?, + match self { + Memo::Empty => write!(f, "Memo::Empty"), + Memo::Text(memo) => write!(f, "Memo::Text(\"{}\")", memo.0), + Memo::Future(bytes) => write!(f, "Memo::Future({:0x})", bytes.0[0]), + Memo::Arbitrary(bytes) => { + write!(f, "Memo::Arbitrary(")?; + fmt_colon_delimited_hex(f, &bytes[..])?; + write!(f, ")") + } } - write!(f, ")") } } impl Default for Memo { fn default() -> Self { - // Empty memo field indication per ZIP 302 - let mut memo = [0u8; 512]; - memo[0] = 0xF6; - Memo(memo) + Memo::Empty } } impl PartialEq for Memo { fn eq(&self, rhs: &Memo) -> bool { - self.0[..] == rhs.0[..] + match (self, rhs) { + (Memo::Empty, Memo::Empty) => true, + (Memo::Text(a), Memo::Text(b)) => a == b, + (Memo::Future(a), Memo::Future(b)) => a.0[..] == b.0[..], + (Memo::Arbitrary(a), Memo::Arbitrary(b)) => a[..] == b[..], + _ => false, + } + } +} + +impl TryFrom for Memo { + type Error = Error; + + /// Parses a `Memo` from its ZIP 302 serialization. + /// + /// Returns an error if the provided slice does not represent a valid `Memo` (for + /// example, if the slice is not 512 bytes, or the encoded `Memo` is non-canonical). + fn try_from(bytes: MemoBytes) -> Result { + match bytes.0[0] { + 0xF6 if bytes.0.iter().skip(1).all(|&b| b == 0) => Ok(Memo::Empty), + 0xFF => Ok(Memo::Arbitrary(Box::new(bytes.0[1..].try_into().unwrap()))), + b if b <= 0xF4 => str::from_utf8(bytes.as_slice()) + .map(|r| Memo::Text(TextMemo(r.to_owned()))) + .map_err(Error::InvalidUtf8), + _ => Ok(Memo::Future(bytes)), + } + } +} + +impl From for MemoBytes { + /// Serializes the `Memo` per ZIP 302. + fn from(memo: Memo) -> Self { + match memo { + // Small optimisation to avoid a clone + Memo::Future(memo) => memo, + memo => (&memo).into(), + } + } +} + +impl From<&Memo> for MemoBytes { + /// Serializes the `Memo` per ZIP 302. + fn from(memo: &Memo) -> Self { + match memo { + Memo::Empty => MemoBytes::default(), + Memo::Text(s) => { + let mut bytes = [0u8; 512]; + let s_bytes = s.0.as_bytes(); + // s_bytes.len() is guaranteed to be <= 512 + bytes[..s_bytes.len()].copy_from_slice(s_bytes); + MemoBytes(Box::new(bytes)) + } + Memo::Future(memo) => memo.clone(), + Memo::Arbitrary(arb) => { + let mut bytes = [0u8; 512]; + bytes[0] = 0xFF; + bytes[1..].copy_from_slice(arb.as_ref()); + MemoBytes(Box::new(bytes)) + } + } } } impl Memo { - /// Returns a `Memo` containing the given slice, appending with zero bytes if - /// necessary, or `None` if the slice is too long. If the slice is empty, - /// `Memo::default` is returned. - pub fn from_bytes(memo: &[u8]) -> Option { - if memo.is_empty() { - Some(Memo::default()) - } else if memo.len() <= 512 { - let mut data = [0; 512]; - data[0..memo.len()].copy_from_slice(memo); - Some(Memo(data)) - } else { - // memo is too long - None - } + /// Parses a `Memo` from its ZIP 302 serialization. + /// + /// Returns an error if the provided slice does not represent a valid `Memo` (for + /// example, if the slice is not 512 bytes, or the encoded `Memo` is non-canonical). + pub fn from_bytes(bytes: &[u8]) -> Result { + MemoBytes::from_bytes(bytes).and_then(TryFrom::try_from) } - /// Returns the underlying bytes of the `Memo`. - pub fn as_bytes(&self) -> &[u8] { - &self.0[..] - } - - /// Returns: - /// - `None` if the memo is not text - /// - `Some(Ok(memo))` if the memo contains a valid UTF-8 string - /// - `Some(Err(e))` if the memo contains invalid UTF-8 - pub fn to_utf8(&self) -> Option> { - // Check if it is a text or binary memo - if self.0[0] < 0xF5 { - // Check if it is valid UTF8 - Some(str::from_utf8(&self.0).map(|memo| { - // Drop trailing zeroes - memo.trim_end_matches(char::from(0)).to_owned() - })) - } else { - None - } + /// Serializes the `Memo` per ZIP 302. + pub fn encode(&self) -> MemoBytes { + self.into() } } impl str::FromStr for Memo { - type Err = (); + type Err = Error; /// Returns a `Memo` containing the given string, or an error if the string is too long. fn from_str(memo: &str) -> Result { - Memo::from_bytes(memo.as_bytes()).ok_or(()) + if memo.is_empty() { + Ok(Memo::Empty) + } else if memo.len() <= 512 { + Ok(Memo::Text(TextMemo(memo.to_owned()))) + } else { + Err(Error::TooLong(memo.len())) + } } } #[cfg(test)] mod tests { + use std::convert::TryInto; use std::str::FromStr; - use super::Memo; + use super::{Error, Memo, MemoBytes}; #[test] fn memo_from_str() { assert_eq!( - Memo::from_str("").unwrap(), - Memo([ + Memo::from_str("").unwrap().encode(), + MemoBytes(Box::new([ 0xf6, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, @@ -151,7 +318,7 @@ mod tests { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 - ]) + ])) ); assert_eq!( Memo::from_str( @@ -163,8 +330,9 @@ mod tests { meeeeeeeeeeeeeeeeeeemooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo \ but it's just short enough" ) - .unwrap(), - Memo([ + .unwrap() + .encode(), + MemoBytes(Box::new([ 0x74, 0x68, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, 0x69, @@ -202,7 +370,7 @@ mod tests { 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x20, 0x62, 0x75, 0x74, 0x20, 0x69, 0x74, 0x27, 0x73, 0x20, 0x6a, 0x75, 0x73, 0x74, 0x20, 0x73, 0x68, 0x6f, 0x72, 0x74, 0x20, 0x65, 0x6e, 0x6f, 0x75, 0x67, 0x68 - ]) + ])) ); assert_eq!( Memo::from_str( @@ -214,14 +382,27 @@ mod tests { meeeeeeeeeeeeeeeeeeemooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo \ but it's now a bit too long" ), - Err(()) + Err(Error::TooLong(513)) ); } #[test] - fn memo_to_utf8() { - let memo = Memo::from_str("Test memo").unwrap(); - assert_eq!(memo.to_utf8(), Some(Ok("Test memo".to_owned()))); - assert_eq!(Memo::default().to_utf8(), None); + fn future_memo() { + let bytes = [0xFE; 512]; + assert_eq!( + MemoBytes::from_bytes(&bytes).unwrap().try_into(), + Ok(Memo::Future(MemoBytes(Box::new(bytes)))) + ); + } + + #[test] + fn arbitrary_memo() { + let bytes = [42; 511]; + let memo = Memo::Arbitrary(Box::new(bytes)); + let raw = memo.encode(); + let encoded = raw.as_array(); + assert_eq!(encoded[0], 0xFF); + assert_eq!(encoded[1..], bytes[..]); + assert_eq!(MemoBytes::from_bytes(encoded).unwrap().try_into(), Ok(memo)); } } diff --git a/zcash_primitives/src/note_encryption.rs b/zcash_primitives/src/note_encryption.rs index 818413f6d..7dfa57e19 100644 --- a/zcash_primitives/src/note_encryption.rs +++ b/zcash_primitives/src/note_encryption.rs @@ -2,7 +2,7 @@ use crate::{ consensus::{self, BlockHeight, NetworkUpgrade::Canopy, ZIP212_GRACE_PERIOD}, - memo::Memo, + memo::MemoBytes, primitives::{Diversifier, Note, PaymentAddress, Rseed, SaplingIvk}, }; use blake2b_simd::{Hash as Blake2bHash, Params as Blake2bParams}; @@ -113,7 +113,7 @@ pub fn prf_ock( /// use rand_core::OsRng; /// use zcash_primitives::{ /// keys::{OutgoingViewingKey, prf_expand}, -/// memo::Memo, +/// memo::MemoBytes, /// note_encryption::SaplingNoteEncryption, /// primitives::{Diversifier, PaymentAddress, Rseed, ValueCommitment}, /// }; @@ -135,7 +135,7 @@ pub fn prf_ock( /// let note = to.create_note(value, Rseed::BeforeZip212(rcm)).unwrap(); /// let cmu = note.cmu(); /// -/// let mut enc = SaplingNoteEncryption::new(ovk, note, to, Memo::default(), &mut rng); +/// let mut enc = SaplingNoteEncryption::new(ovk, note, to, MemoBytes::default(), &mut rng); /// let encCiphertext = enc.encrypt_note_plaintext(); /// let outCiphertext = enc.encrypt_outgoing_plaintext(&cv.commitment().into(), &cmu); /// ``` @@ -144,7 +144,7 @@ pub struct SaplingNoteEncryption { esk: jubjub::Fr, note: Note, to: PaymentAddress, - memo: Memo, + memo: MemoBytes, /// `None` represents the `ovk = ⊥` case. ovk: Option, rng: R, @@ -159,7 +159,7 @@ impl SaplingNoteEncryption { ovk: Option, note: Note, to: PaymentAddress, - memo: Memo, + memo: MemoBytes, rng: R, ) -> Self { Self::new_internal(ovk, note, to, memo, rng) @@ -171,7 +171,7 @@ impl SaplingNoteEncryption { ovk: Option, note: Note, to: PaymentAddress, - memo: Memo, + memo: MemoBytes, mut rng: R, ) -> Self { let esk = note.generate_or_derive_esk_internal(&mut rng); @@ -222,7 +222,7 @@ impl SaplingNoteEncryption { input[20..COMPACT_NOTE_SIZE].copy_from_slice(&rseed); } } - input[COMPACT_NOTE_SIZE..NOTE_PLAINTEXT_SIZE].copy_from_slice(&self.memo.0); + input[COMPACT_NOTE_SIZE..NOTE_PLAINTEXT_SIZE].copy_from_slice(self.memo.as_array()); let mut output = [0u8; ENC_CIPHERTEXT_SIZE]; assert_eq!( @@ -362,7 +362,7 @@ pub fn try_sapling_note_decryption( epk: &jubjub::ExtendedPoint, cmu: &bls12_381::Scalar, enc_ciphertext: &[u8], -) -> Option<(Note, PaymentAddress, Memo)> { +) -> Option<(Note, PaymentAddress, MemoBytes)> { assert_eq!(enc_ciphertext.len(), ENC_CIPHERTEXT_SIZE); let shared_secret = sapling_ka_agree(&ivk.0, &epk); @@ -384,10 +384,10 @@ pub fn try_sapling_note_decryption( let (note, to) = parse_note_plaintext_without_memo(params, height, ivk, epk, cmu, &plaintext)?; - let mut memo = [0u8; 512]; - memo.copy_from_slice(&plaintext[COMPACT_NOTE_SIZE..NOTE_PLAINTEXT_SIZE]); + // Memo is the correct length by definition. + let memo = MemoBytes::from_bytes(&plaintext[COMPACT_NOTE_SIZE..NOTE_PLAINTEXT_SIZE]).unwrap(); - Some((note, to, Memo(memo))) + Some((note, to, memo)) } /// Trial decryption of the compact note plaintext by the recipient for light clients. @@ -436,7 +436,7 @@ pub fn try_sapling_output_recovery_with_ock( epk: &jubjub::ExtendedPoint, enc_ciphertext: &[u8], out_ciphertext: &[u8], -) -> Option<(Note, PaymentAddress, Memo)> { +) -> Option<(Note, PaymentAddress, MemoBytes)> { assert_eq!(enc_ciphertext.len(), ENC_CIPHERTEXT_SIZE); assert_eq!(out_ciphertext.len(), OUT_CIPHERTEXT_SIZE); @@ -502,8 +502,7 @@ pub fn try_sapling_output_recovery_with_ock( Rseed::AfterZip212(r) }; - let mut memo = [0u8; 512]; - memo.copy_from_slice(&plaintext[COMPACT_NOTE_SIZE..NOTE_PLAINTEXT_SIZE]); + let memo = MemoBytes::from_bytes(&plaintext[COMPACT_NOTE_SIZE..NOTE_PLAINTEXT_SIZE]).unwrap(); let diversifier = Diversifier(d); if (diversifier.g_d()? * esk).to_bytes() != epk.to_bytes() { @@ -525,7 +524,7 @@ pub fn try_sapling_output_recovery_with_ock( } } - Some((note, to, Memo(memo))) + Some((note, to, memo)) } /// Recovery of the full note plaintext by the sender. @@ -545,7 +544,7 @@ pub fn try_sapling_output_recovery( epk: &jubjub::ExtendedPoint, enc_ciphertext: &[u8], out_ciphertext: &[u8], -) -> Option<(Note, PaymentAddress, Memo)> { +) -> Option<(Note, PaymentAddress, MemoBytes)> { try_sapling_output_recovery_with_ock::

( params, height, @@ -582,7 +581,7 @@ mod tests { Parameters, TEST_NETWORK, ZIP212_GRACE_PERIOD, }, keys::OutgoingViewingKey, - memo::Memo, + memo::MemoBytes, primitives::{Diversifier, PaymentAddress, Rseed, SaplingIvk, ValueCommitment}, util::generate_random_rseed, }; @@ -682,7 +681,8 @@ mod tests { let cmu = note.cmu(); let ovk = OutgoingViewingKey([0; 32]); - let mut ne = SaplingNoteEncryption::new(Some(ovk), note, pa, Memo([0; 512]), &mut rng); + let mut ne = + SaplingNoteEncryption::new(Some(ovk), note, pa, MemoBytes::default(), &mut rng); let epk = ne.epk().clone().into(); let enc_ciphertext = ne.encrypt_note_plaintext(); let out_ciphertext = ne.encrypt_outgoing_plaintext(&cv, &cmu); @@ -1671,7 +1671,7 @@ mod tests { Some((decrypted_note, decrypted_to, decrypted_memo)) => { assert_eq!(decrypted_note, note); assert_eq!(decrypted_to, to); - assert_eq!(&decrypted_memo.0[..], &tv.memo[..]); + assert_eq!(&decrypted_memo.as_array()[..], &tv.memo[..]); } None => panic!("Note decryption failed"), } @@ -1704,7 +1704,7 @@ mod tests { Some((decrypted_note, decrypted_to, decrypted_memo)) => { assert_eq!(decrypted_note, note); assert_eq!(decrypted_to, to); - assert_eq!(&decrypted_memo.0[..], &tv.memo[..]); + assert_eq!(&decrypted_memo.as_array()[..], &tv.memo[..]); } None => panic!("Output recovery failed"), } @@ -1713,7 +1713,13 @@ mod tests { // Test encryption // - let mut ne = SaplingNoteEncryption::new(Some(ovk), note, to, Memo(tv.memo), OsRng); + let mut ne = SaplingNoteEncryption::new( + Some(ovk), + note, + to, + MemoBytes::from_bytes(&tv.memo).unwrap(), + OsRng, + ); // Swap in the ephemeral keypair from the test vectors ne.esk = esk; ne.epk = epk.into_subgroup().unwrap(); diff --git a/zcash_primitives/src/transaction/builder.rs b/zcash_primitives/src/transaction/builder.rs index 99943f35d..06c2595a5 100644 --- a/zcash_primitives/src/transaction/builder.rs +++ b/zcash_primitives/src/transaction/builder.rs @@ -14,7 +14,7 @@ use crate::{ consensus::{self, BlockHeight}, keys::OutgoingViewingKey, legacy::TransparentAddress, - memo::Memo, + memo::MemoBytes, merkle_tree::MerklePath, note_encryption::SaplingNoteEncryption, primitives::{Diversifier, Note, PaymentAddress}, @@ -99,7 +99,7 @@ pub struct SaplingOutput { ovk: Option, to: PaymentAddress, note: Note, - memo: Memo, + memo: MemoBytes, } impl SaplingOutput { @@ -110,7 +110,7 @@ impl SaplingOutput { ovk: Option, to: PaymentAddress, value: Amount, - memo: Option, + memo: Option, ) -> Result { Self::new_internal(params, height, rng, ovk, to, value, memo) } @@ -122,7 +122,7 @@ impl SaplingOutput { ovk: Option, to: PaymentAddress, value: Amount, - memo: Option, + memo: Option, ) -> Result { let g_d = to.g_d().ok_or(Error::InvalidAddress)?; if value.is_negative() { @@ -521,7 +521,7 @@ impl<'a, P: consensus::Parameters, R: RngCore> Builder<'a, P, R> { ovk: Option, to: PaymentAddress, value: Amount, - memo: Option, + memo: Option, ) -> Result<(), Error> { let output = SaplingOutput::new_internal( &self.params,