diff --git a/zcash_client_backend/src/wallet.rs b/zcash_client_backend/src/wallet.rs index c120ba599..22df1cba1 100644 --- a/zcash_client_backend/src/wallet.rs +++ b/zcash_client_backend/src/wallet.rs @@ -25,6 +25,7 @@ pub struct WalletTx { pub struct WalletShieldedSpend { pub index: usize, pub nf: Vec, + pub account: usize, } /// A subset of an [`OutputDescription`] relevant to wallets and light clients. @@ -37,4 +38,5 @@ pub struct WalletShieldedOutput { pub account: usize, pub note: Note, pub to: PaymentAddress, + pub is_change: bool, } diff --git a/zcash_client_backend/src/welding_rig.rs b/zcash_client_backend/src/welding_rig.rs index 0b5a01f5f..b8310e35f 100644 --- a/zcash_client_backend/src/welding_rig.rs +++ b/zcash_client_backend/src/welding_rig.rs @@ -2,6 +2,7 @@ use ff::{PrimeField, PrimeFieldRepr}; use pairing::bls12_381::{Bls12, Fr, FrRepr}; +use std::collections::HashSet; use zcash_primitives::{ jubjub::{edwards, fs::Fs}, merkle_tree::{CommitmentTree, IncrementalWitness}, @@ -25,6 +26,7 @@ use crate::wallet::{WalletShieldedOutput, WalletShieldedSpend, WalletTx}; fn scan_output( (index, output): (usize, CompactOutput), ivks: &[Fs], + spent_from_accounts: &HashSet, tree: &mut CommitmentTree, existing_witnesses: &mut [&mut IncrementalWitness], new_witnesses: &mut [IncrementalWitness], @@ -64,6 +66,14 @@ fn scan_output( None => continue, }; + // A note is marked as "change" if the account that received it + // also spent notes in the same transaction. This will catch, + // for instance: + // - Change created by spending fractions of notes. + // - Notes created by consolidation transactions. + // - Notes sent from one account to itself. + let is_change = spent_from_accounts.contains(&account); + return Some(( WalletShieldedOutput { index, @@ -72,6 +82,7 @@ fn scan_output( account, note, to, + is_change, }, IncrementalWitness::from_tree(tree), )); @@ -89,7 +100,7 @@ fn scan_output( pub fn scan_block( block: CompactBlock, extfvks: &[ExtendedFullViewingKey], - nullifiers: &[&[u8]], + nullifiers: &[(&[u8], usize)], tree: &mut CommitmentTree, existing_witnesses: &mut [&mut IncrementalWitness], ) -> Vec<(WalletTx, Vec>)> { @@ -101,29 +112,45 @@ pub fn scan_block( let num_outputs = tx.outputs.len(); // Check for spent notes - let shielded_spends: Vec<_> = tx - .spends - .into_iter() - .enumerate() - .filter_map(|(index, spend)| { - if nullifiers.contains(&&spend.nf[..]) { - Some(WalletShieldedSpend { - index, - nf: spend.nf, - }) - } else { - None - } - }) - .collect(); + let shielded_spends: Vec<_> = + tx.spends + .into_iter() + .enumerate() + .filter_map(|(index, spend)| { + if let Some(account) = nullifiers.iter().find_map(|&(nf, acc)| { + if nf == &spend.nf[..] { + Some(acc) + } else { + None + } + }) { + Some(WalletShieldedSpend { + index, + nf: spend.nf, + account, + }) + } else { + None + } + }) + .collect(); + + // Collect the set of accounts that were spent from in this transaction + let spent_from_accounts: HashSet<_> = + shielded_spends.iter().map(|spend| spend.account).collect(); // Check for incoming notes while incrementing tree and witnesses let mut shielded_outputs = vec![]; let mut new_witnesses = vec![]; for to_scan in tx.outputs.into_iter().enumerate() { - if let Some((output, new_witness)) = - scan_output(to_scan, &ivks, tree, existing_witnesses, &mut new_witnesses) - { + if let Some((output, new_witness)) = scan_output( + to_scan, + &ivks, + &spent_from_accounts, + tree, + existing_witnesses, + &mut new_witnesses, + ) { shielded_outputs.push(output); new_witnesses.push(new_witness); } @@ -292,12 +319,13 @@ mod tests { let extsk = ExtendedSpendingKey::master(&[]); let extfvk = ExtendedFullViewingKey::from(&extsk); let nf = [7; 32]; + let account = 12; let cb = fake_compact_block(1, nf, extfvk, Amount::from_u64(5).unwrap()); assert_eq!(cb.vtx.len(), 2); let mut tree = CommitmentTree::new(); - let txs = scan_block(cb, &[], &[&nf], &mut tree, &mut []); + let txs = scan_block(cb, &[], &[(&nf, account)], &mut tree, &mut []); assert_eq!(txs.len(), 1); let (tx, new_witnesses) = &txs[0]; @@ -307,6 +335,7 @@ mod tests { assert_eq!(tx.shielded_outputs.len(), 0); assert_eq!(tx.shielded_spends[0].index, 0); assert_eq!(tx.shielded_spends[0].nf, nf); + assert_eq!(tx.shielded_spends[0].account, account); assert_eq!(new_witnesses.len(), 0); } }