From 1dae1f49df9edc41183988da879fc6674cd9a0b9 Mon Sep 17 00:00:00 2001 From: Alfredo Garcia Date: Mon, 19 Jul 2021 20:05:36 -0300 Subject: [PATCH] implement Sum for Amount (#2500) * implement Sum for Amount * check overflows * add a `zero()` method to `Amount` * impl iter::Sum<&Amount> for Result> And modify the tests so they test both reference and value based sums. * use `try_fold()` * change error doc * use iter::repeat() * fix test Co-authored-by: teor --- zebra-chain/src/amount.rs | 143 ++++++++++++++++++++++++++++++++++---- 1 file changed, 131 insertions(+), 12 deletions(-) diff --git a/zebra-chain/src/amount.rs b/zebra-chain/src/amount.rs index a6de40bb7..8295f17bd 100644 --- a/zebra-chain/src/amount.rs +++ b/zebra-chain/src/amount.rs @@ -47,6 +47,14 @@ impl Amount { LittleEndian::write_i64(&mut buf, self.0); buf } + + /// Create a zero `Amount` + pub fn zero() -> Amount + where + C: Constraint, + { + 0.try_into().expect("an amount of 0 is always valid") + } } impl std::ops::Add> for Amount @@ -293,6 +301,31 @@ impl std::ops::Div for Amount { } } +impl std::iter::Sum> for Result> +where + C: Constraint, +{ + fn sum>>(iter: I) -> Self { + let sum = iter + .map(|a| a.0) + .try_fold(0i64, |acc, amount| acc.checked_add(amount)); + + match sum { + Some(sum) => Amount::try_from(sum), + None => Err(Error::SumOverflow), + } + } +} + +impl<'amt, C> std::iter::Sum<&'amt Amount> for Result> +where + C: Constraint + std::marker::Copy + 'amt, +{ + fn sum>>(iter: I) -> Self { + iter.copied().sum() + } +} + #[derive(thiserror::Error, Debug, displaydoc::Display, Clone, PartialEq)] #[allow(missing_docs)] /// Errors that can be returned when validating `Amount`s @@ -311,6 +344,9 @@ pub enum Error { MultiplicationOverflow { amount: i64, multiplier: u64 }, /// cannot divide amount {amount} by zero DivideByZero { amount: i64 }, + + /// i64 overflow when summing i64 amounts + SumOverflow, } /// Marker type for `Amount` that allows negative values. @@ -438,7 +474,7 @@ mod test { let one: Amount = 1.try_into()?; let neg_one: Amount = (-1).try_into()?; - let zero: Amount = 0.try_into()?; + let zero: Amount = Amount::zero(); let new_zero = one + neg_one; assert_eq!(zero, new_zero?); @@ -454,7 +490,7 @@ mod test { let one = Ok(one); let neg_one: Amount = (-1).try_into()?; - let zero: Amount = 0.try_into()?; + let zero: Amount = Amount::zero(); let new_zero = one + neg_one; assert_eq!(zero, new_zero?); @@ -470,7 +506,7 @@ mod test { let neg_one: Amount = (-1).try_into()?; let neg_one = Ok(neg_one); - let zero: Amount = 0.try_into()?; + let zero: Amount = Amount::zero(); let new_zero = one + neg_one; assert_eq!(zero, new_zero?); @@ -487,7 +523,7 @@ mod test { let neg_one: Amount = (-1).try_into()?; let neg_one = Ok(neg_one); - let zero: Amount = 0.try_into()?; + let zero: Amount = Amount::zero(); let new_zero = one.and_then(|one| one + neg_one); assert_eq!(zero, new_zero?); @@ -503,7 +539,7 @@ mod test { let neg_one: Amount = (-1).try_into()?; let mut neg_one = Ok(neg_one); - let zero: Amount = 0.try_into()?; + let zero: Amount = Amount::zero(); neg_one += one; let new_zero = neg_one; @@ -517,7 +553,7 @@ mod test { zebra_test::init(); let one: Amount = 1.try_into()?; - let zero: Amount = 0.try_into()?; + let zero: Amount = Amount::zero(); let neg_one: Amount = (-1).try_into()?; let new_neg_one = zero - one; @@ -533,7 +569,7 @@ mod test { let one: Amount = 1.try_into()?; let one = Ok(one); - let zero: Amount = 0.try_into()?; + let zero: Amount = Amount::zero(); let neg_one: Amount = (-1).try_into()?; let new_neg_one = zero - one; @@ -548,7 +584,7 @@ mod test { zebra_test::init(); let one: Amount = 1.try_into()?; - let zero: Amount = 0.try_into()?; + let zero: Amount = Amount::zero(); let zero = Ok(zero); let neg_one: Amount = (-1).try_into()?; @@ -564,7 +600,7 @@ mod test { zebra_test::init(); let one: Amount = 1.try_into()?; - let zero: Amount = 0.try_into()?; + let zero: Amount = Amount::zero(); let mut zero = Ok(zero); let neg_one: Amount = (-1).try_into()?; @@ -581,7 +617,7 @@ mod test { zebra_test::init(); let one = Amount::::try_from(1)?; - let zero = Amount::::try_from(0)?; + let zero: Amount = Amount::zero(); (zero - one.constrain()).expect("should allow negative"); (zero.constrain() - one).expect_err("shouldn't allow negative"); @@ -630,7 +666,7 @@ mod test { let one = Amount::::try_from(1)?; let another_one = Amount::::try_from(1)?; - let zero = Amount::::try_from(0)?; + let zero: Amount = Amount::zero(); let hash_set: HashSet, RandomState> = [one].iter().cloned().collect(); assert_eq!(hash_set.len(), 1); @@ -672,7 +708,7 @@ mod test { C1: Constraint + Debug, C2: Constraint + Debug, { - let zero = Amount::::try_from(0)?; + let zero: Amount = Amount::zero(); let one = Amount::::try_from(1)?; let another_one = Amount::::try_from(1)?; @@ -701,4 +737,87 @@ mod test { Ok(()) } + + #[test] + fn test_sum() -> Result<()> { + zebra_test::init(); + + let one: Amount = 1.try_into()?; + let neg_one: Amount = (-1).try_into()?; + + let zero: Amount = Amount::zero(); + + // success + let amounts = vec![one, neg_one, zero]; + + let sum_ref: Amount = amounts.iter().sum::>()?; + let sum_value: Amount = amounts.into_iter().sum::>()?; + + assert_eq!(sum_ref, sum_value); + assert_eq!(sum_ref, zero); + + // above max for Amount error + let max: Amount = MAX_MONEY.try_into()?; + let amounts = vec![one, max]; + let integer_sum: i64 = amounts.iter().map(|a| a.0).sum(); + + let sum_ref = amounts.iter().sum::>(); + let sum_value = amounts.into_iter().sum::>(); + + assert_eq!(sum_ref, sum_value); + assert_eq!( + sum_ref, + Err(Error::Contains { + range: -MAX_MONEY..=MAX_MONEY, + value: integer_sum, + }) + ); + + // below min for Amount error + let min: Amount = (-MAX_MONEY).try_into()?; + let amounts = vec![min, neg_one]; + let integer_sum: i64 = amounts.iter().map(|a| a.0).sum(); + + let sum_ref = amounts.iter().sum::>(); + let sum_value = amounts.into_iter().sum::>(); + + assert_eq!(sum_ref, sum_value); + assert_eq!( + sum_ref, + Err(Error::Contains { + range: -MAX_MONEY..=MAX_MONEY, + value: integer_sum, + }) + ); + + // above max of i64 error + let times: usize = (i64::MAX / MAX_MONEY) + .try_into() + .expect("4392 can always be converted to usize"); + let amounts: Vec = std::iter::repeat(MAX_MONEY.try_into()?) + .take(times + 1) + .collect(); + + let sum_ref = amounts.iter().sum::>(); + let sum_value = amounts.into_iter().sum::>(); + + assert_eq!(sum_ref, sum_value); + assert_eq!(sum_ref, Err(Error::SumOverflow)); + + // below min of i64 overflow + let times: usize = (i64::MAX / MAX_MONEY) + .try_into() + .expect("4392 can always be converted to usize"); + let neg_max_money: Amount = (-MAX_MONEY).try_into()?; + let amounts: Vec> = + std::iter::repeat(neg_max_money).take(times + 1).collect(); + + let sum_ref = amounts.iter().sum::>(); + let sum_value = amounts.into_iter().sum::>(); + + assert_eq!(sum_ref, sum_value); + assert_eq!(sum_ref, Err(Error::SumOverflow)); + + Ok(()) + } }