//! Strongly-typed zatoshi amounts that prevent under/overflows. //! //! The [`Amount`] type is parameterized by a [`Constraint`] implementation that //! declares the range of allowed values. In contrast to regular arithmetic //! operations, which return values, arithmetic on [`Amount`]s returns //! [`Result`](std::result::Result)s. use std::{ cmp::Ordering, convert::{TryFrom, TryInto}, hash::{Hash, Hasher}, marker::PhantomData, ops::RangeInclusive, }; use crate::serialization::{ZcashDeserialize, ZcashSerialize}; use byteorder::{ByteOrder, LittleEndian, ReadBytesExt, WriteBytesExt}; type Result = std::result::Result; /// A runtime validated type for representing amounts of zatoshis #[derive(Clone, Copy, Serialize, Deserialize)] #[serde(try_from = "i64")] #[serde(bound = "C: Constraint")] pub struct Amount(i64, PhantomData); impl std::fmt::Debug for Amount { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_tuple(&format!("Amount<{}>", std::any::type_name::())) .field(&self.0) .finish() } } impl Amount { /// Convert this amount to a different Amount type if it satisfies the new constraint pub fn constrain(self) -> Result> where C2: Constraint, { self.0.try_into() } /// To little endian byte array pub fn to_bytes(&self) -> [u8; 8] { let mut buf: [u8; 8] = [0; 8]; LittleEndian::write_i64(&mut buf, self.0); buf } /// Create a zero `Amount` pub fn zero() -> Amount where C: Constraint, { 0.try_into().expect("an amount of 0 is always valid") } } impl std::ops::Add> for Amount where C: Constraint, { type Output = Result>; fn add(self, rhs: Amount) -> Self::Output { let value = self.0 + rhs.0; value.try_into() } } impl std::ops::Add> for Result> where C: Constraint, { type Output = Result>; fn add(self, rhs: Amount) -> Self::Output { self? + rhs } } impl std::ops::Add>> for Amount where C: Constraint, { type Output = Result>; fn add(self, rhs: Result>) -> Self::Output { self + rhs? } } impl std::ops::AddAssign> for Result> where Amount: Copy, C: Constraint, { fn add_assign(&mut self, rhs: Amount) { if let Ok(lhs) = *self { *self = lhs + rhs; } } } impl std::ops::Sub> for Amount where C: Constraint, { type Output = Result>; fn sub(self, rhs: Amount) -> Self::Output { let value = self.0 - rhs.0; value.try_into() } } impl std::ops::Sub> for Result> where C: Constraint, { type Output = Result>; fn sub(self, rhs: Amount) -> Self::Output { self? - rhs } } impl std::ops::Sub>> for Amount where C: Constraint, { type Output = Result>; fn sub(self, rhs: Result>) -> Self::Output { self - rhs? } } impl std::ops::SubAssign> for Result> where Amount: Copy, C: Constraint, { fn sub_assign(&mut self, rhs: Amount) { if let Ok(lhs) = *self { *self = lhs - rhs; } } } impl From> for i64 { fn from(amount: Amount) -> Self { amount.0 } } impl From> for u64 { fn from(amount: Amount) -> Self { amount.0 as _ } } impl From> for jubjub::Fr { fn from(a: Amount) -> jubjub::Fr { // TODO: this isn't constant time -- does that matter? if a.0 < 0 { jubjub::Fr::from(a.0.abs() as u64).neg() } else { jubjub::Fr::from(a.0 as u64) } } } impl From> for halo2::pasta::pallas::Scalar { fn from(a: Amount) -> halo2::pasta::pallas::Scalar { // TODO: this isn't constant time -- does that matter? if a.0 < 0 { halo2::pasta::pallas::Scalar::from(a.0.abs() as u64).neg() } else { halo2::pasta::pallas::Scalar::from(a.0 as u64) } } } impl TryFrom for Amount where C: Constraint, { type Error = Error; fn try_from(value: i64) -> Result { C::validate(value).map(|v| Self(v, PhantomData)) } } impl TryFrom for Amount where C: Constraint, { type Error = Error; fn try_from(value: i32) -> Result { C::validate(value as _).map(|v| Self(v, PhantomData)) } } impl TryFrom for Amount where C: Constraint, { type Error = Error; fn try_from(value: u64) -> Result { let value = value .try_into() .map_err(|source| Error::Convert { value, source })?; C::validate(value).map(|v| Self(v, PhantomData)) } } impl Hash for Amount { /// Amounts with the same value are equal, even if they have different constraints fn hash(&self, state: &mut H) { self.0.hash(state); } } impl PartialEq> for Amount { fn eq(&self, other: &Amount) -> bool { self.0.eq(&other.0) } } impl PartialEq for Amount { fn eq(&self, other: &i64) -> bool { self.0.eq(other) } } impl PartialEq> for i64 { fn eq(&self, other: &Amount) -> bool { self.eq(&other.0) } } impl Eq for Amount {} impl Eq for Amount {} impl PartialOrd> for Amount { fn partial_cmp(&self, other: &Amount) -> Option { Some(self.0.cmp(&other.0)) } } impl Ord for Amount { fn cmp(&self, other: &Amount) -> Ordering { self.0.cmp(&other.0) } } impl Ord for Amount { fn cmp(&self, other: &Amount) -> Ordering { self.0.cmp(&other.0) } } impl std::ops::Mul for Amount { type Output = Result>; fn mul(self, rhs: u64) -> Self::Output { let value = (self.0 as u64) .checked_mul(rhs) .ok_or(Error::MultiplicationOverflow { amount: self.0, multiplier: rhs, })?; value.try_into() } } impl std::ops::Mul> for u64 { type Output = Result>; fn mul(self, rhs: Amount) -> Self::Output { rhs.mul(self) } } impl std::ops::Div for Amount { type Output = Result>; fn div(self, rhs: u64) -> Self::Output { let quotient = (self.0 as u64) .checked_div(rhs) .ok_or(Error::DivideByZero { amount: self.0 })?; Ok(quotient .try_into() .expect("division by a positive integer always stays within the constraint")) } } impl std::iter::Sum> for Result> where C: Constraint, { fn sum>>(iter: I) -> Self { let sum = iter .map(|a| a.0) .try_fold(0i64, |acc, amount| acc.checked_add(amount)); match sum { Some(sum) => Amount::try_from(sum), None => Err(Error::SumOverflow), } } } impl<'amt, C> std::iter::Sum<&'amt Amount> for Result> where C: Constraint + std::marker::Copy + 'amt, { fn sum>>(iter: I) -> Self { iter.copied().sum() } } impl std::ops::Neg for Amount { type Output = Self; fn neg(self) -> Self::Output { Amount::try_from(-self.0) .expect("a change in sign to any value inside Amount is always valid") } } #[derive(thiserror::Error, Debug, displaydoc::Display, Clone, PartialEq)] #[allow(missing_docs)] /// Errors that can be returned when validating `Amount`s pub enum Error { /// input {value} is outside of valid range for zatoshi Amount, valid_range={range:?} Contains { range: RangeInclusive, value: i64, }, /// u64 {value} could not be converted to an i64 Amount Convert { value: u64, source: std::num::TryFromIntError, }, /// i64 overflow when multiplying i64 non-negative amount {amount} by u64 {multiplier} MultiplicationOverflow { amount: i64, multiplier: u64 }, /// cannot divide amount {amount} by zero DivideByZero { amount: i64 }, /// i64 overflow when summing i64 amounts SumOverflow, } /// Marker type for `Amount` that allows negative values. /// /// ``` /// # use zebra_chain::amount::{Constraint, MAX_MONEY, NegativeAllowed}; /// assert_eq!( /// NegativeAllowed::valid_range(), /// -MAX_MONEY..=MAX_MONEY, /// ); /// ``` #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub struct NegativeAllowed; impl Constraint for NegativeAllowed { fn valid_range() -> RangeInclusive { -MAX_MONEY..=MAX_MONEY } } /// Marker type for `Amount` that requires nonnegative values. /// /// ``` /// # use zebra_chain::amount::{Constraint, MAX_MONEY, NonNegative}; /// assert_eq!( /// NonNegative::valid_range(), /// 0..=MAX_MONEY, /// ); /// ``` #[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] pub struct NonNegative; impl Constraint for NonNegative { fn valid_range() -> RangeInclusive { 0..=MAX_MONEY } } /// Number of zatoshis in 1 ZEC pub const COIN: i64 = 100_000_000; /// The maximum zatoshi amount. pub const MAX_MONEY: i64 = 21_000_000 * COIN; /// A trait for defining constraints on `Amount` pub trait Constraint { /// Returns the range of values that are valid under this constraint fn valid_range() -> RangeInclusive; /// Check if an input value is within the valid range fn validate(value: i64) -> Result { let range = Self::valid_range(); if !range.contains(&value) { Err(Error::Contains { range, value }) } else { Ok(value) } } } impl ZcashSerialize for Amount { fn zcash_serialize(&self, mut writer: W) -> Result<(), std::io::Error> { writer.write_i64::(self.0) } } impl ZcashDeserialize for Amount { fn zcash_deserialize( mut reader: R, ) -> Result { Ok(reader.read_i64::()?.try_into()?) } } impl ZcashSerialize for Amount { fn zcash_serialize(&self, mut writer: W) -> Result<(), std::io::Error> { let amount = self .0 .try_into() .expect("constraint guarantees value is positive"); writer.write_u64::(amount) } } impl ZcashDeserialize for Amount { fn zcash_deserialize( mut reader: R, ) -> Result { Ok(reader.read_u64::()?.try_into()?) } } #[cfg(any(test, feature = "proptest-impl"))] use proptest::prelude::*; #[cfg(any(test, feature = "proptest-impl"))] impl Arbitrary for Amount where C: Constraint + std::fmt::Debug, { type Parameters = (); fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { C::valid_range().prop_map(|v| Self(v, PhantomData)).boxed() } type Strategy = BoxedStrategy; } #[cfg(test)] mod test { use crate::serialization::ZcashDeserializeInto; use super::*; use std::{collections::hash_map::RandomState, collections::HashSet, fmt::Debug}; use color_eyre::eyre::Result; #[test] fn test_add_bare() -> Result<()> { zebra_test::init(); let one: Amount = 1.try_into()?; let neg_one: Amount = (-1).try_into()?; let zero: Amount = Amount::zero(); let new_zero = one + neg_one; assert_eq!(zero, new_zero?); Ok(()) } #[test] fn test_add_opt_lhs() -> Result<()> { zebra_test::init(); let one: Amount = 1.try_into()?; let one = Ok(one); let neg_one: Amount = (-1).try_into()?; let zero: Amount = Amount::zero(); let new_zero = one + neg_one; assert_eq!(zero, new_zero?); Ok(()) } #[test] fn test_add_opt_rhs() -> Result<()> { zebra_test::init(); let one: Amount = 1.try_into()?; let neg_one: Amount = (-1).try_into()?; let neg_one = Ok(neg_one); let zero: Amount = Amount::zero(); let new_zero = one + neg_one; assert_eq!(zero, new_zero?); Ok(()) } #[test] fn test_add_opt_both() -> Result<()> { zebra_test::init(); let one: Amount = 1.try_into()?; let one = Ok(one); let neg_one: Amount = (-1).try_into()?; let neg_one = Ok(neg_one); let zero: Amount = Amount::zero(); let new_zero = one.and_then(|one| one + neg_one); assert_eq!(zero, new_zero?); Ok(()) } #[test] fn test_add_assign() -> Result<()> { zebra_test::init(); let one: Amount = 1.try_into()?; let neg_one: Amount = (-1).try_into()?; let mut neg_one = Ok(neg_one); let zero: Amount = Amount::zero(); neg_one += one; let new_zero = neg_one; assert_eq!(Ok(zero), new_zero); Ok(()) } #[test] fn test_sub_bare() -> Result<()> { zebra_test::init(); let one: Amount = 1.try_into()?; let zero: Amount = Amount::zero(); let neg_one: Amount = (-1).try_into()?; let new_neg_one = zero - one; assert_eq!(Ok(neg_one), new_neg_one); Ok(()) } #[test] fn test_sub_opt_lhs() -> Result<()> { zebra_test::init(); let one: Amount = 1.try_into()?; let one = Ok(one); let zero: Amount = Amount::zero(); let neg_one: Amount = (-1).try_into()?; let new_neg_one = zero - one; assert_eq!(Ok(neg_one), new_neg_one); Ok(()) } #[test] fn test_sub_opt_rhs() -> Result<()> { zebra_test::init(); let one: Amount = 1.try_into()?; let zero: Amount = Amount::zero(); let zero = Ok(zero); let neg_one: Amount = (-1).try_into()?; let new_neg_one = zero - one; assert_eq!(Ok(neg_one), new_neg_one); Ok(()) } #[test] fn test_sub_assign() -> Result<()> { zebra_test::init(); let one: Amount = 1.try_into()?; let zero: Amount = Amount::zero(); let mut zero = Ok(zero); let neg_one: Amount = (-1).try_into()?; zero -= one; let new_neg_one = zero; assert_eq!(Ok(neg_one), new_neg_one); Ok(()) } #[test] fn add_with_diff_constraints() -> Result<()> { zebra_test::init(); let one = Amount::::try_from(1)?; let zero: Amount = Amount::zero(); (zero - one.constrain()).expect("should allow negative"); (zero.constrain() - one).expect_err("shouldn't allow negative"); Ok(()) } #[test] fn deserialize_checks_bounds() -> Result<()> { zebra_test::init(); let big = (MAX_MONEY * 2) .try_into() .expect("unexpectedly large constant: multiplied constant should be within range"); let neg = -10; let mut big_bytes = Vec::new(); (&mut big_bytes) .write_u64::(big) .expect("unexpected serialization failure: vec should be infalliable"); let mut neg_bytes = Vec::new(); (&mut neg_bytes) .write_i64::(neg) .expect("unexpected serialization failure: vec should be infalliable"); Amount::::zcash_deserialize(big_bytes.as_slice()) .expect_err("deserialization should reject too large values"); Amount::::zcash_deserialize(big_bytes.as_slice()) .expect_err("deserialization should reject too large values"); Amount::::zcash_deserialize(neg_bytes.as_slice()) .expect_err("NonNegative deserialization should reject negative values"); let amount: Amount = neg_bytes .zcash_deserialize_into() .expect("NegativeAllowed deserialization should allow negative values"); assert_eq!(amount.0, neg); Ok(()) } #[test] fn hash() -> Result<()> { zebra_test::init(); let one = Amount::::try_from(1)?; let another_one = Amount::::try_from(1)?; let zero: Amount = Amount::zero(); let hash_set: HashSet, RandomState> = [one].iter().cloned().collect(); assert_eq!(hash_set.len(), 1); let hash_set: HashSet, RandomState> = [one, one].iter().cloned().collect(); assert_eq!(hash_set.len(), 1, "Amount hashes are consistent"); let hash_set: HashSet, RandomState> = [one, another_one].iter().cloned().collect(); assert_eq!(hash_set.len(), 1, "Amount hashes are by value"); let hash_set: HashSet, RandomState> = [one, zero].iter().cloned().collect(); assert_eq!( hash_set.len(), 2, "Amount hashes are different for different values" ); Ok(()) } #[test] fn ordering_constraints() -> Result<()> { zebra_test::init(); ordering::()?; ordering::()?; ordering::()?; ordering::()?; Ok(()) } #[allow(clippy::eq_op)] fn ordering() -> Result<()> where C1: Constraint + Debug, C2: Constraint + Debug, { let zero: Amount = Amount::zero(); let one = Amount::::try_from(1)?; let another_one = Amount::::try_from(1)?; assert_eq!(one, one); assert_eq!(one, another_one, "Amount equality is by value"); assert_ne!(one, zero); assert_ne!(zero, one); assert!(one > zero); assert!(zero < one); assert!(zero <= one); let negative_one = Amount::::try_from(-1)?; let negative_two = Amount::::try_from(-2)?; assert_ne!(negative_one, zero); assert_ne!(negative_one, one); assert!(negative_one < zero); assert!(negative_one <= one); assert!(zero > negative_one); assert!(zero >= negative_one); assert!(negative_two < negative_one); assert!(negative_one > negative_two); Ok(()) } #[test] fn test_sum() -> Result<()> { zebra_test::init(); let one: Amount = 1.try_into()?; let neg_one: Amount = (-1).try_into()?; let zero: Amount = Amount::zero(); // success let amounts = vec![one, neg_one, zero]; let sum_ref: Amount = amounts.iter().sum::>()?; let sum_value: Amount = amounts.into_iter().sum::>()?; assert_eq!(sum_ref, sum_value); assert_eq!(sum_ref, zero); // above max for Amount error let max: Amount = MAX_MONEY.try_into()?; let amounts = vec![one, max]; let integer_sum: i64 = amounts.iter().map(|a| a.0).sum(); let sum_ref = amounts.iter().sum::>(); let sum_value = amounts.into_iter().sum::>(); assert_eq!(sum_ref, sum_value); assert_eq!( sum_ref, Err(Error::Contains { range: -MAX_MONEY..=MAX_MONEY, value: integer_sum, }) ); // below min for Amount error let min: Amount = (-MAX_MONEY).try_into()?; let amounts = vec![min, neg_one]; let integer_sum: i64 = amounts.iter().map(|a| a.0).sum(); let sum_ref = amounts.iter().sum::>(); let sum_value = amounts.into_iter().sum::>(); assert_eq!(sum_ref, sum_value); assert_eq!( sum_ref, Err(Error::Contains { range: -MAX_MONEY..=MAX_MONEY, value: integer_sum, }) ); // above max of i64 error let times: usize = (i64::MAX / MAX_MONEY) .try_into() .expect("4392 can always be converted to usize"); let amounts: Vec = std::iter::repeat(MAX_MONEY.try_into()?) .take(times + 1) .collect(); let sum_ref = amounts.iter().sum::>(); let sum_value = amounts.into_iter().sum::>(); assert_eq!(sum_ref, sum_value); assert_eq!(sum_ref, Err(Error::SumOverflow)); // below min of i64 overflow let times: usize = (i64::MAX / MAX_MONEY) .try_into() .expect("4392 can always be converted to usize"); let neg_max_money: Amount = (-MAX_MONEY).try_into()?; let amounts: Vec> = std::iter::repeat(neg_max_money).take(times + 1).collect(); let sum_ref = amounts.iter().sum::>(); let sum_value = amounts.into_iter().sum::>(); assert_eq!(sum_ref, sum_value); assert_eq!(sum_ref, Err(Error::SumOverflow)); Ok(()) } }