Enforce maximum zip321 payment count in TransactionRequest constructor.

This commit is contained in:
Kris Nuttycombe 2021-10-01 12:22:20 -06:00
parent d43a893c72
commit e30c5cd628
2 changed files with 33 additions and 19 deletions

View File

@ -172,16 +172,15 @@ where
R: Copy + Debug,
D: WalletWrite<Error = E, TxRef = R>,
{
let req = TransactionRequest {
payments: vec![Payment {
recipient_address: to.clone(),
amount,
memo,
label: None,
message: None,
other_params: vec![],
}],
};
let req = TransactionRequest::new(vec![Payment {
recipient_address: to.clone(),
amount,
memo,
label: None,
message: None,
other_params: vec![],
}])
.unwrap();
spend(
wallet_db,
@ -255,7 +254,7 @@ where
.and_then(|x| x.ok_or_else(|| Error::ScanRequired.into()))?;
let value = request
.payments
.payments()
.iter()
.map(|p| p.amount)
.sum::<Option<Amount>>()
@ -297,7 +296,7 @@ where
.map_err(Error::Builder)?;
}
for payment in &request.payments {
for payment in request.payments() {
match &payment.recipient_address {
RecipientAddress::Shielded(to) => builder
.add_sapling_output(
@ -321,7 +320,7 @@ where
let (tx, tx_metadata) = builder.build(&prover).map_err(Error::Builder)?;
let sent_outputs = request.payments.iter().enumerate().map(|(i, payment)| {
let sent_outputs = request.payments().iter().enumerate().map(|(i, payment)| {
let idx = match &payment.recipient_address {
// Sapling outputs are shuffled, so we need to look up where the output ended up.
RecipientAddress::Shielded(_) =>

View File

@ -22,11 +22,12 @@ use std::cmp::Ordering;
use crate::address::RecipientAddress;
/// Errors that may be produced in decoding of memos.
/// Errors that may be produced in decoding of payment requests.
#[derive(Debug)]
pub enum MemoError {
pub enum Zip321Error {
InvalidBase64(base64::DecodeError),
MemoBytesError(memo::Error),
TooManyPayments(usize),
}
/// Converts a [`MemoBytes`] value to a ZIP 321 compatible base64-encoded string.
@ -39,10 +40,10 @@ pub fn memo_to_base64(memo: &MemoBytes) -> String {
/// Parse a [`MemoBytes`] value from a ZIP 321 compatible base64-encoded string.
///
/// [`MemoBytes`]: zcash_primitives::memo::MemoBytes
pub fn memo_from_base64(s: &str) -> Result<MemoBytes, MemoError> {
pub fn memo_from_base64(s: &str) -> Result<MemoBytes, Zip321Error> {
base64::decode_config(s, base64::URL_SAFE_NO_PAD)
.map_err(MemoError::InvalidBase64)
.and_then(|b| MemoBytes::from_bytes(&b).map_err(MemoError::MemoBytesError))
.map_err(Zip321Error::InvalidBase64)
.and_then(|b| MemoBytes::from_bytes(&b).map_err(Zip321Error::MemoBytesError))
}
/// A single payment being requested.
@ -106,10 +107,24 @@ impl Payment {
/// payment value in the request.
#[derive(Debug, PartialEq)]
pub struct TransactionRequest {
pub payments: Vec<Payment>,
payments: Vec<Payment>,
}
impl TransactionRequest {
/// Constructs a new transaction request that obeys the ZIP-321 invariants
pub fn new(payments: Vec<Payment>) -> Result<TransactionRequest, Zip321Error> {
if payments.len() > 2109 {
Err(Zip321Error::TooManyPayments(payments.len()))
} else {
Ok(TransactionRequest { payments })
}
}
/// Returns the slice of payments that make up this request.
pub fn payments(&self) -> &[Payment] {
&self.payments[..]
}
/// A utility for use in tests to help check round-trip serialization properties.
#[cfg(any(test, feature = "test-dependencies"))]
pub(in crate::zip321) fn normalize<P: consensus::Parameters>(&mut self, params: &P) {