Make amount addition and subtraction traits use checked operations.

This commit is contained in:
Kris Nuttycombe 2021-05-11 09:57:28 -06:00
parent 35023ed8ca
commit 4efb21d1c7
10 changed files with 128 additions and 41 deletions

View File

@ -23,7 +23,7 @@ hex = "0.4"
jubjub = "0.6"
nom = "6.1"
percent-encoding = "2.1.0"
proptest = { version = "0.10.1", optional = true }
proptest = { version = "1.0.0", optional = true }
protobuf = "2.20"
rand_core = "0.6"
subtle = "2.2.3"

View File

@ -24,6 +24,9 @@ pub enum ChainInvalid {
#[derive(Debug)]
pub enum Error<NoteId> {
/// The amount specified exceeds the allowed range.
InvalidAmount,
/// Unable to create a new spend because the wallet balance is not sufficient.
InsufficientBalance(Amount, Amount),
@ -72,6 +75,10 @@ impl ChainInvalid {
impl<N: fmt::Display> fmt::Display for Error<N> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match &self {
Error::InvalidAmount => write!(
f,
"The value lies outside the valid range of Zcash amounts."
),
Error::InsufficientBalance(have, need) => write!(
f,
"Insufficient balance (have {}, need {} including fee)",

View File

@ -183,11 +183,15 @@ where
.get_target_and_anchor_heights()
.and_then(|x| x.ok_or_else(|| Error::ScanRequired.into()))?;
let target_value = value + DEFAULT_FEE;
let target_value = (value + DEFAULT_FEE).ok_or_else(|| E::from(Error::InvalidAmount))?;
let spendable_notes = wallet_db.select_spendable_notes(account, target_value, anchor_height)?;
// Confirm we were able to select sufficient value
let selected_value = spendable_notes.iter().map(|n| n.note_value).sum();
let selected_value = spendable_notes
.iter()
.map(|n| n.note_value)
.sum::<Option<_>>()
.ok_or_else(|| E::from(Error::InvalidAmount))?;
if selected_value < target_value {
return Err(E::from(Error::InsufficientBalance(
selected_value,

View File

@ -358,13 +358,19 @@ mod tests {
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Account balance should reflect both received notes
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), value + value2);
assert_eq!(
get_balance(&db_data, AccountId(0)).unwrap(),
(value + value2).unwrap()
);
// "Rewind" to height of last scanned block
rewind_to_height(&db_data, sapling_activation_height() + 1).unwrap();
// Account balance should be unaltered
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), value + value2);
assert_eq!(
get_balance(&db_data, AccountId(0)).unwrap(),
(value + value2).unwrap()
);
// Rewind so that one block is dropped
rewind_to_height(&db_data, sapling_activation_height()).unwrap();
@ -376,7 +382,10 @@ mod tests {
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Account balance should again reflect both received notes
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), value + value2);
assert_eq!(
get_balance(&db_data, AccountId(0)).unwrap(),
(value + value2).unwrap()
);
}
#[test]
@ -485,7 +494,10 @@ mod tests {
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Account balance should reflect both received notes
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), value + value2);
assert_eq!(
get_balance(&db_data, AccountId(0)).unwrap(),
(value + value2).unwrap()
);
}
#[test]
@ -543,6 +555,9 @@ mod tests {
scan_cached_blocks(&tests::network(), &db_cache, &mut db_write, None).unwrap();
// Account balance should equal the change
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), value - value2);
assert_eq!(
get_balance(&db_data, AccountId(0)).unwrap(),
(value - value2).unwrap()
);
}
}

View File

@ -701,7 +701,7 @@ mod tests {
let note = Note {
g_d: change_addr.diversifier().g_d().unwrap(),
pk_d: *change_addr.pk_d(),
value: (in_value - value).into(),
value: (in_value - value).unwrap().into(),
rseed,
};
let encryptor = sapling_note_encryption::<_, Network>(

View File

@ -356,7 +356,10 @@ mod tests {
// Verified balance does not include the second note
let (_, anchor_height2) = (&db_data).get_target_and_anchor_heights().unwrap().unwrap();
assert_eq!(get_balance(&db_data, AccountId(0)).unwrap(), value + value);
assert_eq!(
get_balance(&db_data, AccountId(0)).unwrap(),
(value + value).unwrap()
);
assert_eq!(
get_balance_at(&db_data, AccountId(0), anchor_height2).unwrap(),
value

View File

@ -740,7 +740,7 @@ mod tests {
TzeOutPoint::new(tx_a.txid().0, 0),
tx_a.tze_outputs[0].clone(),
);
let value_xfr = value - DEFAULT_FEE;
let value_xfr = (value - DEFAULT_FEE).unwrap();
db_b.demo_transfer_to_close(prevout_a, value_xfr, preimage_1, h2)
.map_err(|e| format!("transfer failure: {:?}", e))
.unwrap();
@ -769,7 +769,7 @@ mod tests {
builder_c
.add_transparent_output(
&TransparentAddress::PublicKey([0; 20]),
value_xfr - DEFAULT_FEE,
(value_xfr - DEFAULT_FEE).unwrap(),
)
.unwrap();

View File

@ -31,7 +31,9 @@ hex = "0.4"
jubjub = "0.6"
lazy_static = "1"
log = "0.4"
proptest = { version = "0.10.1", optional = true }
nonempty = "0.6"
orchard = { git = "https://github.com/zcash/orchard", branch = "main" }
proptest = { version = "1.0.0", optional = true }
rand = "0.8"
rand_core = "0.6"
ripemd160 = { version = "0.9", optional = true }
@ -43,11 +45,16 @@ zcash_note_encryption = { version = "0.0", path = "../components/zcash_note_encr
# Temporary workaround for https://github.com/myrrlyn/funty/issues/3
funty = "=1.1.0"
[dependencies.pasta_curves]
git = "https://github.com/zcash/pasta_curves.git"
rev = "b55a6960dfafd7f767e2820ddf1adaa499322f98"
[dev-dependencies]
criterion = "0.3"
hex-literal = "0.3"
proptest = "0.10.1"
proptest = "1.0.0"
rand_xorshift = "0.3"
orchard = { git = "https://github.com/zcash/orchard", branch = "main", features = ["test-dependencies"] }
[features]
transparent-inputs = ["ripemd160", "secp256k1"]

View File

@ -251,18 +251,18 @@ impl TransparentInputs {
Ok(())
}
fn value_sum(&self) -> Amount {
fn value_sum(&self) -> Option<Amount> {
#[cfg(feature = "transparent-inputs")]
{
self.inputs
.iter()
.map(|input| input.coin.value)
.sum::<Amount>()
.sum::<Option<Amount>>()
}
#[cfg(not(feature = "transparent-inputs"))]
{
Amount::zero()
Some(Amount::zero())
}
}
@ -643,8 +643,18 @@ impl<'a, P: consensus::Parameters, R: RngCore> Builder<'a, P, R> {
//
// Valid change
let change = self.mtx.value_balance - self.fee + self.transparent_inputs.value_sum()
- self.mtx.vout.iter().map(|vo| vo.value).sum::<Amount>();
let change = self.mtx.value_balance - self.fee
+ self
.transparent_inputs
.value_sum()
.ok_or(Error::InvalidAmount)?
- self
.mtx
.vout
.iter()
.map(|vo| vo.value)
.sum::<Option<Amount>>()
.ok_or(Error::InvalidAmount)?;
#[cfg(feature = "zfuture")]
let change = change
@ -653,13 +663,17 @@ impl<'a, P: consensus::Parameters, R: RngCore> Builder<'a, P, R> {
.builders
.iter()
.map(|ein| ein.prevout.value)
.sum::<Amount>()
.sum::<Option<Amount>>()
.ok_or(Error::InvalidAmount)?
- self
.mtx
.tze_outputs
.iter()
.map(|tzo| tzo.value)
.sum::<Amount>();
.sum::<Option<Amount>>()
.ok_or(Error::InvalidAmount)?;
let change = change.ok_or(Error::InvalidAmount)?;
if change.is_negative() {
return Err(Error::ChangeIsNegative(change));
@ -1150,7 +1164,9 @@ mod tests {
let builder = Builder::new(TEST_NETWORK, H0);
assert_eq!(
builder.build(consensus::BranchId::Sapling, &MockTxProver),
Err(Error::ChangeIsNegative(Amount::zero() - DEFAULT_FEE))
Err(Error::ChangeIsNegative(
(Amount::zero() - DEFAULT_FEE).unwrap()
))
);
}
@ -1168,7 +1184,7 @@ mod tests {
assert_eq!(
builder.build(consensus::BranchId::Sapling, &MockTxProver),
Err(Error::ChangeIsNegative(
Amount::from_i64(-50000).unwrap() - DEFAULT_FEE
(Amount::from_i64(-50000).unwrap() - DEFAULT_FEE).unwrap()
))
);
}
@ -1186,7 +1202,7 @@ mod tests {
assert_eq!(
builder.build(consensus::BranchId::Sapling, &MockTxProver),
Err(Error::ChangeIsNegative(
Amount::from_i64(-50000).unwrap() - DEFAULT_FEE
(Amount::from_i64(-50000).unwrap() - DEFAULT_FEE).unwrap()
))
);
}

View File

@ -1,3 +1,4 @@
use std::convert::TryFrom;
use std::iter::Sum;
use std::ops::{Add, AddAssign, Sub, SubAssign};
@ -101,6 +102,14 @@ impl Amount {
}
}
impl TryFrom<i64> for Amount {
type Error = ();
fn try_from(value: i64) -> Result<Self, ()> {
Amount::from_i64(value)
}
}
impl From<Amount> for i64 {
fn from(amount: Amount) -> i64 {
amount.0
@ -114,36 +123,52 @@ impl From<Amount> for u64 {
}
impl Add<Amount> for Amount {
type Output = Amount;
type Output = Option<Amount>;
fn add(self, rhs: Amount) -> Amount {
Amount::from_i64(self.0 + rhs.0).expect("addition should remain in range")
fn add(self, rhs: Amount) -> Option<Amount> {
Amount::from_i64(self.0 + rhs.0).ok()
}
}
impl Add<Amount> for Option<Amount> {
type Output = Self;
fn add(self, rhs: Amount) -> Option<Amount> {
self.and_then(|lhs| lhs + rhs)
}
}
impl AddAssign<Amount> for Amount {
fn add_assign(&mut self, rhs: Amount) {
*self = *self + rhs
*self = (*self + rhs).expect("Addition must produce a valid amount value.")
}
}
impl Sub<Amount> for Amount {
type Output = Amount;
type Output = Option<Amount>;
fn sub(self, rhs: Amount) -> Amount {
Amount::from_i64(self.0 - rhs.0).expect("subtraction should remain in range")
fn sub(self, rhs: Amount) -> Option<Amount> {
Amount::from_i64(self.0 - rhs.0).ok()
}
}
impl Sub<Amount> for Option<Amount> {
type Output = Self;
fn sub(self, rhs: Amount) -> Option<Amount> {
self.and_then(|lhs| lhs - rhs)
}
}
impl SubAssign<Amount> for Amount {
fn sub_assign(&mut self, rhs: Amount) {
*self = *self - rhs
*self = (*self - rhs).expect("Subtraction must produce a valid amount value.")
}
}
impl Sum for Amount {
fn sum<I: Iterator<Item = Amount>>(iter: I) -> Amount {
iter.fold(Amount::zero(), Add::add)
impl Sum<Amount> for Option<Amount> {
fn sum<I: Iterator<Item = Amount>>(iter: I) -> Self {
iter.fold(Some(Amount::zero()), |acc, a| acc? + a)
}
}
@ -153,11 +178,23 @@ pub mod testing {
use super::{Amount, MAX_MONEY};
prop_compose! {
pub fn arb_amount()(amt in -MAX_MONEY..MAX_MONEY) -> Amount {
Amount::from_i64(amt).unwrap()
}
}
prop_compose! {
pub fn arb_nonnegative_amount()(amt in 0i64..MAX_MONEY) -> Amount {
Amount::from_i64(amt).unwrap()
}
}
prop_compose! {
pub fn arb_positive_amount()(amt in 1i64..MAX_MONEY) -> Amount {
Amount::from_i64(amt).unwrap()
}
}
}
#[cfg(test)]
@ -213,10 +250,9 @@ mod tests {
}
#[test]
#[should_panic]
fn add_panics_on_overflow() {
fn add_overflow() {
let v = Amount(MAX_MONEY);
let _sum = v + Amount(1);
assert_eq!(v + Amount(1), None)
}
#[test]
@ -227,10 +263,9 @@ mod tests {
}
#[test]
#[should_panic]
fn sub_panics_on_underflow() {
fn sub_underflow() {
let v = Amount(-MAX_MONEY);
let _diff = v - Amount(1);
assert_eq!(v - Amount(1), None)
}
#[test]