implement Sum for Amount (#2500)
* implement Sum for Amount * check overflows * add a `zero()` method to `Amount` * impl iter::Sum<&Amount<C>> for Result<Amount<C>> And modify the tests so they test both reference and value based sums. * use `try_fold()` * change error doc * use iter::repeat() * fix test Co-authored-by: teor <teor@riseup.net>
This commit is contained in:
parent
bfc3e4a46c
commit
1dae1f49df
|
@ -47,6 +47,14 @@ impl<C> Amount<C> {
|
||||||
LittleEndian::write_i64(&mut buf, self.0);
|
LittleEndian::write_i64(&mut buf, self.0);
|
||||||
buf
|
buf
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a zero `Amount`
|
||||||
|
pub fn zero() -> Amount<C>
|
||||||
|
where
|
||||||
|
C: Constraint,
|
||||||
|
{
|
||||||
|
0.try_into().expect("an amount of 0 is always valid")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<C> std::ops::Add<Amount<C>> for Amount<C>
|
impl<C> std::ops::Add<Amount<C>> for Amount<C>
|
||||||
|
@ -293,6 +301,31 @@ impl std::ops::Div<u64> for Amount<NonNegative> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<C> std::iter::Sum<Amount<C>> for Result<Amount<C>>
|
||||||
|
where
|
||||||
|
C: Constraint,
|
||||||
|
{
|
||||||
|
fn sum<I: Iterator<Item = Amount<C>>>(iter: I) -> Self {
|
||||||
|
let sum = iter
|
||||||
|
.map(|a| a.0)
|
||||||
|
.try_fold(0i64, |acc, amount| acc.checked_add(amount));
|
||||||
|
|
||||||
|
match sum {
|
||||||
|
Some(sum) => Amount::try_from(sum),
|
||||||
|
None => Err(Error::SumOverflow),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'amt, C> std::iter::Sum<&'amt Amount<C>> for Result<Amount<C>>
|
||||||
|
where
|
||||||
|
C: Constraint + std::marker::Copy + 'amt,
|
||||||
|
{
|
||||||
|
fn sum<I: Iterator<Item = &'amt Amount<C>>>(iter: I) -> Self {
|
||||||
|
iter.copied().sum()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(thiserror::Error, Debug, displaydoc::Display, Clone, PartialEq)]
|
#[derive(thiserror::Error, Debug, displaydoc::Display, Clone, PartialEq)]
|
||||||
#[allow(missing_docs)]
|
#[allow(missing_docs)]
|
||||||
/// Errors that can be returned when validating `Amount`s
|
/// Errors that can be returned when validating `Amount`s
|
||||||
|
@ -311,6 +344,9 @@ pub enum Error {
|
||||||
MultiplicationOverflow { amount: i64, multiplier: u64 },
|
MultiplicationOverflow { amount: i64, multiplier: u64 },
|
||||||
/// cannot divide amount {amount} by zero
|
/// cannot divide amount {amount} by zero
|
||||||
DivideByZero { amount: i64 },
|
DivideByZero { amount: i64 },
|
||||||
|
|
||||||
|
/// i64 overflow when summing i64 amounts
|
||||||
|
SumOverflow,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Marker type for `Amount` that allows negative values.
|
/// Marker type for `Amount` that allows negative values.
|
||||||
|
@ -438,7 +474,7 @@ mod test {
|
||||||
let one: Amount = 1.try_into()?;
|
let one: Amount = 1.try_into()?;
|
||||||
let neg_one: Amount = (-1).try_into()?;
|
let neg_one: Amount = (-1).try_into()?;
|
||||||
|
|
||||||
let zero: Amount = 0.try_into()?;
|
let zero: Amount = Amount::zero();
|
||||||
let new_zero = one + neg_one;
|
let new_zero = one + neg_one;
|
||||||
|
|
||||||
assert_eq!(zero, new_zero?);
|
assert_eq!(zero, new_zero?);
|
||||||
|
@ -454,7 +490,7 @@ mod test {
|
||||||
let one = Ok(one);
|
let one = Ok(one);
|
||||||
let neg_one: Amount = (-1).try_into()?;
|
let neg_one: Amount = (-1).try_into()?;
|
||||||
|
|
||||||
let zero: Amount = 0.try_into()?;
|
let zero: Amount = Amount::zero();
|
||||||
let new_zero = one + neg_one;
|
let new_zero = one + neg_one;
|
||||||
|
|
||||||
assert_eq!(zero, new_zero?);
|
assert_eq!(zero, new_zero?);
|
||||||
|
@ -470,7 +506,7 @@ mod test {
|
||||||
let neg_one: Amount = (-1).try_into()?;
|
let neg_one: Amount = (-1).try_into()?;
|
||||||
let neg_one = Ok(neg_one);
|
let neg_one = Ok(neg_one);
|
||||||
|
|
||||||
let zero: Amount = 0.try_into()?;
|
let zero: Amount = Amount::zero();
|
||||||
let new_zero = one + neg_one;
|
let new_zero = one + neg_one;
|
||||||
|
|
||||||
assert_eq!(zero, new_zero?);
|
assert_eq!(zero, new_zero?);
|
||||||
|
@ -487,7 +523,7 @@ mod test {
|
||||||
let neg_one: Amount = (-1).try_into()?;
|
let neg_one: Amount = (-1).try_into()?;
|
||||||
let neg_one = Ok(neg_one);
|
let neg_one = Ok(neg_one);
|
||||||
|
|
||||||
let zero: Amount = 0.try_into()?;
|
let zero: Amount = Amount::zero();
|
||||||
let new_zero = one.and_then(|one| one + neg_one);
|
let new_zero = one.and_then(|one| one + neg_one);
|
||||||
|
|
||||||
assert_eq!(zero, new_zero?);
|
assert_eq!(zero, new_zero?);
|
||||||
|
@ -503,7 +539,7 @@ mod test {
|
||||||
let neg_one: Amount = (-1).try_into()?;
|
let neg_one: Amount = (-1).try_into()?;
|
||||||
let mut neg_one = Ok(neg_one);
|
let mut neg_one = Ok(neg_one);
|
||||||
|
|
||||||
let zero: Amount = 0.try_into()?;
|
let zero: Amount = Amount::zero();
|
||||||
neg_one += one;
|
neg_one += one;
|
||||||
let new_zero = neg_one;
|
let new_zero = neg_one;
|
||||||
|
|
||||||
|
@ -517,7 +553,7 @@ mod test {
|
||||||
zebra_test::init();
|
zebra_test::init();
|
||||||
|
|
||||||
let one: Amount = 1.try_into()?;
|
let one: Amount = 1.try_into()?;
|
||||||
let zero: Amount = 0.try_into()?;
|
let zero: Amount = Amount::zero();
|
||||||
|
|
||||||
let neg_one: Amount = (-1).try_into()?;
|
let neg_one: Amount = (-1).try_into()?;
|
||||||
let new_neg_one = zero - one;
|
let new_neg_one = zero - one;
|
||||||
|
@ -533,7 +569,7 @@ mod test {
|
||||||
|
|
||||||
let one: Amount = 1.try_into()?;
|
let one: Amount = 1.try_into()?;
|
||||||
let one = Ok(one);
|
let one = Ok(one);
|
||||||
let zero: Amount = 0.try_into()?;
|
let zero: Amount = Amount::zero();
|
||||||
|
|
||||||
let neg_one: Amount = (-1).try_into()?;
|
let neg_one: Amount = (-1).try_into()?;
|
||||||
let new_neg_one = zero - one;
|
let new_neg_one = zero - one;
|
||||||
|
@ -548,7 +584,7 @@ mod test {
|
||||||
zebra_test::init();
|
zebra_test::init();
|
||||||
|
|
||||||
let one: Amount = 1.try_into()?;
|
let one: Amount = 1.try_into()?;
|
||||||
let zero: Amount = 0.try_into()?;
|
let zero: Amount = Amount::zero();
|
||||||
let zero = Ok(zero);
|
let zero = Ok(zero);
|
||||||
|
|
||||||
let neg_one: Amount = (-1).try_into()?;
|
let neg_one: Amount = (-1).try_into()?;
|
||||||
|
@ -564,7 +600,7 @@ mod test {
|
||||||
zebra_test::init();
|
zebra_test::init();
|
||||||
|
|
||||||
let one: Amount = 1.try_into()?;
|
let one: Amount = 1.try_into()?;
|
||||||
let zero: Amount = 0.try_into()?;
|
let zero: Amount = Amount::zero();
|
||||||
let mut zero = Ok(zero);
|
let mut zero = Ok(zero);
|
||||||
|
|
||||||
let neg_one: Amount = (-1).try_into()?;
|
let neg_one: Amount = (-1).try_into()?;
|
||||||
|
@ -581,7 +617,7 @@ mod test {
|
||||||
zebra_test::init();
|
zebra_test::init();
|
||||||
|
|
||||||
let one = Amount::<NonNegative>::try_from(1)?;
|
let one = Amount::<NonNegative>::try_from(1)?;
|
||||||
let zero = Amount::<NegativeAllowed>::try_from(0)?;
|
let zero: Amount<NegativeAllowed> = Amount::zero();
|
||||||
|
|
||||||
(zero - one.constrain()).expect("should allow negative");
|
(zero - one.constrain()).expect("should allow negative");
|
||||||
(zero.constrain() - one).expect_err("shouldn't allow negative");
|
(zero.constrain() - one).expect_err("shouldn't allow negative");
|
||||||
|
@ -630,7 +666,7 @@ mod test {
|
||||||
|
|
||||||
let one = Amount::<NonNegative>::try_from(1)?;
|
let one = Amount::<NonNegative>::try_from(1)?;
|
||||||
let another_one = Amount::<NonNegative>::try_from(1)?;
|
let another_one = Amount::<NonNegative>::try_from(1)?;
|
||||||
let zero = Amount::<NonNegative>::try_from(0)?;
|
let zero: Amount<NonNegative> = Amount::zero();
|
||||||
|
|
||||||
let hash_set: HashSet<Amount<NonNegative>, RandomState> = [one].iter().cloned().collect();
|
let hash_set: HashSet<Amount<NonNegative>, RandomState> = [one].iter().cloned().collect();
|
||||||
assert_eq!(hash_set.len(), 1);
|
assert_eq!(hash_set.len(), 1);
|
||||||
|
@ -672,7 +708,7 @@ mod test {
|
||||||
C1: Constraint + Debug,
|
C1: Constraint + Debug,
|
||||||
C2: Constraint + Debug,
|
C2: Constraint + Debug,
|
||||||
{
|
{
|
||||||
let zero = Amount::<C1>::try_from(0)?;
|
let zero: Amount<C1> = Amount::zero();
|
||||||
let one = Amount::<C2>::try_from(1)?;
|
let one = Amount::<C2>::try_from(1)?;
|
||||||
let another_one = Amount::<C1>::try_from(1)?;
|
let another_one = Amount::<C1>::try_from(1)?;
|
||||||
|
|
||||||
|
@ -701,4 +737,87 @@ mod test {
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sum() -> Result<()> {
|
||||||
|
zebra_test::init();
|
||||||
|
|
||||||
|
let one: Amount = 1.try_into()?;
|
||||||
|
let neg_one: Amount = (-1).try_into()?;
|
||||||
|
|
||||||
|
let zero: Amount = Amount::zero();
|
||||||
|
|
||||||
|
// success
|
||||||
|
let amounts = vec![one, neg_one, zero];
|
||||||
|
|
||||||
|
let sum_ref: Amount = amounts.iter().sum::<Result<Amount, Error>>()?;
|
||||||
|
let sum_value: Amount = amounts.into_iter().sum::<Result<Amount, Error>>()?;
|
||||||
|
|
||||||
|
assert_eq!(sum_ref, sum_value);
|
||||||
|
assert_eq!(sum_ref, zero);
|
||||||
|
|
||||||
|
// above max for Amount error
|
||||||
|
let max: Amount = MAX_MONEY.try_into()?;
|
||||||
|
let amounts = vec![one, max];
|
||||||
|
let integer_sum: i64 = amounts.iter().map(|a| a.0).sum();
|
||||||
|
|
||||||
|
let sum_ref = amounts.iter().sum::<Result<Amount, Error>>();
|
||||||
|
let sum_value = amounts.into_iter().sum::<Result<Amount, Error>>();
|
||||||
|
|
||||||
|
assert_eq!(sum_ref, sum_value);
|
||||||
|
assert_eq!(
|
||||||
|
sum_ref,
|
||||||
|
Err(Error::Contains {
|
||||||
|
range: -MAX_MONEY..=MAX_MONEY,
|
||||||
|
value: integer_sum,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
// below min for Amount error
|
||||||
|
let min: Amount = (-MAX_MONEY).try_into()?;
|
||||||
|
let amounts = vec![min, neg_one];
|
||||||
|
let integer_sum: i64 = amounts.iter().map(|a| a.0).sum();
|
||||||
|
|
||||||
|
let sum_ref = amounts.iter().sum::<Result<Amount, Error>>();
|
||||||
|
let sum_value = amounts.into_iter().sum::<Result<Amount, Error>>();
|
||||||
|
|
||||||
|
assert_eq!(sum_ref, sum_value);
|
||||||
|
assert_eq!(
|
||||||
|
sum_ref,
|
||||||
|
Err(Error::Contains {
|
||||||
|
range: -MAX_MONEY..=MAX_MONEY,
|
||||||
|
value: integer_sum,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
// above max of i64 error
|
||||||
|
let times: usize = (i64::MAX / MAX_MONEY)
|
||||||
|
.try_into()
|
||||||
|
.expect("4392 can always be converted to usize");
|
||||||
|
let amounts: Vec<Amount> = std::iter::repeat(MAX_MONEY.try_into()?)
|
||||||
|
.take(times + 1)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let sum_ref = amounts.iter().sum::<Result<Amount, Error>>();
|
||||||
|
let sum_value = amounts.into_iter().sum::<Result<Amount, Error>>();
|
||||||
|
|
||||||
|
assert_eq!(sum_ref, sum_value);
|
||||||
|
assert_eq!(sum_ref, Err(Error::SumOverflow));
|
||||||
|
|
||||||
|
// below min of i64 overflow
|
||||||
|
let times: usize = (i64::MAX / MAX_MONEY)
|
||||||
|
.try_into()
|
||||||
|
.expect("4392 can always be converted to usize");
|
||||||
|
let neg_max_money: Amount<NegativeAllowed> = (-MAX_MONEY).try_into()?;
|
||||||
|
let amounts: Vec<Amount<NegativeAllowed>> =
|
||||||
|
std::iter::repeat(neg_max_money).take(times + 1).collect();
|
||||||
|
|
||||||
|
let sum_ref = amounts.iter().sum::<Result<Amount, Error>>();
|
||||||
|
let sum_value = amounts.into_iter().sum::<Result<Amount, Error>>();
|
||||||
|
|
||||||
|
assert_eq!(sum_ref, sum_value);
|
||||||
|
assert_eq!(sum_ref, Err(Error::SumOverflow));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue