Make amount addition and subtraction traits use checked operations.
This commit is contained in:
parent
35023ed8ca
commit
4efb21d1c7
|
@ -23,7 +23,7 @@ hex = "0.4"
|
|||
jubjub = "0.6"
|
||||
nom = "6.1"
|
||||
percent-encoding = "2.1.0"
|
||||
proptest = { version = "0.10.1", optional = true }
|
||||
proptest = { version = "1.0.0", optional = true }
|
||||
protobuf = "2.20"
|
||||
rand_core = "0.6"
|
||||
subtle = "2.2.3"
|
||||
|
|
|
@ -24,6 +24,9 @@ pub enum ChainInvalid {
|
|||
|
||||
#[derive(Debug)]
|
||||
pub enum Error<NoteId> {
|
||||
/// The amount specified exceeds the allowed range.
|
||||
InvalidAmount,
|
||||
|
||||
/// Unable to create a new spend because the wallet balance is not sufficient.
|
||||
InsufficientBalance(Amount, Amount),
|
||||
|
||||
|
@ -72,6 +75,10 @@ impl ChainInvalid {
|
|||
impl<N: fmt::Display> fmt::Display for Error<N> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match &self {
|
||||
Error::InvalidAmount => write!(
|
||||
f,
|
||||
"The value lies outside the valid range of Zcash amounts."
|
||||
),
|
||||
Error::InsufficientBalance(have, need) => write!(
|
||||
f,
|
||||
"Insufficient balance (have {}, need {} including fee)",
|
||||
|
|
|
@ -183,11 +183,15 @@ where
|
|||
.get_target_and_anchor_heights()
|
||||
.and_then(|x| x.ok_or_else(|| Error::ScanRequired.into()))?;
|
||||
|
||||
let target_value = value + DEFAULT_FEE;
|
||||
let target_value = (value + DEFAULT_FEE).ok_or_else(|| E::from(Error::InvalidAmount))?;
|
||||
let spendable_notes = wallet_db.select_spendable_notes(account, target_value, anchor_height)?;
|
||||
|
||||
// Confirm we were able to select sufficient value
|
||||
let selected_value = spendable_notes.iter().map(|n| n.note_value).sum();
|
||||
let selected_value = spendable_notes
|
||||
.iter()
|
||||
.map(|n| n.note_value)
|
||||
.sum::<Option<_>>()
|
||||
.ok_or_else(|| E::from(Error::InvalidAmount))?;
|
||||
if selected_value < target_value {
|
||||
return Err(E::from(Error::InsufficientBalance(
|
||||
selected_value,
|
||||
|
|
|
@ -358,13 +358,19 @@ mod tests {
|
|||
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
|
||||
|
||||
// Account balance should reflect both received notes
|
||||
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), value + value2);
|
||||
assert_eq!(
|
||||
get_balance(&db_data, AccountId(0)).unwrap(),
|
||||
(value + value2).unwrap()
|
||||
);
|
||||
|
||||
// "Rewind" to height of last scanned block
|
||||
rewind_to_height(&db_data, sapling_activation_height() + 1).unwrap();
|
||||
|
||||
// Account balance should be unaltered
|
||||
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), value + value2);
|
||||
assert_eq!(
|
||||
get_balance(&db_data, AccountId(0)).unwrap(),
|
||||
(value + value2).unwrap()
|
||||
);
|
||||
|
||||
// Rewind so that one block is dropped
|
||||
rewind_to_height(&db_data, sapling_activation_height()).unwrap();
|
||||
|
@ -376,7 +382,10 @@ mod tests {
|
|||
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
|
||||
|
||||
// Account balance should again reflect both received notes
|
||||
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), value + value2);
|
||||
assert_eq!(
|
||||
get_balance(&db_data, AccountId(0)).unwrap(),
|
||||
(value + value2).unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -485,7 +494,10 @@ mod tests {
|
|||
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
|
||||
|
||||
// Account balance should reflect both received notes
|
||||
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), value + value2);
|
||||
assert_eq!(
|
||||
get_balance(&db_data, AccountId(0)).unwrap(),
|
||||
(value + value2).unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -543,6 +555,9 @@ mod tests {
|
|||
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
|
||||
|
||||
// Account balance should equal the change
|
||||
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), value - value2);
|
||||
assert_eq!(
|
||||
get_balance(&db_data, AccountId(0)).unwrap(),
|
||||
(value - value2).unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -701,7 +701,7 @@ mod tests {
|
|||
let note = Note {
|
||||
g_d: change_addr.diversifier().g_d().unwrap(),
|
||||
pk_d: *change_addr.pk_d(),
|
||||
value: (in_value - value).into(),
|
||||
value: (in_value - value).unwrap().into(),
|
||||
rseed,
|
||||
};
|
||||
let encryptor = sapling_note_encryption::<_, Network>(
|
||||
|
|
|
@ -356,7 +356,10 @@ mod tests {
|
|||
|
||||
// Verified balance does not include the second note
|
||||
let (_, anchor_height2) = (&db_data).get_target_and_anchor_heights().unwrap().unwrap();
|
||||
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), value + value);
|
||||
assert_eq!(
|
||||
get_balance(&db_data, AccountId(0)).unwrap(),
|
||||
(value + value).unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
get_balance_at(&db_data, AccountId(0), anchor_height2).unwrap(),
|
||||
value
|
||||
|
|
|
@ -740,7 +740,7 @@ mod tests {
|
|||
TzeOutPoint::new(tx_a.txid().0, 0),
|
||||
tx_a.tze_outputs[0].clone(),
|
||||
);
|
||||
let value_xfr = value - DEFAULT_FEE;
|
||||
let value_xfr = (value - DEFAULT_FEE).unwrap();
|
||||
db_b.demo_transfer_to_close(prevout_a, value_xfr, preimage_1, h2)
|
||||
.map_err(|e| format!("transfer failure: {:?}", e))
|
||||
.unwrap();
|
||||
|
@ -769,7 +769,7 @@ mod tests {
|
|||
builder_c
|
||||
.add_transparent_output(
|
||||
&TransparentAddress::PublicKey([0; 20]),
|
||||
value_xfr - DEFAULT_FEE,
|
||||
(value_xfr - DEFAULT_FEE).unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
|
|
|
@ -31,7 +31,9 @@ hex = "0.4"
|
|||
jubjub = "0.6"
|
||||
lazy_static = "1"
|
||||
log = "0.4"
|
||||
proptest = { version = "0.10.1", optional = true }
|
||||
nonempty = "0.6"
|
||||
orchard = { git = "https://github.com/zcash/orchard", branch = "main" }
|
||||
proptest = { version = "1.0.0", optional = true }
|
||||
rand = "0.8"
|
||||
rand_core = "0.6"
|
||||
ripemd160 = { version = "0.9", optional = true }
|
||||
|
@ -43,11 +45,16 @@ zcash_note_encryption = { version = "0.0", path = "../components/zcash_note_encr
|
|||
# Temporary workaround for https://github.com/myrrlyn/funty/issues/3
|
||||
funty = "=1.1.0"
|
||||
|
||||
[dependencies.pasta_curves]
|
||||
git = "https://github.com/zcash/pasta_curves.git"
|
||||
rev = "b55a6960dfafd7f767e2820ddf1adaa499322f98"
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.3"
|
||||
hex-literal = "0.3"
|
||||
proptest = "0.10.1"
|
||||
proptest = "1.0.0"
|
||||
rand_xorshift = "0.3"
|
||||
orchard = { git = "https://github.com/zcash/orchard", branch = "main", features = ["test-dependencies"] }
|
||||
|
||||
[features]
|
||||
transparent-inputs = ["ripemd160", "secp256k1"]
|
||||
|
|
|
@ -251,18 +251,18 @@ impl TransparentInputs {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn value_sum(&self) -> Amount {
|
||||
fn value_sum(&self) -> Option<Amount> {
|
||||
#[cfg(feature = "transparent-inputs")]
|
||||
{
|
||||
self.inputs
|
||||
.iter()
|
||||
.map(|input| input.coin.value)
|
||||
.sum::<Amount>()
|
||||
.sum::<Option<Amount>>()
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "transparent-inputs"))]
|
||||
{
|
||||
Amount::zero()
|
||||
Some(Amount::zero())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -643,8 +643,18 @@ impl<'a, P: consensus::Parameters, R: RngCore> Builder<'a, P, R> {
|
|||
//
|
||||
|
||||
// Valid change
|
||||
let change = self.mtx.value_balance - self.fee + self.transparent_inputs.value_sum()
|
||||
- self.mtx.vout.iter().map(|vo| vo.value).sum::<Amount>();
|
||||
let change = self.mtx.value_balance - self.fee
|
||||
+ self
|
||||
.transparent_inputs
|
||||
.value_sum()
|
||||
.ok_or(Error::InvalidAmount)?
|
||||
- self
|
||||
.mtx
|
||||
.vout
|
||||
.iter()
|
||||
.map(|vo| vo.value)
|
||||
.sum::<Option<Amount>>()
|
||||
.ok_or(Error::InvalidAmount)?;
|
||||
|
||||
#[cfg(feature = "zfuture")]
|
||||
let change = change
|
||||
|
@ -653,13 +663,17 @@ impl<'a, P: consensus::Parameters, R: RngCore> Builder<'a, P, R> {
|
|||
.builders
|
||||
.iter()
|
||||
.map(|ein| ein.prevout.value)
|
||||
.sum::<Amount>()
|
||||
.sum::<Option<Amount>>()
|
||||
.ok_or(Error::InvalidAmount)?
|
||||
- self
|
||||
.mtx
|
||||
.tze_outputs
|
||||
.iter()
|
||||
.map(|tzo| tzo.value)
|
||||
.sum::<Amount>();
|
||||
.sum::<Option<Amount>>()
|
||||
.ok_or(Error::InvalidAmount)?;
|
||||
|
||||
let change = change.ok_or(Error::InvalidAmount)?;
|
||||
|
||||
if change.is_negative() {
|
||||
return Err(Error::ChangeIsNegative(change));
|
||||
|
@ -1150,7 +1164,9 @@ mod tests {
|
|||
let builder = Builder::new(TEST_NETWORK, H0);
|
||||
assert_eq!(
|
||||
builder.build(consensus::BranchId::Sapling, &MockTxProver),
|
||||
Err(Error::ChangeIsNegative(Amount::zero() - DEFAULT_FEE))
|
||||
Err(Error::ChangeIsNegative(
|
||||
(Amount::zero() - DEFAULT_FEE).unwrap()
|
||||
))
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -1168,7 +1184,7 @@ mod tests {
|
|||
assert_eq!(
|
||||
builder.build(consensus::BranchId::Sapling, &MockTxProver),
|
||||
Err(Error::ChangeIsNegative(
|
||||
Amount::from_i64(-50000).unwrap() - DEFAULT_FEE
|
||||
(Amount::from_i64(-50000).unwrap() - DEFAULT_FEE).unwrap()
|
||||
))
|
||||
);
|
||||
}
|
||||
|
@ -1186,7 +1202,7 @@ mod tests {
|
|||
assert_eq!(
|
||||
builder.build(consensus::BranchId::Sapling, &MockTxProver),
|
||||
Err(Error::ChangeIsNegative(
|
||||
Amount::from_i64(-50000).unwrap() - DEFAULT_FEE
|
||||
(Amount::from_i64(-50000).unwrap() - DEFAULT_FEE).unwrap()
|
||||
))
|
||||
);
|
||||
}
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use std::convert::TryFrom;
|
||||
use std::iter::Sum;
|
||||
use std::ops::{Add, AddAssign, Sub, SubAssign};
|
||||
|
||||
|
@ -101,6 +102,14 @@ impl Amount {
|
|||
}
|
||||
}
|
||||
|
||||
impl TryFrom<i64> for Amount {
|
||||
type Error = ();
|
||||
|
||||
fn try_from(value: i64) -> Result<Self, ()> {
|
||||
Amount::from_i64(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Amount> for i64 {
|
||||
fn from(amount: Amount) -> i64 {
|
||||
amount.0
|
||||
|
@ -114,36 +123,52 @@ impl From<Amount> for u64 {
|
|||
}
|
||||
|
||||
impl Add<Amount> for Amount {
|
||||
type Output = Amount;
|
||||
type Output = Option<Amount>;
|
||||
|
||||
fn add(self, rhs: Amount) -> Amount {
|
||||
Amount::from_i64(self.0 + rhs.0).expect("addition should remain in range")
|
||||
fn add(self, rhs: Amount) -> Option<Amount> {
|
||||
Amount::from_i64(self.0 + rhs.0).ok()
|
||||
}
|
||||
}
|
||||
|
||||
impl Add<Amount> for Option<Amount> {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: Amount) -> Option<Amount> {
|
||||
self.and_then(|lhs| lhs + rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl AddAssign<Amount> for Amount {
|
||||
fn add_assign(&mut self, rhs: Amount) {
|
||||
*self = *self + rhs
|
||||
*self = (*self + rhs).expect("Addition must produce a valid amount value.")
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub<Amount> for Amount {
|
||||
type Output = Amount;
|
||||
type Output = Option<Amount>;
|
||||
|
||||
fn sub(self, rhs: Amount) -> Amount {
|
||||
Amount::from_i64(self.0 - rhs.0).expect("subtraction should remain in range")
|
||||
fn sub(self, rhs: Amount) -> Option<Amount> {
|
||||
Amount::from_i64(self.0 - rhs.0).ok()
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub<Amount> for Option<Amount> {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: Amount) -> Option<Amount> {
|
||||
self.and_then(|lhs| lhs - rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl SubAssign<Amount> for Amount {
|
||||
fn sub_assign(&mut self, rhs: Amount) {
|
||||
*self = *self - rhs
|
||||
*self = (*self - rhs).expect("Subtraction must produce a valid amount value.")
|
||||
}
|
||||
}
|
||||
|
||||
impl Sum for Amount {
|
||||
fn sum<I: Iterator<Item = Amount>>(iter: I) -> Amount {
|
||||
iter.fold(Amount::zero(), Add::add)
|
||||
impl Sum<Amount> for Option<Amount> {
|
||||
fn sum<I: Iterator<Item = Amount>>(iter: I) -> Self {
|
||||
iter.fold(Some(Amount::zero()), |acc, a| acc? + a)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -153,11 +178,23 @@ pub mod testing {
|
|||
|
||||
use super::{Amount, MAX_MONEY};
|
||||
|
||||
prop_compose! {
|
||||
pub fn arb_amount()(amt in -MAX_MONEY..MAX_MONEY) -> Amount {
|
||||
Amount::from_i64(amt).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
prop_compose! {
|
||||
pub fn arb_nonnegative_amount()(amt in 0i64..MAX_MONEY) -> Amount {
|
||||
Amount::from_i64(amt).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
prop_compose! {
|
||||
pub fn arb_positive_amount()(amt in 1i64..MAX_MONEY) -> Amount {
|
||||
Amount::from_i64(amt).unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -213,10 +250,9 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn add_panics_on_overflow() {
|
||||
fn add_overflow() {
|
||||
let v = Amount(MAX_MONEY);
|
||||
let _sum = v + Amount(1);
|
||||
assert_eq!(v + Amount(1), None)
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -227,10 +263,9 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn sub_panics_on_underflow() {
|
||||
fn sub_underflow() {
|
||||
let v = Amount(-MAX_MONEY);
|
||||
let _diff = v - Amount(1);
|
||||
assert_eq!(v - Amount(1), None)
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
Loading…
Reference in New Issue