From 7abd1324dedca54f69873d3d1886f768b776e55b Mon Sep 17 00:00:00 2001 From: Kris Nuttycombe Date: Tue, 5 Sep 2023 11:58:24 -0600 Subject: [PATCH] Apply suggestions from code review Co-authored-by: Daira Emma Hopwood Co-authored-by: Jack Grigg --- zcash_client_backend/src/data_api.rs | 24 ++-- zcash_client_sqlite/src/chain.rs | 115 ++++++++---------- zcash_client_sqlite/src/testing.rs | 115 +++++++++--------- zcash_client_sqlite/src/wallet.rs | 70 +++-------- .../init/migrations/wallet_summaries.rs | 2 +- zcash_client_sqlite/src/wallet/sapling.rs | 61 ++++------ zcash_client_sqlite/src/wallet/scanning.rs | 6 +- .../src/transaction/components/amount.rs | 16 +++ 8 files changed, 178 insertions(+), 231 deletions(-) diff --git a/zcash_client_backend/src/data_api.rs b/zcash_client_backend/src/data_api.rs index df923d756..d42b0da95 100644 --- a/zcash_client_backend/src/data_api.rs +++ b/zcash_client_backend/src/data_api.rs @@ -51,8 +51,8 @@ pub enum NullifierQuery { All, } -/// Balance information for a value within a single shielded pool in an account. -#[derive(Debug, Clone, Copy)] +/// Balance information for a value within a single pool in an account. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Balance { /// The value in the account that may currently be spent; it is possible to compute witnesses /// for all the notes that comprise this value, and all of this value is confirmed to the @@ -86,7 +86,7 @@ impl Balance { /// Balance information for a single account. The sum of this struct's fields is the total balance /// of the wallet. -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct AccountBalance { /// The value of unspent Sapling outputs belonging to the account. pub sapling_balance: Balance, @@ -144,12 +144,12 @@ impl Ratio { /// can only be certain to be unspent in the case that [`Self::is_synced`] is true, and even in /// this circumstance it is possible that a newly created transaction could conflict with a /// not-yet-mined transaction in the mempool. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct WalletSummary { account_balances: BTreeMap, chain_tip_height: BlockHeight, fully_scanned_height: BlockHeight, - sapling_scan_progress: Option>, + scan_progress: Option>, } impl WalletSummary { @@ -158,13 +158,13 @@ impl WalletSummary { account_balances: BTreeMap, chain_tip_height: BlockHeight, fully_scanned_height: BlockHeight, - sapling_scan_progress: Option>, + scan_progress: Option>, ) -> Self { Self { account_balances, chain_tip_height, fully_scanned_height, - sapling_scan_progress, + scan_progress, } } @@ -178,20 +178,20 @@ impl WalletSummary { self.chain_tip_height } - /// Returns the height below which all blocks wallet have been scanned, ignoring blocks below - /// the wallet birthday. + /// Returns the height below which all blocks have been scanned by the wallet, ignoring blocks + /// below the wallet birthday. pub fn fully_scanned_height(&self) -> BlockHeight { self.fully_scanned_height } - /// Returns the progress of scanning Sapling outputs, in terms of the ratio between notes + /// Returns the progress of scanning shielded outputs, in terms of the ratio between notes /// scanned and the total number of notes added to the chain since the wallet birthday. /// /// This ratio should only be used to compute progress percentages, and the numerator and /// denominator should not be treated as authoritative note counts. Returns `None` if the /// wallet is unable to determine the size of the note commitment tree. - pub fn sapling_scan_progress(&self) -> Option> { - self.sapling_scan_progress + pub fn scan_progress(&self) -> Option> { + self.scan_progress } /// Returns whether or not wallet scanning is complete. diff --git a/zcash_client_sqlite/src/chain.rs b/zcash_client_sqlite/src/chain.rs index a69bcce22..950d2e2b3 100644 --- a/zcash_client_sqlite/src/chain.rs +++ b/zcash_client_sqlite/src/chain.rs @@ -326,7 +326,10 @@ mod tests { use zcash_primitives::{ block::BlockHash, - transaction::{components::Amount, fees::zip317::FeeRule}, + transaction::{ + components::{amount::NonNegativeAmount, Amount}, + fees::zip317::FeeRule, + }, zip32::ExtendedSpendingKey, }; @@ -344,7 +347,7 @@ mod tests { use crate::{ testing::{AddressType, TestBuilder}, - wallet::{get_balance, truncate_to_height}, + wallet::truncate_to_height, AccountId, }; @@ -441,24 +444,21 @@ mod tests { let dfvk = st.test_account_sapling().unwrap(); - // Account balance should be zero - assert_eq!( - get_balance(&st.wallet().conn, AccountId::from(0)).unwrap(), - Amount::zero() - ); + // Wallet summary is not yet available + assert_eq!(st.get_wallet_summary(0), None); // Create fake CompactBlocks sending value to the address - let value = Amount::from_u64(5).unwrap(); - let value2 = Amount::from_u64(7).unwrap(); - let (h, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value); - st.generate_next_block(&dfvk, AddressType::DefaultExternal, value2); + let value = NonNegativeAmount::from_u64(5).unwrap(); + let value2 = NonNegativeAmount::from_u64(7).unwrap(); + let (h, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value.into()); + st.generate_next_block(&dfvk, AddressType::DefaultExternal, value2.into()); // Scan the cache st.scan_cached_blocks(h, 2); // Account balance should reflect both received notes assert_eq!( - get_balance(&st.wallet().conn, AccountId::from(0)).unwrap(), + st.get_total_balance(AccountId::from(0)), (value + value2).unwrap() ); @@ -469,7 +469,7 @@ mod tests { // Account balance should be unaltered assert_eq!( - get_balance(&st.wallet().conn, AccountId::from(0)).unwrap(), + st.get_total_balance(AccountId::from(0)), (value + value2).unwrap() ); @@ -479,17 +479,14 @@ mod tests { .unwrap(); // Account balance should only contain the first received note - assert_eq!( - get_balance(&st.wallet().conn, AccountId::from(0)).unwrap(), - value - ); + assert_eq!(st.get_total_balance(AccountId::from(0)), value); // Scan the cache again st.scan_cached_blocks(h, 2); // Account balance should again reflect both received notes assert_eq!( - get_balance(&st.wallet().conn, AccountId::from(0)).unwrap(), + st.get_total_balance(AccountId::from(0)), (value + value2).unwrap() ); } @@ -505,17 +502,14 @@ mod tests { let dfvk = st.test_account_sapling().unwrap(); // Create a block with height SAPLING_ACTIVATION_HEIGHT - let value = Amount::from_u64(50000).unwrap(); - let (h1, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value); + let value = NonNegativeAmount::from_u64(50000).unwrap(); + let (h1, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value.into()); st.scan_cached_blocks(h1, 1); - assert_eq!( - get_balance(&st.wallet().conn, AccountId::from(0)).unwrap(), - value - ); + assert_eq!(st.get_total_balance(AccountId::from(0)), value); // Create blocks to reach SAPLING_ACTIVATION_HEIGHT + 2 - let (h2, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value); - let (h3, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value); + let (h2, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value.into()); + let (h3, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value.into()); // Scan the later block first st.scan_cached_blocks(h3, 1); @@ -523,8 +517,8 @@ mod tests { // Now scan the block of height SAPLING_ACTIVATION_HEIGHT + 1 st.scan_cached_blocks(h2, 1); assert_eq!( - get_balance(&st.wallet().conn, AccountId::from(0)).unwrap(), - Amount::from_u64(150_000).unwrap() + st.get_total_balance(AccountId::from(0)), + NonNegativeAmount::from_u64(150_000).unwrap() ); // We can spend the received notes @@ -562,35 +556,29 @@ mod tests { let dfvk = st.test_account_sapling().unwrap(); - // Account balance should be zero - assert_eq!( - get_balance(&st.wallet().conn, AccountId::from(0)).unwrap(), - Amount::zero() - ); + // Wallet summary is not yet available + assert_eq!(st.get_wallet_summary(0), None); // Create a fake CompactBlock sending value to the address - let value = Amount::from_u64(5).unwrap(); - let (h1, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value); + let value = NonNegativeAmount::from_u64(5).unwrap(); + let (h1, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value.into()); // Scan the cache st.scan_cached_blocks(h1, 1); // Account balance should reflect the received note - assert_eq!( - get_balance(&st.wallet().conn, AccountId::from(0)).unwrap(), - value - ); + assert_eq!(st.get_total_balance(AccountId::from(0)), value); // Create a second fake CompactBlock sending more value to the address - let value2 = Amount::from_u64(7).unwrap(); - let (h2, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value2); + let value2 = NonNegativeAmount::from_u64(7).unwrap(); + let (h2, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value2.into()); // Scan the cache again st.scan_cached_blocks(h2, 1); // Account balance should reflect both received notes assert_eq!( - get_balance(&st.wallet().conn, AccountId::from(0)).unwrap(), + st.get_total_balance(AccountId::from(0)), (value + value2).unwrap() ); } @@ -603,38 +591,33 @@ mod tests { .build(); let dfvk = st.test_account_sapling().unwrap(); - // Account balance should be zero - assert_eq!( - get_balance(&st.wallet().conn, AccountId::from(0)).unwrap(), - Amount::zero() - ); + // Wallet summary is not yet available + assert_eq!(st.get_wallet_summary(0), None); // Create a fake CompactBlock sending value to the address - let value = Amount::from_u64(5).unwrap(); + let value = NonNegativeAmount::from_u64(5).unwrap(); let (received_height, _, nf) = - st.generate_next_block(&dfvk, AddressType::DefaultExternal, value); + st.generate_next_block(&dfvk, AddressType::DefaultExternal, value.into()); // Scan the cache st.scan_cached_blocks(received_height, 1); // Account balance should reflect the received note - assert_eq!( - get_balance(&st.wallet().conn, AccountId::from(0)).unwrap(), - value - ); + assert_eq!(st.get_total_balance(AccountId::from(0)), value); // Create a second fake CompactBlock spending value from the address let extsk2 = ExtendedSpendingKey::master(&[0]); let to2 = extsk2.default_address().1; - let value2 = Amount::from_u64(2).unwrap(); - let (spent_height, _) = st.generate_next_block_spending(&dfvk, (nf, value), to2, value2); + let value2 = NonNegativeAmount::from_u64(2).unwrap(); + let (spent_height, _) = + st.generate_next_block_spending(&dfvk, (nf, value.into()), to2, value2.into()); // Scan the cache again st.scan_cached_blocks(spent_height, 1); // Account balance should equal the change assert_eq!( - get_balance(&st.wallet().conn, AccountId::from(0)).unwrap(), + st.get_total_balance(AccountId::from(0)), (value - value2).unwrap() ); } @@ -648,29 +631,27 @@ mod tests { let dfvk = st.test_account_sapling().unwrap(); - // Account balance should be zero - assert_eq!( - get_balance(&st.wallet().conn, AccountId::from(0)).unwrap(), - Amount::zero() - ); + // Wallet summary is not yet available + assert_eq!(st.get_wallet_summary(0), None); // Create a fake CompactBlock sending value to the address - let value = Amount::from_u64(5).unwrap(); + let value = NonNegativeAmount::from_u64(5).unwrap(); let (received_height, _, nf) = - st.generate_next_block(&dfvk, AddressType::DefaultExternal, value); + st.generate_next_block(&dfvk, AddressType::DefaultExternal, value.into()); // Create a second fake CompactBlock spending value from the address let extsk2 = ExtendedSpendingKey::master(&[0]); let to2 = extsk2.default_address().1; - let value2 = Amount::from_u64(2).unwrap(); - let (spent_height, _) = st.generate_next_block_spending(&dfvk, (nf, value), to2, value2); + let value2 = NonNegativeAmount::from_u64(2).unwrap(); + let (spent_height, _) = + st.generate_next_block_spending(&dfvk, (nf, value.into()), to2, value2.into()); // Scan the spending block first. st.scan_cached_blocks(spent_height, 1); // Account balance should equal the change assert_eq!( - get_balance(&st.wallet().conn, AccountId::from(0)).unwrap(), + st.get_total_balance(AccountId::from(0)), (value - value2).unwrap() ); @@ -679,7 +660,7 @@ mod tests { // Account balance should be the same. assert_eq!( - get_balance(&st.wallet().conn, AccountId::from(0)).unwrap(), + st.get_total_balance(AccountId::from(0)), (value - value2).unwrap() ); } diff --git a/zcash_client_sqlite/src/testing.rs b/zcash_client_sqlite/src/testing.rs index 7fbb46e19..0034ff4bb 100644 --- a/zcash_client_sqlite/src/testing.rs +++ b/zcash_client_sqlite/src/testing.rs @@ -14,6 +14,7 @@ use tempfile::NamedTempFile; #[cfg(feature = "unstable")] use tempfile::TempDir; +use zcash_client_backend::data_api::AccountBalance; #[allow(deprecated)] use zcash_client_backend::{ address::RecipientAddress, @@ -290,66 +291,6 @@ where limit, ) } - - pub(crate) fn get_total_balance(&self, account: AccountId) -> NonNegativeAmount { - get_wallet_summary(&self.wallet().conn, 0, &SubtreeScanProgress) - .unwrap() - .unwrap() - .account_balances() - .get(&account) - .unwrap() - .total() - } - - pub(crate) fn get_spendable_balance( - &self, - account: AccountId, - min_confirmations: u32, - ) -> NonNegativeAmount { - let binding = - get_wallet_summary(&self.wallet().conn, min_confirmations, &SubtreeScanProgress) - .unwrap() - .unwrap(); - let balance = binding.account_balances().get(&account).unwrap(); - - balance.sapling_balance.spendable_value - } - - pub(crate) fn get_pending_shielded_balance( - &self, - account: AccountId, - min_confirmations: u32, - ) -> NonNegativeAmount { - let binding = - get_wallet_summary(&self.wallet().conn, min_confirmations, &SubtreeScanProgress) - .unwrap() - .unwrap(); - let balance = binding.account_balances().get(&account).unwrap(); - - (balance.sapling_balance.value_pending_spendability - + balance.sapling_balance.change_pending_confirmation) - .unwrap() - } - - pub(crate) fn get_pending_change( - &self, - account: AccountId, - min_confirmations: u32, - ) -> NonNegativeAmount { - let binding = - get_wallet_summary(&self.wallet().conn, min_confirmations, &SubtreeScanProgress) - .unwrap() - .unwrap(); - let balance = binding.account_balances().get(&account).unwrap(); - - balance.sapling_balance.change_pending_confirmation - } - - pub(crate) fn get_wallet_summary(&self, min_confirmations: u32) -> WalletSummary { - get_wallet_summary(&self.wallet().conn, min_confirmations, &SubtreeScanProgress) - .unwrap() - .unwrap() - } } impl TestState { @@ -594,6 +535,60 @@ impl TestState { min_confirmations, ) } + + fn with_account_balance T>( + &self, + account: AccountId, + min_confirmations: u32, + f: F, + ) -> T { + let binding = + get_wallet_summary(&self.wallet().conn, min_confirmations, &SubtreeScanProgress) + .unwrap() + .unwrap(); + + f(binding.account_balances().get(&account).unwrap()) + } + + pub(crate) fn get_total_balance(&self, account: AccountId) -> NonNegativeAmount { + self.with_account_balance(account, 0, |balance| balance.total()) + } + + pub(crate) fn get_spendable_balance( + &self, + account: AccountId, + min_confirmations: u32, + ) -> NonNegativeAmount { + self.with_account_balance(account, min_confirmations, |balance| { + balance.sapling_balance.spendable_value + }) + } + + pub(crate) fn get_pending_shielded_balance( + &self, + account: AccountId, + min_confirmations: u32, + ) -> NonNegativeAmount { + self.with_account_balance(account, min_confirmations, |balance| { + balance.sapling_balance.value_pending_spendability + + balance.sapling_balance.change_pending_confirmation + }) + .unwrap() + } + + pub(crate) fn get_pending_change( + &self, + account: AccountId, + min_confirmations: u32, + ) -> NonNegativeAmount { + self.with_account_balance(account, min_confirmations, |balance| { + balance.sapling_balance.change_pending_confirmation + }) + } + + pub(crate) fn get_wallet_summary(&self, min_confirmations: u32) -> Option { + get_wallet_summary(&self.wallet().conn, min_confirmations, &SubtreeScanProgress).unwrap() + } } #[allow(dead_code)] diff --git a/zcash_client_sqlite/src/wallet.rs b/zcash_client_sqlite/src/wallet.rs index 41b775102..d929bbb46 100644 --- a/zcash_client_sqlite/src/wallet.rs +++ b/zcash_client_sqlite/src/wallet.rs @@ -493,34 +493,6 @@ pub(crate) fn is_valid_account_extfvk( }) } -/// Returns the balance for the account, including all mined unspent notes that we know -/// about. -/// -/// WARNING: This balance is potentially unreliable, as mined notes may become unmined due -/// to chain reorgs. You should generally not show this balance to users without some -/// caveat. Use [`get_balance_at`] where you need a more reliable indication of the -/// wallet balance. -#[cfg(test)] -pub(crate) fn get_balance( - conn: &rusqlite::Connection, - account: AccountId, -) -> Result { - let balance = conn.query_row( - "SELECT SUM(value) FROM sapling_received_notes - INNER JOIN transactions ON transactions.id_tx = sapling_received_notes.tx - WHERE account = ? AND spent IS NULL AND transactions.block IS NOT NULL", - [u32::from(account)], - |row| row.get(0).or(Ok(0)), - )?; - - match Amount::from_i64(balance) { - Ok(amount) if !amount.is_negative() => Ok(amount), - _ => Err(SqliteClientError::CorruptedData( - "Sum of values in sapling_received_notes is out of range".to_string(), - )), - } -} - pub(crate) trait ScanProgress { fn sapling_scan_progress( &self, @@ -629,7 +601,7 @@ pub(crate) fn get_wallet_summary( let fully_scanned_height = block_fully_scanned(conn)?.map_or(birthday_height - 1, |m| m.block_height()); - let summary_height = chain_tip_height + 1 - min_confirmations; + let summary_height = (chain_tip_height + 1).saturating_sub(std::cmp::max(min_confirmations, 1)); let sapling_scan_progress = progress.sapling_scan_progress( conn, @@ -638,10 +610,10 @@ pub(crate) fn get_wallet_summary( chain_tip_height, )?; - // If the shard containing the anchor is contains any unscanned ranges below the summary - // height, none of our balance is currently spendable. + // If the shard containing the summary height contains any unscanned ranges that start below or + // including that height, none of our balance is currently spendable. let any_spendable = conn.query_row( - "SELECT EXISTS( + "SELECT NOT EXISTS( SELECT 1 FROM v_sapling_shard_unscanned_ranges WHERE :summary_height BETWEEN subtree_start_height @@ -649,11 +621,11 @@ pub(crate) fn get_wallet_summary( AND block_range_start <= :summary_height )", named_params![":summary_height": u32::from(summary_height)], - |row| row.get::<_, bool>(0).map(|b| !b), + |row| row.get::<_, bool>(0), )?; let mut stmt_select_notes = conn.prepare_cached( - "SELECT n.account, n.value, n.is_change, scan_state.max_priority, t.block, t.expiry_height + "SELECT n.account, n.value, n.is_change, scan_state.max_priority, t.block FROM sapling_received_notes n JOIN transactions t ON t.id_tx = n.tx LEFT OUTER JOIN v_sapling_shards_scan_state scan_state @@ -696,9 +668,7 @@ pub(crate) fn get_wallet_summary( }, )?; - let received_height = row - .get::<_, Option>(4) - .map(|opt| opt.map(BlockHeight::from))?; + let received_height = row.get::<_, Option>(4)?.map(BlockHeight::from); let is_spendable = any_spendable && received_height.iter().any(|h| h <= &summary_height) @@ -763,7 +733,10 @@ pub(crate) fn get_wallet_summary( account_balances .entry(account) - .and_modify(|bal| bal.unshielded = value) + .and_modify(|bal| { + bal.unshielded = + (bal.unshielded + value).expect("Unshielded value cannot overflow") + }) .or_insert(AccountBalance { sapling_balance: Balance::ZERO, unshielded: value, @@ -1924,8 +1897,6 @@ mod tests { use crate::{testing::TestBuilder, AccountId}; - use super::get_balance; - #[cfg(feature = "transparent-inputs")] use { incrementalmerkletree::frontier::Frontier, @@ -1945,11 +1916,8 @@ mod tests { .with_test_account(AccountBirthday::from_sapling_activation) .build(); - // The account should be empty - assert_eq!( - get_balance(&st.wallet().conn, AccountId::from(0)).unwrap(), - Amount::zero() - ); + // The account should have no summary information + assert_eq!(st.get_wallet_summary(0), None); // We can't get an anchor height, as we have not scanned any blocks. assert_eq!( @@ -1959,15 +1927,17 @@ mod tests { None ); - // An invalid account has zero balance + // The default address is set for the test account + assert_matches!( + st.wallet().get_current_address(AccountId::from(0)), + Ok(Some(_)) + ); + + // No default address is set for an un-initialized account assert_matches!( st.wallet().get_current_address(AccountId::from(1)), Ok(None) ); - assert_eq!( - get_balance(&st.wallet().conn, AccountId::from(0)).unwrap(), - Amount::zero() - ); } #[test] diff --git a/zcash_client_sqlite/src/wallet/init/migrations/wallet_summaries.rs b/zcash_client_sqlite/src/wallet/init/migrations/wallet_summaries.rs index 5807bae65..2be9ca15a 100644 --- a/zcash_client_sqlite/src/wallet/init/migrations/wallet_summaries.rs +++ b/zcash_client_sqlite/src/wallet/init/migrations/wallet_summaries.rs @@ -43,7 +43,7 @@ impl RusqliteMigration for Migration { )?; transaction.execute_batch( - // set the number of outputs everywhere that we have sequential Sapling blocks + // set the number of outputs everywhere that we have sequential blocks "CREATE TEMPORARY TABLE block_deltas AS SELECT cur.height AS height, diff --git a/zcash_client_sqlite/src/wallet/sapling.rs b/zcash_client_sqlite/src/wallet/sapling.rs index 4e1f3702a..d2a5271c1 100644 --- a/zcash_client_sqlite/src/wallet/sapling.rs +++ b/zcash_client_sqlite/src/wallet/sapling.rs @@ -465,7 +465,7 @@ pub(crate) mod tests { use crate::{ error::SqliteClientError, testing::{AddressType, BlockCache, TestBuilder, TestState}, - wallet::{commitment_tree, get_balance}, + wallet::commitment_tree, AccountId, NoteId, ReceivedNoteId, }; @@ -495,19 +495,12 @@ pub(crate) mod tests { let dfvk = st.test_account_sapling().unwrap(); // Add funds to the wallet in a single note - let value = Amount::from_u64(60000).unwrap(); - let (h, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value); + let value = NonNegativeAmount::from_u64(60000).unwrap(); + let (h, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value.into()); st.scan_cached_blocks(h, 1); // Verified balance matches total balance - assert_eq!( - get_balance(&st.wallet().conn, AccountId::from(0)).unwrap(), - value - ); - assert_eq!( - st.get_total_balance(account), - NonNegativeAmount::try_from(value).unwrap() - ); + assert_eq!(st.get_total_balance(account), value); let to_extsk = ExtendedSpendingKey::master(&[]); let to: RecipientAddress = to_extsk.default_address().1.into(); @@ -664,11 +657,8 @@ pub(crate) mod tests { let dfvk = st.test_account_sapling().unwrap(); let to = dfvk.default_address().1.into(); - // Account balance should be zero - assert_eq!( - get_balance(&st.wallet().conn, AccountId::from(0)).unwrap(), - Amount::zero() - ); + // Wallet summary is not yet available + assert_eq!(st.get_wallet_summary(0), None); // We cannot do anything if we aren't synchronised assert_matches!( @@ -700,10 +690,6 @@ pub(crate) mod tests { st.scan_cached_blocks(h1, 1); // Verified balance matches total balance - assert_eq!( - get_balance(&st.wallet().conn, account).unwrap(), - value.into() - ); assert_eq!(st.get_total_balance(account), value); // Value is considered pending @@ -711,7 +697,10 @@ pub(crate) mod tests { // Wallet is fully scanned let summary = st.get_wallet_summary(1); - assert_eq!(summary.sapling_scan_progress(), Some(Ratio::new(1, 1))); + assert_eq!( + summary.and_then(|s| s.scan_progress()), + Some(Ratio::new(1, 1)) + ); // Add more funds to the wallet in a second note let (h2, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value.into()); @@ -719,17 +708,16 @@ pub(crate) mod tests { // Verified balance does not include the second note let total = (value + value).unwrap(); - assert_eq!( - get_balance(&st.wallet().conn, account).unwrap(), - total.into() - ); assert_eq!(st.get_spendable_balance(account, 2), value); assert_eq!(st.get_pending_shielded_balance(account, 2), value); assert_eq!(st.get_total_balance(account), total); // Wallet is still fully scanned let summary = st.get_wallet_summary(1); - assert_eq!(summary.sapling_scan_progress(), Some(Ratio::new(2, 2))); + assert_eq!( + summary.and_then(|s| s.scan_progress()), + Some(Ratio::new(2, 2)) + ); // Spend fails because there are insufficient verified notes let extsk2 = ExtendedSpendingKey::master(&[]); @@ -805,10 +793,10 @@ pub(crate) mod tests { let dfvk = st.test_account_sapling().unwrap(); // Add funds to the wallet in a single note - let value = Amount::from_u64(50000).unwrap(); - let (h1, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value); + let value = NonNegativeAmount::from_u64(50000).unwrap(); + let (h1, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value.into()); st.scan_cached_blocks(h1, 1); - assert_eq!(get_balance(&st.wallet().conn, account).unwrap(), value); + assert_eq!(st.get_total_balance(account), value); // Send some of the funds to another address let extsk2 = ExtendedSpendingKey::master(&[]); @@ -848,7 +836,7 @@ pub(crate) mod tests { st.generate_next_block( &ExtendedSpendingKey::master(&[i as u8]).to_diversifiable_full_viewing_key(), AddressType::DefaultExternal, - value, + value.into(), ); } st.scan_cached_blocks(h1 + 1, 41); @@ -874,7 +862,7 @@ pub(crate) mod tests { let (h43, _, _) = st.generate_next_block( &ExtendedSpendingKey::master(&[42]).to_diversifiable_full_viewing_key(), AddressType::DefaultExternal, - value, + value.into(), ); st.scan_cached_blocks(h43, 1); @@ -901,10 +889,10 @@ pub(crate) mod tests { let dfvk = st.test_account_sapling().unwrap(); // Add funds to the wallet in a single note - let value = Amount::from_u64(50000).unwrap(); - let (h1, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value); + let value = NonNegativeAmount::from_u64(50000).unwrap(); + let (h1, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value.into()); st.scan_cached_blocks(h1, 1); - assert_eq!(get_balance(&st.wallet().conn, account).unwrap(), value); + assert_eq!(st.get_total_balance(account), value); let extsk2 = ExtendedSpendingKey::master(&[]); let addr2 = extsk2.default_address().1; @@ -975,7 +963,7 @@ pub(crate) mod tests { st.generate_next_block( &ExtendedSpendingKey::master(&[i as u8]).to_diversifiable_full_viewing_key(), AddressType::DefaultExternal, - value, + value.into(), ); } st.scan_cached_blocks(h1 + 1, 42); @@ -1004,7 +992,6 @@ pub(crate) mod tests { st.scan_cached_blocks(h, 1); // Verified balance matches total balance - assert_eq!(get_balance(&st.wallet().conn, account).unwrap(), value); assert_eq!( st.get_total_balance(account), NonNegativeAmount::try_from(value).unwrap() @@ -1040,7 +1027,6 @@ pub(crate) mod tests { st.scan_cached_blocks(h, 1); // Verified balance matches total balance - assert_eq!(get_balance(&st.wallet().conn, account).unwrap(), value); assert_eq!( st.get_total_balance(account), NonNegativeAmount::try_from(value).unwrap() @@ -1096,7 +1082,6 @@ pub(crate) mod tests { // Verified balance matches total balance let total = Amount::from_u64(60000).unwrap(); - assert_eq!(get_balance(&st.wallet().conn, account).unwrap(), total); assert_eq!( st.get_total_balance(account), NonNegativeAmount::try_from(total).unwrap() diff --git a/zcash_client_sqlite/src/wallet/scanning.rs b/zcash_client_sqlite/src/wallet/scanning.rs index fba6cd4d9..47066e162 100644 --- a/zcash_client_sqlite/src/wallet/scanning.rs +++ b/zcash_client_sqlite/src/wallet/scanning.rs @@ -1692,7 +1692,7 @@ mod tests { // We have scan ranges and a subtree, but have scanned no blocks. let summary = st.get_wallet_summary(1); - assert_eq!(summary.sapling_scan_progress(), None); + assert_eq!(summary.and_then(|s| s.scan_progress()), None); // Set up prior chain state. This simulates us having imported a wallet // with a birthday 520 blocks below the chain tip. @@ -1732,7 +1732,7 @@ mod tests { // wallet birthday but before the end of the shard. let summary = st.get_wallet_summary(1); assert_eq!( - summary.sapling_scan_progress(), + summary.and_then(|s| s.scan_progress()), Some(Ratio::new(1, 0x1 << SAPLING_SHARD_HEIGHT)) ); @@ -1776,7 +1776,7 @@ mod tests { // shards worth of notes to scan. let summary = st.get_wallet_summary(1); assert_eq!( - summary.sapling_scan_progress(), + summary.and_then(|s| s.scan_progress()), Some(Ratio::new(1, 0x1 << (SAPLING_SHARD_HEIGHT + 1))) ); } diff --git a/zcash_primitives/src/transaction/components/amount.rs b/zcash_primitives/src/transaction/components/amount.rs index fc7e53a6e..dbb3e85d7 100644 --- a/zcash_primitives/src/transaction/components/amount.rs +++ b/zcash_primitives/src/transaction/components/amount.rs @@ -290,6 +290,22 @@ impl Add for Option { } } +impl Sub for NonNegativeAmount { + type Output = Option; + + fn sub(self, rhs: NonNegativeAmount) -> Option { + (self.0 - rhs.0).and_then(|amt| NonNegativeAmount::try_from(amt).ok()) + } +} + +impl Sub for Option { + type Output = Self; + + fn sub(self, rhs: NonNegativeAmount) -> Option { + self.and_then(|lhs| lhs - rhs) + } +} + /// A type for balance violations in amount addition and subtraction /// (overflow and underflow of allowed ranges) #[derive(Copy, Clone, Debug, PartialEq, Eq)]