implement Sum for Amount (#2500)

* implement Sum for Amount

* check overflows

* add a `zero()` method to `Amount`

* impl iter::Sum<&Amount<C>> for Result<Amount<C>>

And modify the tests so they test both reference and value based sums.

* use `try_fold()`

* change error doc

* use iter::repeat()

* fix test

Co-authored-by: teor <teor@riseup.net>
This commit is contained in:
Alfredo Garcia 2021-07-19 20:05:36 -03:00 committed by GitHub
parent bfc3e4a46c
commit 1dae1f49df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 131 additions and 12 deletions

View File

@ -47,6 +47,14 @@ impl<C> Amount<C> {
LittleEndian::write_i64(&mut buf, self.0); LittleEndian::write_i64(&mut buf, self.0);
buf buf
} }
/// Create a zero `Amount`
pub fn zero() -> Amount<C>
where
C: Constraint,
{
0.try_into().expect("an amount of 0 is always valid")
}
} }
impl<C> std::ops::Add<Amount<C>> for Amount<C> impl<C> std::ops::Add<Amount<C>> for Amount<C>
@ -293,6 +301,31 @@ impl std::ops::Div<u64> for Amount<NonNegative> {
} }
} }
impl<C> std::iter::Sum<Amount<C>> for Result<Amount<C>>
where
C: Constraint,
{
fn sum<I: Iterator<Item = Amount<C>>>(iter: I) -> Self {
let sum = iter
.map(|a| a.0)
.try_fold(0i64, |acc, amount| acc.checked_add(amount));
match sum {
Some(sum) => Amount::try_from(sum),
None => Err(Error::SumOverflow),
}
}
}
impl<'amt, C> std::iter::Sum<&'amt Amount<C>> for Result<Amount<C>>
where
C: Constraint + std::marker::Copy + 'amt,
{
fn sum<I: Iterator<Item = &'amt Amount<C>>>(iter: I) -> Self {
iter.copied().sum()
}
}
#[derive(thiserror::Error, Debug, displaydoc::Display, Clone, PartialEq)] #[derive(thiserror::Error, Debug, displaydoc::Display, Clone, PartialEq)]
#[allow(missing_docs)] #[allow(missing_docs)]
/// Errors that can be returned when validating `Amount`s /// Errors that can be returned when validating `Amount`s
@ -311,6 +344,9 @@ pub enum Error {
MultiplicationOverflow { amount: i64, multiplier: u64 }, MultiplicationOverflow { amount: i64, multiplier: u64 },
/// cannot divide amount {amount} by zero /// cannot divide amount {amount} by zero
DivideByZero { amount: i64 }, DivideByZero { amount: i64 },
/// i64 overflow when summing i64 amounts
SumOverflow,
} }
/// Marker type for `Amount` that allows negative values. /// Marker type for `Amount` that allows negative values.
@ -438,7 +474,7 @@ mod test {
let one: Amount = 1.try_into()?; let one: Amount = 1.try_into()?;
let neg_one: Amount = (-1).try_into()?; let neg_one: Amount = (-1).try_into()?;
let zero: Amount = 0.try_into()?; let zero: Amount = Amount::zero();
let new_zero = one + neg_one; let new_zero = one + neg_one;
assert_eq!(zero, new_zero?); assert_eq!(zero, new_zero?);
@ -454,7 +490,7 @@ mod test {
let one = Ok(one); let one = Ok(one);
let neg_one: Amount = (-1).try_into()?; let neg_one: Amount = (-1).try_into()?;
let zero: Amount = 0.try_into()?; let zero: Amount = Amount::zero();
let new_zero = one + neg_one; let new_zero = one + neg_one;
assert_eq!(zero, new_zero?); assert_eq!(zero, new_zero?);
@ -470,7 +506,7 @@ mod test {
let neg_one: Amount = (-1).try_into()?; let neg_one: Amount = (-1).try_into()?;
let neg_one = Ok(neg_one); let neg_one = Ok(neg_one);
let zero: Amount = 0.try_into()?; let zero: Amount = Amount::zero();
let new_zero = one + neg_one; let new_zero = one + neg_one;
assert_eq!(zero, new_zero?); assert_eq!(zero, new_zero?);
@ -487,7 +523,7 @@ mod test {
let neg_one: Amount = (-1).try_into()?; let neg_one: Amount = (-1).try_into()?;
let neg_one = Ok(neg_one); let neg_one = Ok(neg_one);
let zero: Amount = 0.try_into()?; let zero: Amount = Amount::zero();
let new_zero = one.and_then(|one| one + neg_one); let new_zero = one.and_then(|one| one + neg_one);
assert_eq!(zero, new_zero?); assert_eq!(zero, new_zero?);
@ -503,7 +539,7 @@ mod test {
let neg_one: Amount = (-1).try_into()?; let neg_one: Amount = (-1).try_into()?;
let mut neg_one = Ok(neg_one); let mut neg_one = Ok(neg_one);
let zero: Amount = 0.try_into()?; let zero: Amount = Amount::zero();
neg_one += one; neg_one += one;
let new_zero = neg_one; let new_zero = neg_one;
@ -517,7 +553,7 @@ mod test {
zebra_test::init(); zebra_test::init();
let one: Amount = 1.try_into()?; let one: Amount = 1.try_into()?;
let zero: Amount = 0.try_into()?; let zero: Amount = Amount::zero();
let neg_one: Amount = (-1).try_into()?; let neg_one: Amount = (-1).try_into()?;
let new_neg_one = zero - one; let new_neg_one = zero - one;
@ -533,7 +569,7 @@ mod test {
let one: Amount = 1.try_into()?; let one: Amount = 1.try_into()?;
let one = Ok(one); let one = Ok(one);
let zero: Amount = 0.try_into()?; let zero: Amount = Amount::zero();
let neg_one: Amount = (-1).try_into()?; let neg_one: Amount = (-1).try_into()?;
let new_neg_one = zero - one; let new_neg_one = zero - one;
@ -548,7 +584,7 @@ mod test {
zebra_test::init(); zebra_test::init();
let one: Amount = 1.try_into()?; let one: Amount = 1.try_into()?;
let zero: Amount = 0.try_into()?; let zero: Amount = Amount::zero();
let zero = Ok(zero); let zero = Ok(zero);
let neg_one: Amount = (-1).try_into()?; let neg_one: Amount = (-1).try_into()?;
@ -564,7 +600,7 @@ mod test {
zebra_test::init(); zebra_test::init();
let one: Amount = 1.try_into()?; let one: Amount = 1.try_into()?;
let zero: Amount = 0.try_into()?; let zero: Amount = Amount::zero();
let mut zero = Ok(zero); let mut zero = Ok(zero);
let neg_one: Amount = (-1).try_into()?; let neg_one: Amount = (-1).try_into()?;
@ -581,7 +617,7 @@ mod test {
zebra_test::init(); zebra_test::init();
let one = Amount::<NonNegative>::try_from(1)?; let one = Amount::<NonNegative>::try_from(1)?;
let zero = Amount::<NegativeAllowed>::try_from(0)?; let zero: Amount<NegativeAllowed> = Amount::zero();
(zero - one.constrain()).expect("should allow negative"); (zero - one.constrain()).expect("should allow negative");
(zero.constrain() - one).expect_err("shouldn't allow negative"); (zero.constrain() - one).expect_err("shouldn't allow negative");
@ -630,7 +666,7 @@ mod test {
let one = Amount::<NonNegative>::try_from(1)?; let one = Amount::<NonNegative>::try_from(1)?;
let another_one = Amount::<NonNegative>::try_from(1)?; let another_one = Amount::<NonNegative>::try_from(1)?;
let zero = Amount::<NonNegative>::try_from(0)?; let zero: Amount<NonNegative> = Amount::zero();
let hash_set: HashSet<Amount<NonNegative>, RandomState> = [one].iter().cloned().collect(); let hash_set: HashSet<Amount<NonNegative>, RandomState> = [one].iter().cloned().collect();
assert_eq!(hash_set.len(), 1); assert_eq!(hash_set.len(), 1);
@ -672,7 +708,7 @@ mod test {
C1: Constraint + Debug, C1: Constraint + Debug,
C2: Constraint + Debug, C2: Constraint + Debug,
{ {
let zero = Amount::<C1>::try_from(0)?; let zero: Amount<C1> = Amount::zero();
let one = Amount::<C2>::try_from(1)?; let one = Amount::<C2>::try_from(1)?;
let another_one = Amount::<C1>::try_from(1)?; let another_one = Amount::<C1>::try_from(1)?;
@ -701,4 +737,87 @@ mod test {
Ok(()) Ok(())
} }
#[test]
fn test_sum() -> Result<()> {
zebra_test::init();
let one: Amount = 1.try_into()?;
let neg_one: Amount = (-1).try_into()?;
let zero: Amount = Amount::zero();
// success
let amounts = vec![one, neg_one, zero];
let sum_ref: Amount = amounts.iter().sum::<Result<Amount, Error>>()?;
let sum_value: Amount = amounts.into_iter().sum::<Result<Amount, Error>>()?;
assert_eq!(sum_ref, sum_value);
assert_eq!(sum_ref, zero);
// above max for Amount error
let max: Amount = MAX_MONEY.try_into()?;
let amounts = vec![one, max];
let integer_sum: i64 = amounts.iter().map(|a| a.0).sum();
let sum_ref = amounts.iter().sum::<Result<Amount, Error>>();
let sum_value = amounts.into_iter().sum::<Result<Amount, Error>>();
assert_eq!(sum_ref, sum_value);
assert_eq!(
sum_ref,
Err(Error::Contains {
range: -MAX_MONEY..=MAX_MONEY,
value: integer_sum,
})
);
// below min for Amount error
let min: Amount = (-MAX_MONEY).try_into()?;
let amounts = vec![min, neg_one];
let integer_sum: i64 = amounts.iter().map(|a| a.0).sum();
let sum_ref = amounts.iter().sum::<Result<Amount, Error>>();
let sum_value = amounts.into_iter().sum::<Result<Amount, Error>>();
assert_eq!(sum_ref, sum_value);
assert_eq!(
sum_ref,
Err(Error::Contains {
range: -MAX_MONEY..=MAX_MONEY,
value: integer_sum,
})
);
// above max of i64 error
let times: usize = (i64::MAX / MAX_MONEY)
.try_into()
.expect("4392 can always be converted to usize");
let amounts: Vec<Amount> = std::iter::repeat(MAX_MONEY.try_into()?)
.take(times + 1)
.collect();
let sum_ref = amounts.iter().sum::<Result<Amount, Error>>();
let sum_value = amounts.into_iter().sum::<Result<Amount, Error>>();
assert_eq!(sum_ref, sum_value);
assert_eq!(sum_ref, Err(Error::SumOverflow));
// below min of i64 overflow
let times: usize = (i64::MAX / MAX_MONEY)
.try_into()
.expect("4392 can always be converted to usize");
let neg_max_money: Amount<NegativeAllowed> = (-MAX_MONEY).try_into()?;
let amounts: Vec<Amount<NegativeAllowed>> =
std::iter::repeat(neg_max_money).take(times + 1).collect();
let sum_ref = amounts.iter().sum::<Result<Amount, Error>>();
let sum_value = amounts.into_iter().sum::<Result<Amount, Error>>();
assert_eq!(sum_ref, sum_value);
assert_eq!(sum_ref, Err(Error::SumOverflow));
Ok(())
}
} }