From 7acd47eae84421c92da5d2a43b908bc697041c3a Mon Sep 17 00:00:00 2001 From: Kris Nuttycombe Date: Thu, 18 Aug 2022 16:40:27 -0600 Subject: [PATCH 1/3] Copy parallel batch decryption data types from zcash/zcash This is an unmodified copy of https://github.com/zcash/zcash/blob/7d1e14ac3d876f2f019e22c1adbf15c5949e9f84/src/rust/src/wallet_scanner.rs Co-authored-by: str4d --- zcash_client_backend/src/scan.rs | 526 +++++++++++++++++++++++++++++++ 1 file changed, 526 insertions(+) create mode 100644 zcash_client_backend/src/scan.rs diff --git a/zcash_client_backend/src/scan.rs b/zcash_client_backend/src/scan.rs new file mode 100644 index 000000000..390c34614 --- /dev/null +++ b/zcash_client_backend/src/scan.rs @@ -0,0 +1,526 @@ +use core::fmt; +use std::collections::HashMap; +use std::io; +use std::mem; + +use crossbeam_channel as channel; +use group::GroupEncoding; +use zcash_note_encryption::{batch, BatchDomain, Domain, ShieldedOutput, ENC_CIPHERTEXT_SIZE}; +use zcash_primitives::{ + block::BlockHash, + consensus, constants, + sapling::{self, note_encryption::SaplingDomain}, + transaction::{ + components::{sapling::GrothProofBytes, OutputDescription}, + Transaction, TxId, + }, +}; + +#[cxx::bridge] +mod ffi { + #[namespace = "wallet"] + struct SaplingDecryptionResult { + txid: [u8; 32], + output: u32, + ivk: [u8; 32], + diversifier: [u8; 11], + pk_d: [u8; 32], + } + + #[namespace = "wallet"] + extern "Rust" { + type Network; + type BatchScanner; + type BatchResult; + + fn network( + network: &str, + overwinter: i32, + sapling: i32, + blossom: i32, + heartwood: i32, + canopy: i32, + nu5: i32, + ) -> Result>; + + fn init_batch_scanner( + network: &Network, + sapling_ivks: &[[u8; 32]], + ) -> Result>; + fn add_transaction( + self: &mut BatchScanner, + block_tag: [u8; 32], + tx_bytes: &[u8], + height: u32, + ) -> Result<()>; + fn flush(self: &mut BatchScanner); + fn collect_results( + self: &mut BatchScanner, + block_tag: [u8; 32], + txid: [u8; 32], + ) -> Box; + + fn get_sapling(self: &BatchResult) -> Vec; + } +} + +/// The minimum number of outputs to trial decrypt in a batch. +/// +/// TODO: Tune this. +const BATCH_SIZE_THRESHOLD: usize = 20; + +/// Chain parameters for the networks supported by `zcashd`. +#[derive(Clone, Copy)] +pub enum Network { + Consensus(consensus::Network), + RegTest { + overwinter: Option, + sapling: Option, + blossom: Option, + heartwood: Option, + canopy: Option, + nu5: Option, + }, +} + +/// Constructs a `Network` from the given network string. +/// +/// The heights are only for constructing a regtest network, and are ignored otherwise. +fn network( + network: &str, + overwinter: i32, + sapling: i32, + blossom: i32, + heartwood: i32, + canopy: i32, + nu5: i32, +) -> Result, &'static str> { + let i32_to_optional_height = |n: i32| { + if n.is_negative() { + None + } else { + Some(consensus::BlockHeight::from_u32(n.unsigned_abs())) + } + }; + + let params = match network { + "main" => Network::Consensus(consensus::Network::MainNetwork), + "test" => Network::Consensus(consensus::Network::TestNetwork), + "regtest" => Network::RegTest { + overwinter: i32_to_optional_height(overwinter), + sapling: i32_to_optional_height(sapling), + blossom: i32_to_optional_height(blossom), + heartwood: i32_to_optional_height(heartwood), + canopy: i32_to_optional_height(canopy), + nu5: i32_to_optional_height(nu5), + }, + _ => return Err("Unsupported network kind"), + }; + + Ok(Box::new(params)) +} + +impl consensus::Parameters for Network { + fn activation_height(&self, nu: consensus::NetworkUpgrade) -> Option { + match self { + Self::Consensus(params) => params.activation_height(nu), + Self::RegTest { + overwinter, + sapling, + blossom, + heartwood, + canopy, + nu5, + } => match nu { + consensus::NetworkUpgrade::Overwinter => *overwinter, + consensus::NetworkUpgrade::Sapling => *sapling, + consensus::NetworkUpgrade::Blossom => *blossom, + consensus::NetworkUpgrade::Heartwood => *heartwood, + consensus::NetworkUpgrade::Canopy => *canopy, + consensus::NetworkUpgrade::Nu5 => *nu5, + }, + } + } + + fn coin_type(&self) -> u32 { + match self { + Self::Consensus(params) => params.coin_type(), + Self::RegTest { .. } => constants::regtest::COIN_TYPE, + } + } + + fn hrp_sapling_extended_spending_key(&self) -> &str { + match self { + Self::Consensus(params) => params.hrp_sapling_extended_spending_key(), + Self::RegTest { .. } => constants::regtest::HRP_SAPLING_EXTENDED_SPENDING_KEY, + } + } + + fn hrp_sapling_extended_full_viewing_key(&self) -> &str { + match self { + Self::Consensus(params) => params.hrp_sapling_extended_full_viewing_key(), + Self::RegTest { .. } => constants::regtest::HRP_SAPLING_EXTENDED_FULL_VIEWING_KEY, + } + } + + fn hrp_sapling_payment_address(&self) -> &str { + match self { + Self::Consensus(params) => params.hrp_sapling_payment_address(), + Self::RegTest { .. } => constants::regtest::HRP_SAPLING_PAYMENT_ADDRESS, + } + } + + fn b58_pubkey_address_prefix(&self) -> [u8; 2] { + match self { + Self::Consensus(params) => params.b58_pubkey_address_prefix(), + Self::RegTest { .. } => constants::regtest::B58_PUBKEY_ADDRESS_PREFIX, + } + } + + fn b58_script_address_prefix(&self) -> [u8; 2] { + match self { + Self::Consensus(params) => params.b58_script_address_prefix(), + Self::RegTest { .. } => constants::regtest::B58_SCRIPT_ADDRESS_PREFIX, + } + } +} + +/// A decrypted note. +struct DecryptedNote { + /// The incoming viewing key used to decrypt the note. + ivk: D::IncomingViewingKey, + /// The recipient of the note. + recipient: D::Recipient, + /// The note! + note: D::Note, + /// The memo sent with the note. + memo: D::Memo, +} + +impl fmt::Debug for DecryptedNote +where + D::IncomingViewingKey: fmt::Debug, + D::Recipient: fmt::Debug, + D::Note: fmt::Debug, + D::Memo: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("DecryptedNote") + .field("ivk", &self.ivk) + .field("recipient", &self.recipient) + .field("note", &self.note) + .field("memo", &self.memo) + .finish() + } +} + +/// A value correlated with an output index. +struct OutputIndex { + /// The index of the output within the corresponding shielded bundle. + output_index: usize, + /// The value for the output index. + value: V, +} + +type OutputReplier = OutputIndex>>>>; + +/// A batch of outputs to trial decrypt. +struct Batch> { + ivks: Vec, + outputs: Vec<(D, Output)>, + repliers: Vec>, +} + +impl Batch +where + D: BatchDomain, + Output: ShieldedOutput, + D::IncomingViewingKey: Clone, +{ + /// Constructs a new batch. + fn new(ivks: Vec) -> Self { + Self { + ivks, + outputs: vec![], + repliers: vec![], + } + } + + /// Returns `true` if the batch is currently empty. + fn is_empty(&self) -> bool { + self.outputs.is_empty() + } + + /// Runs the batch of trial decryptions, and reports the results. + fn run(self) { + assert_eq!(self.outputs.len(), self.repliers.len()); + let decrypted = batch::try_note_decryption(&self.ivks, &self.outputs); + for (decrypted_note, (ivk, replier)) in decrypted.into_iter().zip( + // The output of `batch::try_note_decryption` corresponds to the stream of + // trial decryptions: + // (ivk0, out0), (ivk0, out1), ..., (ivk0, outN), (ivk1, out0), ... + // So we can use the position in the stream to figure out which output was + // decrypted and which ivk decrypted it. + self.ivks + .iter() + .flat_map(|ivk| self.repliers.iter().map(move |tx| (ivk, tx))), + ) { + let value = decrypted_note.map(|(note, recipient, memo)| DecryptedNote { + ivk: ivk.clone(), + memo, + note, + recipient, + }); + + let output_index = replier.output_index; + let tx = &replier.value; + if tx + .send(OutputIndex { + output_index, + value, + }) + .is_err() + { + tracing::debug!("BatchRunner was dropped before batch finished"); + return; + } + } + } +} + +impl + Clone> Batch { + /// Adds the given outputs to this batch. + /// + /// `replier` will be called with the result of every output. + fn add_outputs( + &mut self, + domain: impl Fn() -> D, + outputs: &[Output], + replier: channel::Sender>>>, + ) { + self.outputs + .extend(outputs.iter().cloned().map(|output| (domain(), output))); + self.repliers + .extend((0..outputs.len()).map(|output_index| OutputIndex { + output_index, + value: replier.clone(), + })); + } +} + +type ResultKey = (BlockHash, TxId); + +/// Logic to run batches of trial decryptions on the global threadpool. +struct BatchRunner> { + acc: Batch, + pending_results: HashMap>>>>, +} + +impl BatchRunner +where + D: BatchDomain, + Output: ShieldedOutput, + D::IncomingViewingKey: Clone, +{ + /// Constructs a new batch runner for the given incoming viewing keys. + fn new(ivks: Vec) -> Self { + Self { + acc: Batch::new(ivks), + pending_results: HashMap::default(), + } + } +} + +impl BatchRunner +where + D: BatchDomain + Send + 'static, + D::IncomingViewingKey: Clone + Send, + D::Memo: Send, + D::Note: Send, + D::Recipient: Send, + Output: ShieldedOutput + Clone + Send + 'static, +{ + /// Batches the given outputs for trial decryption. + /// + /// `block_tag` is the hash of the block that triggered this txid being added to the + /// batch, or the all-zeros hash to indicate that no block triggered it (i.e. it was a + /// mempool change). + /// + /// If after adding the given outputs, the accumulated batch size is at least + /// `BATCH_SIZE_THRESHOLD`, `Self::flush` is called. Subsequent calls to + /// `Self::add_outputs` will be accumulated into a new batch. + fn add_outputs( + &mut self, + block_tag: BlockHash, + txid: TxId, + domain: impl Fn() -> D, + outputs: &[Output], + ) { + let (tx, rx) = channel::unbounded(); + self.acc.add_outputs(domain, outputs, tx); + self.pending_results.insert((block_tag, txid), rx); + + if self.acc.outputs.len() >= BATCH_SIZE_THRESHOLD { + self.flush(); + } + } + + /// Runs the currently accumulated batch on the global threadpool. + /// + /// Subsequent calls to `Self::add_outputs` will be accumulated into a new batch. + fn flush(&mut self) { + if !self.acc.is_empty() { + let mut batch = Batch::new(self.acc.ivks.clone()); + mem::swap(&mut batch, &mut self.acc); + rayon::spawn_fifo(|| batch.run()); + } + } + + /// Collects the pending decryption results for the given transaction. + /// + /// `block_tag` is the hash of the block that triggered this txid being added to the + /// batch, or the all-zeros hash to indicate that no block triggered it (i.e. it was a + /// mempool change). + fn collect_results( + &mut self, + block_tag: BlockHash, + txid: TxId, + ) -> HashMap<(TxId, usize), DecryptedNote> { + self.pending_results + .remove(&(block_tag, txid)) + // We won't have a pending result if the transaction didn't have outputs of + // this runner's kind. + .map(|rx| { + rx.into_iter() + .filter_map( + |OutputIndex { + output_index, + value, + }| { + value.map(|decrypted_note| ((txid, output_index), decrypted_note)) + }, + ) + .collect() + }) + .unwrap_or_default() + } +} + +/// A batch scanner for the `zcashd` wallet. +struct BatchScanner { + params: Network, + sapling_runner: Option, OutputDescription>>, +} + +fn init_batch_scanner( + network: &Network, + sapling_ivks: &[[u8; 32]], +) -> Result, &'static str> { + let sapling_runner = if sapling_ivks.is_empty() { + None + } else { + let ivks = sapling_ivks + .iter() + .map(|ivk| { + let ivk: Option = + jubjub::Fr::from_bytes(ivk).map(sapling::SaplingIvk).into(); + ivk.ok_or("Invalid Sapling ivk passed to wallet::init_batch_scanner()") + }) + .collect::>()?; + Some(BatchRunner::new(ivks)) + }; + + Ok(Box::new(BatchScanner { + params: *network, + sapling_runner, + })) +} + +impl BatchScanner { + /// Adds the given transaction's shielded outputs to the various batch runners. + /// + /// `block_tag` is the hash of the block that triggered this txid being added to the + /// batch, or the all-zeros hash to indicate that no block triggered it (i.e. it was a + /// mempool change). + /// + /// After adding the outputs, any accumulated batch of sufficient size is run on the + /// global threadpool. Subsequent calls to `Self::add_transaction` will accumulate + /// those output kinds into new batches. + fn add_transaction( + &mut self, + block_tag: [u8; 32], + tx_bytes: &[u8], + height: u32, + ) -> Result<(), io::Error> { + let block_tag = BlockHash(block_tag); + // The consensusBranchId parameter is ignored; it is not used in trial decryption + // and does not affect transaction parsing. + let tx = Transaction::read(tx_bytes, consensus::BranchId::Sprout)?; + let txid = tx.txid(); + let height = consensus::BlockHeight::from_u32(height); + + // If we have any Sapling IVKs, and the transaction has any Sapling outputs, queue + // the outputs for trial decryption. + if let Some((runner, bundle)) = self.sapling_runner.as_mut().zip(tx.sapling_bundle()) { + let params = self.params; + runner.add_outputs( + block_tag, + txid, + || SaplingDomain::for_height(params, height), + &bundle.shielded_outputs, + ); + } + + Ok(()) + } + + /// Runs the currently accumulated batches on the global threadpool. + /// + /// Subsequent calls to `Self::add_transaction` will be accumulated into new batches. + fn flush(&mut self) { + if let Some(runner) = &mut self.sapling_runner { + runner.flush(); + } + } + + /// Collects the pending decryption results for the given transaction. + /// + /// `block_tag` is the hash of the block that triggered this txid being added to the + /// batch, or the all-zeros hash to indicate that no block triggered it (i.e. it was a + /// mempool change). + /// + /// TODO: Return the `HashMap`s directly once `cxx` supports it. + fn collect_results(&mut self, block_tag: [u8; 32], txid: [u8; 32]) -> Box { + let block_tag = BlockHash(block_tag); + let txid = TxId::from_bytes(txid); + + let sapling = self + .sapling_runner + .as_mut() + .map(|runner| runner.collect_results(block_tag, txid)) + .unwrap_or_default(); + + Box::new(BatchResult { sapling }) + } +} + +struct BatchResult { + sapling: HashMap<(TxId, usize), DecryptedNote>>, +} + +impl BatchResult { + fn get_sapling(&self) -> Vec { + self.sapling + .iter() + .map( + |((txid, output), decrypted_note)| ffi::SaplingDecryptionResult { + txid: *txid.as_ref(), + output: *output as u32, + ivk: decrypted_note.ivk.to_repr(), + diversifier: decrypted_note.recipient.diversifier().0, + pk_d: decrypted_note.recipient.pk_d().to_bytes(), + }, + ) + .collect() + } +} From 6156215d4c2bb3b95f42bbe523a2d4a52e0fefc2 Mon Sep 17 00:00:00 2001 From: Kris Nuttycombe Date: Thu, 18 Aug 2022 16:45:45 -0600 Subject: [PATCH 2/3] Add parallelized batched trial decryption to wallet scanning. Co-authored-by: str4d --- zcash_client_backend/Cargo.toml | 5 +- zcash_client_backend/src/data_api/chain.rs | 48 ++- zcash_client_backend/src/lib.rs | 1 + zcash_client_backend/src/scan.rs | 387 ++---------------- zcash_client_backend/src/welding_rig.rs | 106 +++-- .../src/transaction/components/sapling.rs | 1 + 6 files changed, 174 insertions(+), 374 deletions(-) diff --git a/zcash_client_backend/Cargo.toml b/zcash_client_backend/Cargo.toml index 6d18d3766..dd66b62fa 100644 --- a/zcash_client_backend/Cargo.toml +++ b/zcash_client_backend/Cargo.toml @@ -13,10 +13,11 @@ license = "MIT OR Apache-2.0" edition = "2018" [dependencies] +base64 = "0.13" bech32 = "0.8" bls12_381 = "0.7" bs58 = { version = "0.4", features = ["check"] } -base64 = "0.13" +crossbeam-channel = "0.5" ff = "0.12" group = "0.12" hex = "0.4" @@ -29,11 +30,13 @@ percent-encoding = "2.1.0" proptest = { version = "1.0.0", optional = true } protobuf = "~2.27.1" # MSRV 1.52.1 rand_core = "0.6" +rayon = "1.5" ripemd = { version = "0.1", optional = true } secp256k1 = { version = "0.21", optional = true } sha2 = { version = "0.10.1", optional = true } subtle = "2.2.3" time = "0.2" +tracing = "0.1" zcash_address = { version = "0.1", path = "../components/zcash_address" } zcash_note_encryption = { version = "0.1", path = "../components/zcash_note_encryption" } zcash_primitives = { version = "0.7", path = "../zcash_primitives" } diff --git a/zcash_client_backend/src/data_api/chain.rs b/zcash_client_backend/src/data_api/chain.rs index 7a7006c5f..af4fb1602 100644 --- a/zcash_client_backend/src/data_api/chain.rs +++ b/zcash_client_backend/src/data_api/chain.rs @@ -77,13 +77,15 @@ //! # } //! ``` +use std::convert::TryFrom; use std::fmt::Debug; use zcash_primitives::{ block::BlockHash, consensus::{self, BlockHeight, NetworkUpgrade}, merkle_tree::CommitmentTree, - sapling::Nullifier, + sapling::{keys::Scope, note_encryption::SaplingDomain, Nullifier}, + transaction::components::sapling::CompactOutputDescription, }; use crate::{ @@ -92,8 +94,9 @@ use crate::{ BlockSource, PrunedBlock, WalletWrite, }, proto::compact_formats::CompactBlock, + scan::BatchRunner, wallet::WalletTx, - welding_rig::scan_block, + welding_rig::scan_block_with_runner, }; /// Checks that the scanned blocks in the data database, when combined with the recent @@ -192,7 +195,7 @@ pub fn scan_cached_blocks( limit: Option, ) -> Result<(), E> where - P: consensus::Parameters, + P: consensus::Parameters + Send + 'static, C: BlockSource, D: WalletWrite, N: Copy + Debug, @@ -229,6 +232,42 @@ where // Get the nullifiers for the notes we are tracking let mut nullifiers = data.get_nullifiers()?; + let mut batch_runner = BatchRunner::new( + 100, + dfvks + .iter() + .flat_map(|(_, dfvk)| [dfvk.to_ivk(Scope::External), dfvk.to_ivk(Scope::Internal)]) + .collect(), + ); + + cache.with_blocks(last_height, limit, |block: CompactBlock| { + let block_hash = block.hash(); + let block_height = block.height(); + + for tx in block.vtx.into_iter() { + let txid = tx.txid(); + let outputs = tx + .outputs + .into_iter() + .map(|output| { + CompactOutputDescription::try_from(output) + .expect("Invalid output found in compact block decoding.") + }) + .collect::>(); + + batch_runner.add_outputs( + block_hash, + txid, + || SaplingDomain::for_height(params.clone(), block_height), + &outputs, + ) + } + + Ok(()) + })?; + + batch_runner.flush(); + cache.with_blocks(last_height, limit, |block: CompactBlock| { let current_height = block.height(); @@ -245,13 +284,14 @@ where let txs: Vec> = { let mut witness_refs: Vec<_> = witnesses.iter_mut().map(|w| &mut w.1).collect(); - scan_block( + scan_block_with_runner( params, block, &dfvks, &nullifiers, &mut tree, &mut witness_refs[..], + Some(&mut batch_runner), ) }; diff --git a/zcash_client_backend/src/lib.rs b/zcash_client_backend/src/lib.rs index 852906156..2bcf755e9 100644 --- a/zcash_client_backend/src/lib.rs +++ b/zcash_client_backend/src/lib.rs @@ -14,6 +14,7 @@ mod decrypt; pub mod encoding; pub mod keys; pub mod proto; +pub mod scan; pub mod wallet; pub mod welding_rig; pub mod zip321; diff --git a/zcash_client_backend/src/scan.rs b/zcash_client_backend/src/scan.rs index 390c34614..ecebb3358 100644 --- a/zcash_client_backend/src/scan.rs +++ b/zcash_client_backend/src/scan.rs @@ -1,200 +1,19 @@ -use core::fmt; +use crossbeam_channel as channel; use std::collections::HashMap; -use std::io; +use std::fmt; use std::mem; -use crossbeam_channel as channel; -use group::GroupEncoding; -use zcash_note_encryption::{batch, BatchDomain, Domain, ShieldedOutput, ENC_CIPHERTEXT_SIZE}; -use zcash_primitives::{ - block::BlockHash, - consensus, constants, - sapling::{self, note_encryption::SaplingDomain}, - transaction::{ - components::{sapling::GrothProofBytes, OutputDescription}, - Transaction, TxId, - }, -}; - -#[cxx::bridge] -mod ffi { - #[namespace = "wallet"] - struct SaplingDecryptionResult { - txid: [u8; 32], - output: u32, - ivk: [u8; 32], - diversifier: [u8; 11], - pk_d: [u8; 32], - } - - #[namespace = "wallet"] - extern "Rust" { - type Network; - type BatchScanner; - type BatchResult; - - fn network( - network: &str, - overwinter: i32, - sapling: i32, - blossom: i32, - heartwood: i32, - canopy: i32, - nu5: i32, - ) -> Result>; - - fn init_batch_scanner( - network: &Network, - sapling_ivks: &[[u8; 32]], - ) -> Result>; - fn add_transaction( - self: &mut BatchScanner, - block_tag: [u8; 32], - tx_bytes: &[u8], - height: u32, - ) -> Result<()>; - fn flush(self: &mut BatchScanner); - fn collect_results( - self: &mut BatchScanner, - block_tag: [u8; 32], - txid: [u8; 32], - ) -> Box; - - fn get_sapling(self: &BatchResult) -> Vec; - } -} - -/// The minimum number of outputs to trial decrypt in a batch. -/// -/// TODO: Tune this. -const BATCH_SIZE_THRESHOLD: usize = 20; - -/// Chain parameters for the networks supported by `zcashd`. -#[derive(Clone, Copy)] -pub enum Network { - Consensus(consensus::Network), - RegTest { - overwinter: Option, - sapling: Option, - blossom: Option, - heartwood: Option, - canopy: Option, - nu5: Option, - }, -} - -/// Constructs a `Network` from the given network string. -/// -/// The heights are only for constructing a regtest network, and are ignored otherwise. -fn network( - network: &str, - overwinter: i32, - sapling: i32, - blossom: i32, - heartwood: i32, - canopy: i32, - nu5: i32, -) -> Result, &'static str> { - let i32_to_optional_height = |n: i32| { - if n.is_negative() { - None - } else { - Some(consensus::BlockHeight::from_u32(n.unsigned_abs())) - } - }; - - let params = match network { - "main" => Network::Consensus(consensus::Network::MainNetwork), - "test" => Network::Consensus(consensus::Network::TestNetwork), - "regtest" => Network::RegTest { - overwinter: i32_to_optional_height(overwinter), - sapling: i32_to_optional_height(sapling), - blossom: i32_to_optional_height(blossom), - heartwood: i32_to_optional_height(heartwood), - canopy: i32_to_optional_height(canopy), - nu5: i32_to_optional_height(nu5), - }, - _ => return Err("Unsupported network kind"), - }; - - Ok(Box::new(params)) -} - -impl consensus::Parameters for Network { - fn activation_height(&self, nu: consensus::NetworkUpgrade) -> Option { - match self { - Self::Consensus(params) => params.activation_height(nu), - Self::RegTest { - overwinter, - sapling, - blossom, - heartwood, - canopy, - nu5, - } => match nu { - consensus::NetworkUpgrade::Overwinter => *overwinter, - consensus::NetworkUpgrade::Sapling => *sapling, - consensus::NetworkUpgrade::Blossom => *blossom, - consensus::NetworkUpgrade::Heartwood => *heartwood, - consensus::NetworkUpgrade::Canopy => *canopy, - consensus::NetworkUpgrade::Nu5 => *nu5, - }, - } - } - - fn coin_type(&self) -> u32 { - match self { - Self::Consensus(params) => params.coin_type(), - Self::RegTest { .. } => constants::regtest::COIN_TYPE, - } - } - - fn hrp_sapling_extended_spending_key(&self) -> &str { - match self { - Self::Consensus(params) => params.hrp_sapling_extended_spending_key(), - Self::RegTest { .. } => constants::regtest::HRP_SAPLING_EXTENDED_SPENDING_KEY, - } - } - - fn hrp_sapling_extended_full_viewing_key(&self) -> &str { - match self { - Self::Consensus(params) => params.hrp_sapling_extended_full_viewing_key(), - Self::RegTest { .. } => constants::regtest::HRP_SAPLING_EXTENDED_FULL_VIEWING_KEY, - } - } - - fn hrp_sapling_payment_address(&self) -> &str { - match self { - Self::Consensus(params) => params.hrp_sapling_payment_address(), - Self::RegTest { .. } => constants::regtest::HRP_SAPLING_PAYMENT_ADDRESS, - } - } - - fn b58_pubkey_address_prefix(&self) -> [u8; 2] { - match self { - Self::Consensus(params) => params.b58_pubkey_address_prefix(), - Self::RegTest { .. } => constants::regtest::B58_PUBKEY_ADDRESS_PREFIX, - } - } - - fn b58_script_address_prefix(&self) -> [u8; 2] { - match self { - Self::Consensus(params) => params.b58_script_address_prefix(), - Self::RegTest { .. } => constants::regtest::B58_SCRIPT_ADDRESS_PREFIX, - } - } -} +use zcash_note_encryption::{batch, BatchDomain, Domain, ShieldedOutput, COMPACT_NOTE_SIZE}; +use zcash_primitives::{block::BlockHash, transaction::TxId}; /// A decrypted note. -struct DecryptedNote { +pub(crate) struct DecryptedNote { /// The incoming viewing key used to decrypt the note. - ivk: D::IncomingViewingKey, + pub(crate) ivk: D::IncomingViewingKey, /// The recipient of the note. - recipient: D::Recipient, + pub(crate) recipient: D::Recipient, /// The note! - note: D::Note, - /// The memo sent with the note. - memo: D::Memo, + pub(crate) note: D::Note, } impl fmt::Debug for DecryptedNote @@ -209,7 +28,6 @@ where .field("ivk", &self.ivk) .field("recipient", &self.recipient) .field("note", &self.note) - .field("memo", &self.memo) .finish() } } @@ -225,8 +43,15 @@ struct OutputIndex { type OutputReplier = OutputIndex>>>>; /// A batch of outputs to trial decrypt. -struct Batch> { +struct Batch> { ivks: Vec, + /// We currently store outputs and repliers as parallel vectors, because + /// [`batch::try_note_decryption`] accepts a slice of domain/output pairs + /// rather than a value that implements `IntoIterator`, and therefore we + /// can't just use `map` to select the parts we need in order to perform + /// batch decryption. Ideally the domain, output, and output replier would + /// all be part of the same struct, which would also track the output index + /// (that is captured in the outer `OutputIndex` of each `OutputReplier`). outputs: Vec<(D, Output)>, repliers: Vec>, } @@ -234,7 +59,7 @@ struct Batch> { impl Batch where D: BatchDomain, - Output: ShieldedOutput, + Output: ShieldedOutput, D::IncomingViewingKey: Clone, { /// Constructs a new batch. @@ -254,33 +79,20 @@ where /// Runs the batch of trial decryptions, and reports the results. fn run(self) { assert_eq!(self.outputs.len(), self.repliers.len()); - let decrypted = batch::try_note_decryption(&self.ivks, &self.outputs); - for (decrypted_note, (ivk, replier)) in decrypted.into_iter().zip( - // The output of `batch::try_note_decryption` corresponds to the stream of - // trial decryptions: - // (ivk0, out0), (ivk0, out1), ..., (ivk0, outN), (ivk1, out0), ... - // So we can use the position in the stream to figure out which output was - // decrypted and which ivk decrypted it. - self.ivks - .iter() - .flat_map(|ivk| self.repliers.iter().map(move |tx| (ivk, tx))), - ) { - let value = decrypted_note.map(|(note, recipient, memo)| DecryptedNote { - ivk: ivk.clone(), - memo, - note, - recipient, - }); - let output_index = replier.output_index; - let tx = &replier.value; - if tx - .send(OutputIndex { - output_index, - value, - }) - .is_err() - { + let decryption_results = batch::try_compact_note_decryption(&self.ivks, &self.outputs); + for (decryption_result, replier) in decryption_results.into_iter().zip(self.repliers.iter()) + { + let result = OutputIndex { + output_index: replier.output_index, + value: decryption_result.map(|((note, recipient), ivk_idx)| DecryptedNote { + ivk: self.ivks[ivk_idx].clone(), + recipient, + note, + }), + }; + + if replier.value.send(result).is_err() { tracing::debug!("BatchRunner was dropped before batch finished"); return; } @@ -288,7 +100,7 @@ where } } -impl + Clone> Batch { +impl + Clone> Batch { /// Adds the given outputs to this batch. /// /// `replier` will be called with the result of every output. @@ -311,7 +123,8 @@ impl + Clone> Bat type ResultKey = (BlockHash, TxId); /// Logic to run batches of trial decryptions on the global threadpool. -struct BatchRunner> { +pub(crate) struct BatchRunner> { + batch_size_threshold: usize, acc: Batch, pending_results: HashMap>>>>, } @@ -319,12 +132,13 @@ struct BatchRunner BatchRunner where D: BatchDomain, - Output: ShieldedOutput, + Output: ShieldedOutput, D::IncomingViewingKey: Clone, { /// Constructs a new batch runner for the given incoming viewing keys. - fn new(ivks: Vec) -> Self { + pub(crate) fn new(batch_size_threshold: usize, ivks: Vec) -> Self { Self { + batch_size_threshold, acc: Batch::new(ivks), pending_results: HashMap::default(), } @@ -338,7 +152,7 @@ where D::Memo: Send, D::Note: Send, D::Recipient: Send, - Output: ShieldedOutput + Clone + Send + 'static, + Output: ShieldedOutput + Clone + Send + 'static, { /// Batches the given outputs for trial decryption. /// @@ -349,7 +163,7 @@ where /// If after adding the given outputs, the accumulated batch size is at least /// `BATCH_SIZE_THRESHOLD`, `Self::flush` is called. Subsequent calls to /// `Self::add_outputs` will be accumulated into a new batch. - fn add_outputs( + pub(crate) fn add_outputs( &mut self, block_tag: BlockHash, txid: TxId, @@ -360,7 +174,7 @@ where self.acc.add_outputs(domain, outputs, tx); self.pending_results.insert((block_tag, txid), rx); - if self.acc.outputs.len() >= BATCH_SIZE_THRESHOLD { + if self.acc.outputs.len() >= self.batch_size_threshold { self.flush(); } } @@ -368,7 +182,7 @@ where /// Runs the currently accumulated batch on the global threadpool. /// /// Subsequent calls to `Self::add_outputs` will be accumulated into a new batch. - fn flush(&mut self) { + pub(crate) fn flush(&mut self) { if !self.acc.is_empty() { let mut batch = Batch::new(self.acc.ivks.clone()); mem::swap(&mut batch, &mut self.acc); @@ -381,7 +195,7 @@ where /// `block_tag` is the hash of the block that triggered this txid being added to the /// batch, or the all-zeros hash to indicate that no block triggered it (i.e. it was a /// mempool change). - fn collect_results( + pub(crate) fn collect_results( &mut self, block_tag: BlockHash, txid: TxId, @@ -405,122 +219,3 @@ where .unwrap_or_default() } } - -/// A batch scanner for the `zcashd` wallet. -struct BatchScanner { - params: Network, - sapling_runner: Option, OutputDescription>>, -} - -fn init_batch_scanner( - network: &Network, - sapling_ivks: &[[u8; 32]], -) -> Result, &'static str> { - let sapling_runner = if sapling_ivks.is_empty() { - None - } else { - let ivks = sapling_ivks - .iter() - .map(|ivk| { - let ivk: Option = - jubjub::Fr::from_bytes(ivk).map(sapling::SaplingIvk).into(); - ivk.ok_or("Invalid Sapling ivk passed to wallet::init_batch_scanner()") - }) - .collect::>()?; - Some(BatchRunner::new(ivks)) - }; - - Ok(Box::new(BatchScanner { - params: *network, - sapling_runner, - })) -} - -impl BatchScanner { - /// Adds the given transaction's shielded outputs to the various batch runners. - /// - /// `block_tag` is the hash of the block that triggered this txid being added to the - /// batch, or the all-zeros hash to indicate that no block triggered it (i.e. it was a - /// mempool change). - /// - /// After adding the outputs, any accumulated batch of sufficient size is run on the - /// global threadpool. Subsequent calls to `Self::add_transaction` will accumulate - /// those output kinds into new batches. - fn add_transaction( - &mut self, - block_tag: [u8; 32], - tx_bytes: &[u8], - height: u32, - ) -> Result<(), io::Error> { - let block_tag = BlockHash(block_tag); - // The consensusBranchId parameter is ignored; it is not used in trial decryption - // and does not affect transaction parsing. - let tx = Transaction::read(tx_bytes, consensus::BranchId::Sprout)?; - let txid = tx.txid(); - let height = consensus::BlockHeight::from_u32(height); - - // If we have any Sapling IVKs, and the transaction has any Sapling outputs, queue - // the outputs for trial decryption. - if let Some((runner, bundle)) = self.sapling_runner.as_mut().zip(tx.sapling_bundle()) { - let params = self.params; - runner.add_outputs( - block_tag, - txid, - || SaplingDomain::for_height(params, height), - &bundle.shielded_outputs, - ); - } - - Ok(()) - } - - /// Runs the currently accumulated batches on the global threadpool. - /// - /// Subsequent calls to `Self::add_transaction` will be accumulated into new batches. - fn flush(&mut self) { - if let Some(runner) = &mut self.sapling_runner { - runner.flush(); - } - } - - /// Collects the pending decryption results for the given transaction. - /// - /// `block_tag` is the hash of the block that triggered this txid being added to the - /// batch, or the all-zeros hash to indicate that no block triggered it (i.e. it was a - /// mempool change). - /// - /// TODO: Return the `HashMap`s directly once `cxx` supports it. - fn collect_results(&mut self, block_tag: [u8; 32], txid: [u8; 32]) -> Box { - let block_tag = BlockHash(block_tag); - let txid = TxId::from_bytes(txid); - - let sapling = self - .sapling_runner - .as_mut() - .map(|runner| runner.collect_results(block_tag, txid)) - .unwrap_or_default(); - - Box::new(BatchResult { sapling }) - } -} - -struct BatchResult { - sapling: HashMap<(TxId, usize), DecryptedNote>>, -} - -impl BatchResult { - fn get_sapling(&self) -> Vec { - self.sapling - .iter() - .map( - |((txid, output), decrypted_note)| ffi::SaplingDecryptionResult { - txid: *txid.as_ref(), - output: *output as u32, - ivk: decrypted_note.ivk.to_repr(), - diversifier: decrypted_note.recipient.diversifier().0, - pk_d: decrypted_note.recipient.pk_d().to_bytes(), - }, - ) - .collect() - } -} diff --git a/zcash_client_backend/src/welding_rig.rs b/zcash_client_backend/src/welding_rig.rs index 655ba92b6..0f15241e4 100644 --- a/zcash_client_backend/src/welding_rig.rs +++ b/zcash_client_backend/src/welding_rig.rs @@ -1,7 +1,7 @@ //! Tools for scanning a compact representation of the Zcash block chain. use ff::PrimeField; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::convert::TryFrom; use subtle::{ConditionallySelectable, ConstantTimeEq, CtOption}; use zcash_note_encryption::batch; @@ -18,8 +18,11 @@ use zcash_primitives::{ zip32::{AccountId, ExtendedFullViewingKey}, }; -use crate::proto::compact_formats::CompactBlock; -use crate::wallet::{WalletShieldedOutput, WalletShieldedSpend, WalletTx}; +use crate::{ + proto::compact_formats::CompactBlock, + scan::BatchRunner, + wallet::{WalletShieldedOutput, WalletShieldedSpend, WalletTx}, +}; /// A key that can be used to perform trial decryption and nullifier /// computation for a Sapling [`CompactSaplingOutput`] @@ -36,7 +39,7 @@ use crate::wallet::{WalletShieldedOutput, WalletShieldedSpend, WalletTx}; /// [`scan_block`]: crate::welding_rig::scan_block pub trait ScanningKey { /// The type of key that is used to decrypt Sapling outputs; - type SaplingNk; + type SaplingNk: Clone; type SaplingKeys: IntoIterator; @@ -141,16 +144,37 @@ impl ScanningKey for SaplingIvk { /// [`IncrementalWitness`]: zcash_primitives::merkle_tree::IncrementalWitness /// [`WalletShieldedOutput`]: crate::wallet::WalletShieldedOutput /// [`WalletTx`]: crate::wallet::WalletTx -pub fn scan_block( +pub fn scan_block( params: &P, block: CompactBlock, vks: &[(&AccountId, &K)], nullifiers: &[(AccountId, Nullifier)], tree: &mut CommitmentTree, existing_witnesses: &mut [&mut IncrementalWitness], +) -> Vec> { + scan_block_with_runner( + params, + block, + vks, + nullifiers, + tree, + existing_witnesses, + None, + ) +} + +pub(crate) fn scan_block_with_runner( + params: &P, + block: CompactBlock, + vks: &[(&AccountId, &K)], + nullifiers: &[(AccountId, Nullifier)], + tree: &mut CommitmentTree, + existing_witnesses: &mut [&mut IncrementalWitness], + mut batch_runner: Option<&mut BatchRunner, CompactOutputDescription>>, ) -> Vec> { let mut wtxs: Vec> = vec![]; let block_height = block.height(); + let block_hash = block.hash(); for tx in block.vtx.into_iter() { let txid = tx.txid(); @@ -218,21 +242,53 @@ pub fn scan_block( }) .collect::>(); - let vks = vks - .iter() - .flat_map(|(a, k)| { - k.to_sapling_keys() - .into_iter() - .map(move |(ivk, nk)| (**a, ivk, nk)) - }) - .collect::>(); + let decrypted: Vec<_> = if let Some(runner) = batch_runner.as_mut() { + let vks = vks + .iter() + .flat_map(|(a, k)| { + k.to_sapling_keys() + .into_iter() + .map(move |(ivk, nk)| (ivk.to_repr(), (**a, nk))) + }) + .collect::>(); - let ivks = vks - .iter() - .map(|(_, ivk, _)| (*ivk).clone()) - .collect::>(); + let mut decrypted = runner.collect_results(block_hash, txid); + (0..decoded.len()) + .map(|i| { + decrypted.remove(&(txid, i)).map(|d_note| { + let (a, nk) = vks.get(&d_note.ivk.to_repr()).expect( + "The batch runner and scan_block must use the same set of IVKs.", + ); - let decrypted = batch::try_compact_note_decryption(&ivks, decoded); + ((d_note.note, d_note.recipient), *a, (*nk).clone()) + }) + }) + .collect() + } else { + let vks = vks + .iter() + .flat_map(|(a, k)| { + k.to_sapling_keys() + .into_iter() + .map(move |(ivk, nk)| (**a, ivk, nk)) + }) + .collect::>(); + + let ivks = vks + .iter() + .map(|(_, ivk, _)| (*ivk).clone()) + .collect::>(); + + batch::try_compact_note_decryption(&ivks, decoded) + .into_iter() + .map(|v| { + v.map(|(note_data, ivk_idx)| { + let (account, _, nk) = &vks[ivk_idx]; + (note_data, *account, (*nk).clone()) + }) + }) + .collect() + }; for (index, ((_, output), dec_output)) in decoded.iter().zip(decrypted).enumerate() { // Grab mutable references to new witnesses from previous outputs @@ -256,23 +312,22 @@ pub fn scan_block( } tree.append(node).unwrap(); - if let Some(((note, to), ivk_idx)) = dec_output { + if let Some(((note, to), account, nk)) = dec_output { // A note is marked as "change" if the account that received it // also spent notes in the same transaction. This will catch, // for instance: // - Change created by spending fractions of notes. // - Notes created by consolidation transactions. // - Notes sent from one account to itself. - let (account, _, nk) = &vks[ivk_idx]; - let is_change = spent_from_accounts.contains(account); + let is_change = spent_from_accounts.contains(&account); let witness = IncrementalWitness::from_tree(tree); - let nf = K::sapling_nf(nk, ¬e, &witness); + let nf = K::sapling_nf(&nk, ¬e, &witness); shielded_outputs.push(WalletShieldedOutput { index, cmu: output.cmu, ephemeral_key: output.ephemeral_key.clone(), - account: *account, + account, note, to, is_change, @@ -387,6 +442,11 @@ mod tests { // Create a fake CompactBlock containing the note let mut cb = CompactBlock::new(); + cb.set_hash({ + let mut hash = vec![0; 32]; + rng.fill_bytes(&mut hash); + hash + }); cb.set_height(height.into()); // Add a random Sapling tx before ours diff --git a/zcash_primitives/src/transaction/components/sapling.rs b/zcash_primitives/src/transaction/components/sapling.rs index 8413b897b..9d22ce5e4 100644 --- a/zcash_primitives/src/transaction/components/sapling.rs +++ b/zcash_primitives/src/transaction/components/sapling.rs @@ -389,6 +389,7 @@ impl OutputDescriptionV5 { } } +#[derive(Clone)] pub struct CompactOutputDescription { pub ephemeral_key: EphemeralKeyBytes, pub cmu: bls12_381::Scalar, From 397c76ca8d3972855392b7a3a4ec975a6958f674 Mon Sep 17 00:00:00 2001 From: Kris Nuttycombe Date: Thu, 18 Aug 2022 15:03:32 -0600 Subject: [PATCH 3/3] Add tests for batched note decryption. --- zcash_client_backend/src/data_api/chain.rs | 29 +-- zcash_client_backend/src/welding_rig.rs | 210 ++++++++++++++------- 2 files changed, 148 insertions(+), 91 deletions(-) diff --git a/zcash_client_backend/src/data_api/chain.rs b/zcash_client_backend/src/data_api/chain.rs index af4fb1602..b5d9567a2 100644 --- a/zcash_client_backend/src/data_api/chain.rs +++ b/zcash_client_backend/src/data_api/chain.rs @@ -77,15 +77,13 @@ //! # } //! ``` -use std::convert::TryFrom; use std::fmt::Debug; use zcash_primitives::{ block::BlockHash, consensus::{self, BlockHeight, NetworkUpgrade}, merkle_tree::CommitmentTree, - sapling::{keys::Scope, note_encryption::SaplingDomain, Nullifier}, - transaction::components::sapling::CompactOutputDescription, + sapling::{keys::Scope, Nullifier}, }; use crate::{ @@ -96,7 +94,7 @@ use crate::{ proto::compact_formats::CompactBlock, scan::BatchRunner, wallet::WalletTx, - welding_rig::scan_block_with_runner, + welding_rig::{add_block_to_runner, scan_block_with_runner}, }; /// Checks that the scanned blocks in the data database, when combined with the recent @@ -241,28 +239,7 @@ where ); cache.with_blocks(last_height, limit, |block: CompactBlock| { - let block_hash = block.hash(); - let block_height = block.height(); - - for tx in block.vtx.into_iter() { - let txid = tx.txid(); - let outputs = tx - .outputs - .into_iter() - .map(|output| { - CompactOutputDescription::try_from(output) - .expect("Invalid output found in compact block decoding.") - }) - .collect::>(); - - batch_runner.add_outputs( - block_hash, - txid, - || SaplingDomain::for_height(params.clone(), block_height), - &outputs, - ) - } - + add_block_to_runner(params, block, &mut batch_runner); Ok(()) })?; diff --git a/zcash_client_backend/src/welding_rig.rs b/zcash_client_backend/src/welding_rig.rs index 0f15241e4..764169040 100644 --- a/zcash_client_backend/src/welding_rig.rs +++ b/zcash_client_backend/src/welding_rig.rs @@ -163,6 +163,34 @@ pub fn scan_block( ) } +pub(crate) fn add_block_to_runner( + params: &P, + block: CompactBlock, + batch_runner: &mut BatchRunner, CompactOutputDescription>, +) { + let block_hash = block.hash(); + let block_height = block.height(); + + for tx in block.vtx.into_iter() { + let txid = tx.txid(); + let outputs = tx + .outputs + .into_iter() + .map(|output| { + CompactOutputDescription::try_from(output) + .expect("Invalid output found in compact block decoding.") + }) + .collect::>(); + + batch_runner.add_outputs( + block_hash, + txid, + || SaplingDomain::for_height(params.clone(), block_height), + &outputs, + ) + } +} + pub(crate) fn scan_block_with_runner( params: &P, block: CompactBlock, @@ -371,11 +399,15 @@ mod tests { zip32::{AccountId, ExtendedFullViewingKey, ExtendedSpendingKey}, }; - use super::scan_block; - use crate::proto::compact_formats::{ - CompactBlock, CompactSaplingOutput, CompactSaplingSpend, CompactTx, + use crate::{ + proto::compact_formats::{ + CompactBlock, CompactSaplingOutput, CompactSaplingSpend, CompactTx, + }, + scan::BatchRunner, }; + use super::{add_block_to_runner, scan_block, scan_block_with_runner, ScanningKey}; + fn random_compact_tx(mut rng: impl RngCore) -> CompactTx { let fake_nf = { let mut nf = vec![0; 32]; @@ -483,80 +515,128 @@ mod tests { #[test] fn scan_block_with_my_tx() { - let extsk = ExtendedSpendingKey::master(&[]); - let extfvk = ExtendedFullViewingKey::from(&extsk); + fn go(scan_multithreaded: bool) { + let extsk = ExtendedSpendingKey::master(&[]); + let extfvk = ExtendedFullViewingKey::from(&extsk); - let cb = fake_compact_block( - 1u32.into(), - Nullifier([0; 32]), - extfvk.clone(), - Amount::from_u64(5).unwrap(), - false, - ); - assert_eq!(cb.vtx.len(), 2); + let cb = fake_compact_block( + 1u32.into(), + Nullifier([0; 32]), + extfvk.clone(), + Amount::from_u64(5).unwrap(), + false, + ); + assert_eq!(cb.vtx.len(), 2); - let mut tree = CommitmentTree::empty(); - let txs = scan_block( - &Network::TestNetwork, - cb, - &[(&AccountId::from(0), &extfvk)], - &[], - &mut tree, - &mut [], - ); - assert_eq!(txs.len(), 1); + let mut tree = CommitmentTree::empty(); + let mut batch_runner = if scan_multithreaded { + let mut runner = BatchRunner::new( + 10, + extfvk + .to_sapling_keys() + .iter() + .map(|(k, _)| k.clone()) + .collect(), + ); - let tx = &txs[0]; - assert_eq!(tx.index, 1); - assert_eq!(tx.num_spends, 1); - assert_eq!(tx.num_outputs, 1); - assert_eq!(tx.shielded_spends.len(), 0); - assert_eq!(tx.shielded_outputs.len(), 1); - assert_eq!(tx.shielded_outputs[0].index, 0); - assert_eq!(tx.shielded_outputs[0].account, AccountId::from(0)); - assert_eq!(tx.shielded_outputs[0].note.value, 5); + add_block_to_runner(&Network::TestNetwork, cb.clone(), &mut runner); + runner.flush(); - // Check that the witness root matches - assert_eq!(tx.shielded_outputs[0].witness.root(), tree.root()); + Some(runner) + } else { + None + }; + + let txs = scan_block_with_runner( + &Network::TestNetwork, + cb, + &[(&AccountId::from(0), &extfvk)], + &[], + &mut tree, + &mut [], + batch_runner.as_mut(), + ); + assert_eq!(txs.len(), 1); + + let tx = &txs[0]; + assert_eq!(tx.index, 1); + assert_eq!(tx.num_spends, 1); + assert_eq!(tx.num_outputs, 1); + assert_eq!(tx.shielded_spends.len(), 0); + assert_eq!(tx.shielded_outputs.len(), 1); + assert_eq!(tx.shielded_outputs[0].index, 0); + assert_eq!(tx.shielded_outputs[0].account, AccountId::from(0)); + assert_eq!(tx.shielded_outputs[0].note.value, 5); + + // Check that the witness root matches + assert_eq!(tx.shielded_outputs[0].witness.root(), tree.root()); + } + + go(false); + go(true); } #[test] fn scan_block_with_txs_after_my_tx() { - let extsk = ExtendedSpendingKey::master(&[]); - let extfvk = ExtendedFullViewingKey::from(&extsk); + fn go(scan_multithreaded: bool) { + let extsk = ExtendedSpendingKey::master(&[]); + let extfvk = ExtendedFullViewingKey::from(&extsk); - let cb = fake_compact_block( - 1u32.into(), - Nullifier([0; 32]), - extfvk.clone(), - Amount::from_u64(5).unwrap(), - true, - ); - assert_eq!(cb.vtx.len(), 3); + let cb = fake_compact_block( + 1u32.into(), + Nullifier([0; 32]), + extfvk.clone(), + Amount::from_u64(5).unwrap(), + true, + ); + assert_eq!(cb.vtx.len(), 3); - let mut tree = CommitmentTree::empty(); - let txs = scan_block( - &Network::TestNetwork, - cb, - &[(&AccountId::from(0), &extfvk)], - &[], - &mut tree, - &mut [], - ); - assert_eq!(txs.len(), 1); + let mut tree = CommitmentTree::empty(); + let mut batch_runner = if scan_multithreaded { + let mut runner = BatchRunner::new( + 10, + extfvk + .to_sapling_keys() + .iter() + .map(|(k, _)| k.clone()) + .collect(), + ); - let tx = &txs[0]; - assert_eq!(tx.index, 1); - assert_eq!(tx.num_spends, 1); - assert_eq!(tx.num_outputs, 1); - assert_eq!(tx.shielded_spends.len(), 0); - assert_eq!(tx.shielded_outputs.len(), 1); - assert_eq!(tx.shielded_outputs[0].index, 0); - assert_eq!(tx.shielded_outputs[0].account, AccountId::from(0)); - assert_eq!(tx.shielded_outputs[0].note.value, 5); + add_block_to_runner(&Network::TestNetwork, cb.clone(), &mut runner); + runner.flush(); - // Check that the witness root matches - assert_eq!(tx.shielded_outputs[0].witness.root(), tree.root()); + Some(runner) + } else { + None + }; + + let txs = scan_block_with_runner( + &Network::TestNetwork, + cb, + &[(&AccountId::from(0), &extfvk)], + &[], + &mut tree, + &mut [], + batch_runner.as_mut(), + ); + assert_eq!(txs.len(), 1); + + let tx = &txs[0]; + assert_eq!(tx.index, 1); + assert_eq!(tx.num_spends, 1); + assert_eq!(tx.num_outputs, 1); + assert_eq!(tx.shielded_spends.len(), 0); + assert_eq!(tx.shielded_outputs.len(), 1); + assert_eq!(tx.shielded_outputs[0].index, 0); + assert_eq!(tx.shielded_outputs[0].account, AccountId::from(0)); + assert_eq!(tx.shielded_outputs[0].note.value, 5); + + // Check that the witness root matches + assert_eq!(tx.shielded_outputs[0].witness.root(), tree.root()); + } + + go(false); + go(true); } #[test]