Merge pull request #6088 from str4d/wallet-batch-scanner

Use multithreaded batched trial decryption for Sapling outputs
This commit is contained in:
Kris Nuttycombe 2022-07-22 14:09:50 -06:00 committed by GitHub
commit e3e5465438
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 1082 additions and 50 deletions

1
Cargo.lock generated
View File

@ -956,6 +956,7 @@ dependencies = [
"bls12_381",
"byteorder",
"clearscreen",
"crossbeam-channel",
"cxx",
"ed25519-zebra",
"group",

View File

@ -38,6 +38,7 @@ blake2b_simd = "1"
blake2s_simd = "1"
bls12_381 = "0.7"
byteorder = "1"
crossbeam-channel = "0.5"
group = "0.12"
incrementalmerkletree = "0.3"
libc = "0.2"

View File

@ -52,19 +52,22 @@ CXXBRIDGE_RS = \
rust/src/bundlecache.rs \
rust/src/equihash.rs \
rust/src/orchard_bundle.rs \
rust/src/sapling.rs
rust/src/sapling.rs \
rust/src/wallet_scanner.rs
CXXBRIDGE_H = \
rust/gen/include/rust/blake2b.h \
rust/gen/include/rust/bundlecache.h \
rust/gen/include/rust/equihash.h \
rust/gen/include/rust/orchard_bundle.h \
rust/gen/include/rust/sapling.h
rust/gen/include/rust/sapling.h \
rust/gen/include/rust/wallet_scanner.h
CXXBRIDGE_CPP = \
rust/gen/src/blake2b.cpp \
rust/gen/src/bundlecache.cpp \
rust/gen/src/equihash.cpp \
rust/gen/src/orchard_bundle.cpp \
rust/gen/src/sapling.cpp
rust/gen/src/sapling.cpp \
rust/gen/src/wallet_scanner.cpp
# We add a rust/cxx.h include to indicate that we provide this (via the rustcxx depends
# package), so that cxxbridge doesn't include it within the generated headers and code.

View File

@ -81,6 +81,7 @@ mod sapling;
mod transaction_ffi;
mod unified_keys_ffi;
mod wallet;
mod wallet_scanner;
mod zip339_ffi;
mod test_harness_ffi;

View File

@ -0,0 +1,526 @@
use core::fmt;
use std::collections::HashMap;
use std::io;
use std::mem;
use crossbeam_channel as channel;
use group::GroupEncoding;
use zcash_note_encryption::{batch, BatchDomain, Domain, ShieldedOutput, ENC_CIPHERTEXT_SIZE};
use zcash_primitives::{
block::BlockHash,
consensus, constants,
sapling::{self, note_encryption::SaplingDomain},
transaction::{
components::{sapling::GrothProofBytes, OutputDescription},
Transaction, TxId,
},
};
#[cxx::bridge]
mod ffi {
#[namespace = "wallet"]
struct SaplingDecryptionResult {
txid: [u8; 32],
output: u32,
ivk: [u8; 32],
diversifier: [u8; 11],
pk_d: [u8; 32],
}
#[namespace = "wallet"]
extern "Rust" {
type Network;
type BatchScanner;
type BatchResult;
fn network(
network: &str,
overwinter: i32,
sapling: i32,
blossom: i32,
heartwood: i32,
canopy: i32,
nu5: i32,
) -> Result<Box<Network>>;
fn init_batch_scanner(
network: &Network,
sapling_ivks: &[[u8; 32]],
) -> Result<Box<BatchScanner>>;
fn add_transaction(
self: &mut BatchScanner,
block_tag: [u8; 32],
tx_bytes: &[u8],
height: u32,
) -> Result<()>;
fn flush(self: &mut BatchScanner);
fn collect_results(
self: &mut BatchScanner,
block_tag: [u8; 32],
txid: [u8; 32],
) -> Box<BatchResult>;
fn get_sapling(self: &BatchResult) -> Vec<SaplingDecryptionResult>;
}
}
/// The minimum number of outputs to trial decrypt in a batch.
///
/// TODO: Tune this.
const BATCH_SIZE_THRESHOLD: usize = 20;
/// Chain parameters for the networks supported by `zcashd`.
#[derive(Clone, Copy)]
pub enum Network {
Consensus(consensus::Network),
RegTest {
overwinter: Option<consensus::BlockHeight>,
sapling: Option<consensus::BlockHeight>,
blossom: Option<consensus::BlockHeight>,
heartwood: Option<consensus::BlockHeight>,
canopy: Option<consensus::BlockHeight>,
nu5: Option<consensus::BlockHeight>,
},
}
/// Constructs a `Network` from the given network string.
///
/// The heights are only for constructing a regtest network, and are ignored otherwise.
fn network(
network: &str,
overwinter: i32,
sapling: i32,
blossom: i32,
heartwood: i32,
canopy: i32,
nu5: i32,
) -> Result<Box<Network>, &'static str> {
let i32_to_optional_height = |n: i32| {
if n.is_negative() {
None
} else {
Some(consensus::BlockHeight::from_u32(n.unsigned_abs()))
}
};
let params = match network {
"main" => Network::Consensus(consensus::Network::MainNetwork),
"test" => Network::Consensus(consensus::Network::TestNetwork),
"regtest" => Network::RegTest {
overwinter: i32_to_optional_height(overwinter),
sapling: i32_to_optional_height(sapling),
blossom: i32_to_optional_height(blossom),
heartwood: i32_to_optional_height(heartwood),
canopy: i32_to_optional_height(canopy),
nu5: i32_to_optional_height(nu5),
},
_ => return Err("Unsupported network kind"),
};
Ok(Box::new(params))
}
impl consensus::Parameters for Network {
fn activation_height(&self, nu: consensus::NetworkUpgrade) -> Option<consensus::BlockHeight> {
match self {
Self::Consensus(params) => params.activation_height(nu),
Self::RegTest {
overwinter,
sapling,
blossom,
heartwood,
canopy,
nu5,
} => match nu {
consensus::NetworkUpgrade::Overwinter => *overwinter,
consensus::NetworkUpgrade::Sapling => *sapling,
consensus::NetworkUpgrade::Blossom => *blossom,
consensus::NetworkUpgrade::Heartwood => *heartwood,
consensus::NetworkUpgrade::Canopy => *canopy,
consensus::NetworkUpgrade::Nu5 => *nu5,
},
}
}
fn coin_type(&self) -> u32 {
match self {
Self::Consensus(params) => params.coin_type(),
Self::RegTest { .. } => constants::regtest::COIN_TYPE,
}
}
fn hrp_sapling_extended_spending_key(&self) -> &str {
match self {
Self::Consensus(params) => params.hrp_sapling_extended_spending_key(),
Self::RegTest { .. } => constants::regtest::HRP_SAPLING_EXTENDED_SPENDING_KEY,
}
}
fn hrp_sapling_extended_full_viewing_key(&self) -> &str {
match self {
Self::Consensus(params) => params.hrp_sapling_extended_full_viewing_key(),
Self::RegTest { .. } => constants::regtest::HRP_SAPLING_EXTENDED_FULL_VIEWING_KEY,
}
}
fn hrp_sapling_payment_address(&self) -> &str {
match self {
Self::Consensus(params) => params.hrp_sapling_payment_address(),
Self::RegTest { .. } => constants::regtest::HRP_SAPLING_PAYMENT_ADDRESS,
}
}
fn b58_pubkey_address_prefix(&self) -> [u8; 2] {
match self {
Self::Consensus(params) => params.b58_pubkey_address_prefix(),
Self::RegTest { .. } => constants::regtest::B58_PUBKEY_ADDRESS_PREFIX,
}
}
fn b58_script_address_prefix(&self) -> [u8; 2] {
match self {
Self::Consensus(params) => params.b58_script_address_prefix(),
Self::RegTest { .. } => constants::regtest::B58_SCRIPT_ADDRESS_PREFIX,
}
}
}
/// A decrypted note.
struct DecryptedNote<D: Domain> {
/// The incoming viewing key used to decrypt the note.
ivk: D::IncomingViewingKey,
/// The recipient of the note.
recipient: D::Recipient,
/// The note!
note: D::Note,
/// The memo sent with the note.
memo: D::Memo,
}
impl<D: Domain> fmt::Debug for DecryptedNote<D>
where
D::IncomingViewingKey: fmt::Debug,
D::Recipient: fmt::Debug,
D::Note: fmt::Debug,
D::Memo: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DecryptedNote")
.field("ivk", &self.ivk)
.field("recipient", &self.recipient)
.field("note", &self.note)
.field("memo", &self.memo)
.finish()
}
}
/// A value correlated with an output index.
struct OutputIndex<V> {
/// The index of the output within the corresponding shielded bundle.
output_index: usize,
/// The value for the output index.
value: V,
}
type OutputReplier<D> = OutputIndex<channel::Sender<OutputIndex<Option<DecryptedNote<D>>>>>;
/// A batch of outputs to trial decrypt.
struct Batch<D: BatchDomain, Output: ShieldedOutput<D, ENC_CIPHERTEXT_SIZE>> {
ivks: Vec<D::IncomingViewingKey>,
outputs: Vec<(D, Output)>,
repliers: Vec<OutputReplier<D>>,
}
impl<D, Output> Batch<D, Output>
where
D: BatchDomain,
Output: ShieldedOutput<D, ENC_CIPHERTEXT_SIZE>,
D::IncomingViewingKey: Clone,
{
/// Constructs a new batch.
fn new(ivks: Vec<D::IncomingViewingKey>) -> Self {
Self {
ivks,
outputs: vec![],
repliers: vec![],
}
}
/// Returns `true` if the batch is currently empty.
fn is_empty(&self) -> bool {
self.outputs.is_empty()
}
/// Runs the batch of trial decryptions, and reports the results.
fn run(self) {
assert_eq!(self.outputs.len(), self.repliers.len());
let decrypted = batch::try_note_decryption(&self.ivks, &self.outputs);
for (decrypted_note, (ivk, replier)) in decrypted.into_iter().zip(
// The output of `batch::try_note_decryption` corresponds to the stream of
// trial decryptions:
// (ivk0, out0), (ivk0, out1), ..., (ivk0, outN), (ivk1, out0), ...
// So we can use the position in the stream to figure out which output was
// decrypted and which ivk decrypted it.
self.ivks
.iter()
.flat_map(|ivk| self.repliers.iter().map(move |tx| (ivk, tx))),
) {
let value = decrypted_note.map(|(note, recipient, memo)| DecryptedNote {
ivk: ivk.clone(),
memo,
note,
recipient,
});
let output_index = replier.output_index;
let tx = &replier.value;
if tx
.send(OutputIndex {
output_index,
value,
})
.is_err()
{
tracing::debug!("BatchRunner was dropped before batch finished");
return;
}
}
}
}
impl<D: BatchDomain, Output: ShieldedOutput<D, ENC_CIPHERTEXT_SIZE> + Clone> Batch<D, Output> {
/// Adds the given outputs to this batch.
///
/// `replier` will be called with the result of every output.
fn add_outputs(
&mut self,
domain: impl Fn() -> D,
outputs: &[Output],
replier: channel::Sender<OutputIndex<Option<DecryptedNote<D>>>>,
) {
self.outputs
.extend(outputs.iter().cloned().map(|output| (domain(), output)));
self.repliers
.extend((0..outputs.len()).map(|output_index| OutputIndex {
output_index,
value: replier.clone(),
}));
}
}
type ResultKey = (BlockHash, TxId);
/// Logic to run batches of trial decryptions on the global threadpool.
struct BatchRunner<D: BatchDomain, Output: ShieldedOutput<D, ENC_CIPHERTEXT_SIZE>> {
acc: Batch<D, Output>,
pending_results: HashMap<ResultKey, channel::Receiver<OutputIndex<Option<DecryptedNote<D>>>>>,
}
impl<D, Output> BatchRunner<D, Output>
where
D: BatchDomain,
Output: ShieldedOutput<D, ENC_CIPHERTEXT_SIZE>,
D::IncomingViewingKey: Clone,
{
/// Constructs a new batch runner for the given incoming viewing keys.
fn new(ivks: Vec<D::IncomingViewingKey>) -> Self {
Self {
acc: Batch::new(ivks),
pending_results: HashMap::default(),
}
}
}
impl<D, Output> BatchRunner<D, Output>
where
D: BatchDomain + Send + 'static,
D::IncomingViewingKey: Clone + Send,
D::Memo: Send,
D::Note: Send,
D::Recipient: Send,
Output: ShieldedOutput<D, ENC_CIPHERTEXT_SIZE> + Clone + Send + 'static,
{
/// Batches the given outputs for trial decryption.
///
/// `block_tag` is the hash of the block that triggered this txid being added to the
/// batch, or the all-zeros hash to indicate that no block triggered it (i.e. it was a
/// mempool change).
///
/// If after adding the given outputs, the accumulated batch size is at least
/// `BATCH_SIZE_THRESHOLD`, `Self::flush` is called. Subsequent calls to
/// `Self::add_outputs` will be accumulated into a new batch.
fn add_outputs(
&mut self,
block_tag: BlockHash,
txid: TxId,
domain: impl Fn() -> D,
outputs: &[Output],
) {
let (tx, rx) = channel::unbounded();
self.acc.add_outputs(domain, outputs, tx);
self.pending_results.insert((block_tag, txid), rx);
if self.acc.outputs.len() >= BATCH_SIZE_THRESHOLD {
self.flush();
}
}
/// Runs the currently accumulated batch on the global threadpool.
///
/// Subsequent calls to `Self::add_outputs` will be accumulated into a new batch.
fn flush(&mut self) {
if !self.acc.is_empty() {
let mut batch = Batch::new(self.acc.ivks.clone());
mem::swap(&mut batch, &mut self.acc);
rayon::spawn_fifo(|| batch.run());
}
}
/// Collects the pending decryption results for the given transaction.
///
/// `block_tag` is the hash of the block that triggered this txid being added to the
/// batch, or the all-zeros hash to indicate that no block triggered it (i.e. it was a
/// mempool change).
fn collect_results(
&mut self,
block_tag: BlockHash,
txid: TxId,
) -> HashMap<(TxId, usize), DecryptedNote<D>> {
self.pending_results
.remove(&(block_tag, txid))
// We won't have a pending result if the transaction didn't have outputs of
// this runner's kind.
.map(|rx| {
rx.into_iter()
.filter_map(
|OutputIndex {
output_index,
value,
}| {
value.map(|decrypted_note| ((txid, output_index), decrypted_note))
},
)
.collect()
})
.unwrap_or_default()
}
}
/// A batch scanner for the `zcashd` wallet.
struct BatchScanner {
params: Network,
sapling_runner: Option<BatchRunner<SaplingDomain<Network>, OutputDescription<GrothProofBytes>>>,
}
fn init_batch_scanner(
network: &Network,
sapling_ivks: &[[u8; 32]],
) -> Result<Box<BatchScanner>, &'static str> {
let sapling_runner = if sapling_ivks.is_empty() {
None
} else {
let ivks = sapling_ivks
.iter()
.map(|ivk| {
let ivk: Option<sapling::SaplingIvk> =
jubjub::Fr::from_bytes(ivk).map(sapling::SaplingIvk).into();
ivk.ok_or("Invalid Sapling ivk passed to wallet::init_batch_scanner()")
})
.collect::<Result<_, _>>()?;
Some(BatchRunner::new(ivks))
};
Ok(Box::new(BatchScanner {
params: *network,
sapling_runner,
}))
}
impl BatchScanner {
/// Adds the given transaction's shielded outputs to the various batch runners.
///
/// `block_tag` is the hash of the block that triggered this txid being added to the
/// batch, or the all-zeros hash to indicate that no block triggered it (i.e. it was a
/// mempool change).
///
/// After adding the outputs, any accumulated batch of sufficient size is run on the
/// global threadpool. Subsequent calls to `Self::add_transaction` will accumulate
/// those output kinds into new batches.
fn add_transaction(
&mut self,
block_tag: [u8; 32],
tx_bytes: &[u8],
height: u32,
) -> Result<(), io::Error> {
let block_tag = BlockHash(block_tag);
// The consensusBranchId parameter is ignored; it is not used in trial decryption
// and does not affect transaction parsing.
let tx = Transaction::read(tx_bytes, consensus::BranchId::Sprout)?;
let txid = tx.txid();
let height = consensus::BlockHeight::from_u32(height);
// If we have any Sapling IVKs, and the transaction has any Sapling outputs, queue
// the outputs for trial decryption.
if let Some((runner, bundle)) = self.sapling_runner.as_mut().zip(tx.sapling_bundle()) {
let params = self.params;
runner.add_outputs(
block_tag,
txid,
|| SaplingDomain::for_height(params, height),
&bundle.shielded_outputs,
);
}
Ok(())
}
/// Runs the currently accumulated batches on the global threadpool.
///
/// Subsequent calls to `Self::add_transaction` will be accumulated into new batches.
fn flush(&mut self) {
if let Some(runner) = &mut self.sapling_runner {
runner.flush();
}
}
/// Collects the pending decryption results for the given transaction.
///
/// `block_tag` is the hash of the block that triggered this txid being added to the
/// batch, or the all-zeros hash to indicate that no block triggered it (i.e. it was a
/// mempool change).
///
/// TODO: Return the `HashMap`s directly once `cxx` supports it.
fn collect_results(&mut self, block_tag: [u8; 32], txid: [u8; 32]) -> Box<BatchResult> {
let block_tag = BlockHash(block_tag);
let txid = TxId::from_bytes(txid);
let sapling = self
.sapling_runner
.as_mut()
.map(|runner| runner.collect_results(block_tag, txid))
.unwrap_or_default();
Box::new(BatchResult { sapling })
}
}
struct BatchResult {
sapling: HashMap<(TxId, usize), DecryptedNote<SaplingDomain<Network>>>,
}
impl BatchResult {
fn get_sapling(&self) -> Vec<ffi::SaplingDecryptionResult> {
self.sapling
.iter()
.map(
|((txid, output), decrypted_note)| ffi::SaplingDecryptionResult {
txid: *txid.as_ref(),
output: *output as u32,
ivk: decrypted_note.ivk.to_repr(),
diversifier: decrypted_note.recipient.diversifier().0,
pk_d: decrypted_note.recipient.pk_d().to_bytes(),
},
)
.collect()
}
}

View File

@ -7,6 +7,7 @@
#ifndef BITCOIN_UINT256_H
#define BITCOIN_UINT256_H
#include <array>
#include <assert.h>
#include <cstring>
#include <stdexcept>
@ -130,6 +131,13 @@ public:
uint256() {}
explicit uint256(const std::vector<unsigned char>& vch) : base_blob<256>(vch) {}
static uint256 FromRawBytes(std::array<uint8_t, 32> bytes)
{
uint256 buf;
std::memcpy(buf.begin(), bytes.data(), 32);
return buf;
}
/** A cheap hash function that just returns 64 bits from the result, it can be
* used when the contents are considered uniformly random. It is not appropriate
* when the value can easily be influenced from outside as e.g. a network adversary could

View File

@ -28,6 +28,7 @@ CMainSignals& GetMainSignals()
void RegisterValidationInterface(CValidationInterface* pwalletIn) {
g_signals.UpdatedBlockTip.connect(boost::bind(&CValidationInterface::UpdatedBlockTip, pwalletIn, _1));
g_signals.GetBatchScanner.connect(boost::bind(&CValidationInterface::GetBatchScanner, pwalletIn));
g_signals.SyncTransaction.connect(boost::bind(&CValidationInterface::SyncTransaction, pwalletIn, _1, _2, _3));
g_signals.EraseTransaction.connect(boost::bind(&CValidationInterface::EraseFromWallet, pwalletIn, _1));
g_signals.UpdatedTransaction.connect(boost::bind(&CValidationInterface::UpdatedTransaction, pwalletIn, _1));
@ -49,6 +50,7 @@ void UnregisterValidationInterface(CValidationInterface* pwalletIn) {
g_signals.UpdatedTransaction.disconnect(boost::bind(&CValidationInterface::UpdatedTransaction, pwalletIn, _1));
g_signals.EraseTransaction.disconnect(boost::bind(&CValidationInterface::EraseFromWallet, pwalletIn, _1));
g_signals.SyncTransaction.disconnect(boost::bind(&CValidationInterface::SyncTransaction, pwalletIn, _1, _2, _3));
g_signals.GetBatchScanner.disconnect(boost::bind(&CValidationInterface::GetBatchScanner, pwalletIn));
g_signals.UpdatedBlockTip.disconnect(boost::bind(&CValidationInterface::UpdatedBlockTip, pwalletIn, _1));
}
@ -62,10 +64,39 @@ void UnregisterAllValidationInterfaces() {
g_signals.UpdatedTransaction.disconnect_all_slots();
g_signals.EraseTransaction.disconnect_all_slots();
g_signals.SyncTransaction.disconnect_all_slots();
g_signals.GetBatchScanner.disconnect_all_slots();
g_signals.UpdatedBlockTip.disconnect_all_slots();
}
void SyncWithWallets(const CTransaction &tx, const CBlock *pblock, const int nHeight) {
void AddTxToBatches(
std::vector<BatchScanner*> &batchScanners,
const CTransaction &tx,
const uint256 &blockTag,
const int nHeight)
{
CDataStream ssTx(SER_NETWORK, PROTOCOL_VERSION);
ssTx << tx;
std::vector<unsigned char> txBytes(ssTx.begin(), ssTx.end());
for (auto& batchScanner : batchScanners) {
batchScanner->AddTransaction(tx, txBytes, blockTag, nHeight);
}
}
void FlushBatches(std::vector<BatchScanner*> &batchScanners) {
for (auto& batchScanner : batchScanners) {
batchScanner->Flush();
}
}
void SyncWithWallets(
std::vector<BatchScanner*> &batchScanners,
const CTransaction &tx,
const CBlock *pblock,
const int nHeight)
{
for (auto& batchScanner : batchScanners) {
batchScanner->SyncTransaction(tx, pblock, nHeight);
}
g_signals.SyncTransaction(tx, pblock, nHeight);
}
@ -185,6 +216,163 @@ void ThreadNotifyWallets(CBlockIndex *pindexLastTip)
// network message processing thread.
//
// The wallet inherited from Bitcoin Core was built around the following
// general workflow for moving from one chain tip to another:
//
// - For each block in the old chain, from its tip to the fork point:
// - For each transaction in the block:
// - 1⃣ Trial-decrypt the transaction's shielded outputs.
// - If the transaction belongs to the wallet:
// - 2⃣ Add or update the transaction, and mark it as dirty.
// - Update the wallet's view of the chain tip.
// - 3⃣ In `zcashd`, this is when we decrement note witnesses.
// - For each block in the new chain, from the fork point to its tip:
// - For each transaction that became conflicted by this block:
// - 4⃣ Trial-decrypt the transaction's shielded outputs.
// - If the transaction belongs to the wallet:
// - 5⃣ Add or update the transaction, and mark it as dirty.
// - For each transaction in the block:
// - 6⃣ Trial-decrypt the transaction's shielded outputs.
// - If the transaction belongs to the wallet:
// - 7⃣ Add or update the transaction, and mark it as dirty.
// - Update the wallet's view of the chain tip.
// - 8⃣ In `zcashd`, this is when we increment note witnesses.
// - For each transaction in the mempool:
// - 9⃣ Trial-decrypt the transaction's shielded outputs.
// - If the transaction belongs to the wallet:
// - 🅰️ Add or update the transaction, and mark it as dirty.
//
// Steps 2⃣, 3⃣, 5⃣, 7⃣, 8⃣, and 🅰️ are where wallet state is updated,
// and the relative order of these updates must be preserved in order to
// avoid breaking any internal assumptions that the wallet makes.
//
// Steps 1⃣, 4⃣, 6⃣, and 9⃣ can be performed at any time, as long as
// their results are available when their respective conditionals are
// evaluated. We therefore refactor the above workflow to enable the
// trial-decryption work to be batched and parallelised:
//
// - For each block in the old chain, from its tip to the fork point:
// - For each transaction in the block:
// - Accumulate its Sprout, Sapling, and Orchard outputs.
// - For each block in the new chain, from the fork point to its tip:
// - For each transaction that became conflicted by this block:
// - Accumulate its Sprout, Sapling, and Orchard outputs.
// - For each transaction in the block:
// - Accumulate its Sprout, Sapling, and Orchard outputs.
//
// - 1⃣4⃣6⃣9⃣ Trial-decrypt the Sprout, Sapling, and Orchard outputs.
// - This can split up and batch the work however is most efficient.
//
// - For each block in the old chain, from its tip to the fork point:
// - For each transaction in the block:
// - If the transaction has decrypted outputs, or transparent inputs
// that belong to the wallet:
// - 2⃣ Add or update the transaction, and mark it as dirty.
// - Update the wallet's view of the chain tip.
// - 3⃣ In `zcashd`, this is when we decrement note witnesses.
// - For each block in the new chain, from the fork point to its tip:
// - For each transaction that became conflicted by this block:
// - If the transaction has decrypted outputs, or transparent inputs
// that belong to the wallet:
// - 5⃣ Add or update the transaction, and mark it as dirty.
// - For each transaction in the block:
// - If the transaction has decrypted outputs, or transparent inputs
// that belong to the wallet:
// - 7⃣ Add or update the transaction, and mark it as dirty.
// - Update the wallet's view of the chain tip.
// - 8⃣ In `zcashd`, this is when we increment note witnesses.
// - For each transaction in the mempool:
// - If the transaction has decrypted outputs, or transparent inputs
// that belong to the wallet:
// - 🅰️ Add or update the transaction, and mark it as dirty.
// Get a new handle to the BatchScanner for each listener in each loop.
// This allows the listeners to alter their scanning logic over time,
// for example to add new incoming viewing keys.
auto batchScanners = GetMainSignals().GetBatchScanner();
if (!batchScanners.empty()) {
// Batch the shielded outputs across all blocks being processed.
// TODO: We can probably not bother trial-decrypting transactions
// in blocks being disconnected, or that are becoming conflicted,
// instead doing a plain "is this tx in the wallet" check. However,
// the logic in AddToWalletIfInvolvingMe would need to be carefully
// checked to ensure its side-effects are correctly preserved, so
// for now we maintain the previous behaviour of trial-decrypting
// everything.
// Batch block disconnects.
auto pindexScan = pindexLastTip;
while (pindexScan && pindexScan != pindexFork) {
// Read block from disk.
CBlock block;
if (!ReadBlockFromDisk(block, pindexScan, chainParams.GetConsensus())) {
LogPrintf(
"*** %s: Failed to read block %s while collecting shielded outputs",
__func__, pindexScan->GetBlockHash().GetHex());
uiInterface.ThreadSafeMessageBox(
_("Error: A fatal internal error occurred, see debug.log for details"),
"", CClientUIInterface::MSG_ERROR);
StartShutdown();
return;
}
// Batch transactions that went from 1-confirmed to 0-confirmed
// or conflicted.
for (const CTransaction &tx : block.vtx) {
AddTxToBatches(batchScanners, tx, block.GetHash(), pindexScan->nHeight);
}
// On to the next block!
pindexScan = pindexScan->pprev;
}
// Batch block connections. Process blockStack in the same order we
// do below, so batched work can be completed in roughly the order
// we need it.
for (auto it = blockStack.rbegin(); it != blockStack.rend(); ++it) {
const auto& blockData = *it;
// Read block from disk.
CBlock block;
if (!ReadBlockFromDisk(block, blockData.pindex, chainParams.GetConsensus())) {
LogPrintf(
"*** %s: Failed to read block %s while collecting shielded outputs from block connects",
__func__, blockData.pindex->GetBlockHash().GetHex());
uiInterface.ThreadSafeMessageBox(
_("Error: A fatal internal error occurred, see debug.log for details"),
"", CClientUIInterface::MSG_ERROR);
StartShutdown();
return;
}
// Batch transactions that went from mempool to conflicted:
for (const CTransaction &tx : blockData.txConflicted) {
AddTxToBatches(
batchScanners,
tx,
blockData.pindex->GetBlockHash(),
blockData.pindex->nHeight + 1);
}
// ... and transactions that got confirmed:
for (const CTransaction &tx : block.vtx) {
AddTxToBatches(
batchScanners,
tx,
blockData.pindex->GetBlockHash(),
blockData.pindex->nHeight);
}
}
// Batch transactions in the mempool.
for (auto tx : recentlyAdded.first) {
AddTxToBatches(batchScanners, tx, uint256(), pindexLastTip->nHeight + 1);
}
}
// Ensure that all pending work has been started.
FlushBatches(batchScanners);
// Notify block disconnects
while (pindexLastTip && pindexLastTip != pindexFork) {
// Read block from disk.
@ -203,7 +391,7 @@ void ThreadNotifyWallets(CBlockIndex *pindexLastTip)
// Let wallets know transactions went from 1-confirmed to
// 0-confirmed or conflicted:
for (const CTransaction &tx : block.vtx) {
SyncWithWallets(tx, NULL, pindexLastTip->nHeight);
SyncWithWallets(batchScanners, tx, NULL, pindexLastTip->nHeight);
}
// Update cached incremental witnesses
// This will take the cs_main lock in order to obtain the CBlockLocator
@ -237,11 +425,11 @@ void ThreadNotifyWallets(CBlockIndex *pindexLastTip)
// Tell wallet about transactions that went from mempool
// to conflicted:
for (const CTransaction &tx : blockData.txConflicted) {
SyncWithWallets(tx, NULL, blockData.pindex->nHeight + 1);
SyncWithWallets(batchScanners, tx, NULL, blockData.pindex->nHeight + 1);
}
// ... and about transactions that got confirmed:
for (const CTransaction &tx : block.vtx) {
SyncWithWallets(tx, &block, blockData.pindex->nHeight);
SyncWithWallets(batchScanners, tx, &block, blockData.pindex->nHeight);
}
// Update cached incremental witnesses
// This will take the cs_main lock in order to obtain the CBlockLocator
@ -262,7 +450,7 @@ void ThreadNotifyWallets(CBlockIndex *pindexLastTip)
// Notify transactions in the mempool
for (auto tx : recentlyAdded.first) {
try {
SyncWithWallets(tx, NULL, pindexLastTip->nHeight + 1);
SyncWithWallets(batchScanners, tx, NULL, pindexLastTip->nHeight + 1);
} catch (const boost::thread_interrupted&) {
throw;
} catch (const std::exception& e) {

View File

@ -24,6 +24,43 @@ class CValidationInterface;
class CValidationState;
class uint256;
class BatchScanner {
public:
/**
* Adds a transaction to the batch scanner.
*
* `block_tag` is the hash of the block that triggered this txid being added
* to the batch, or the all-zeros hash to indicate that no block triggered
* it (i.e. it was a mempool change).
*/
virtual void AddTransaction(
const CTransaction &tx,
const std::vector<unsigned char> &txBytes,
const uint256 &blockTag,
const int nHeight) = 0;
/**
* Flushes any pending batches.
*
* After calling this, every transaction passed to `AddTransaction` should
* have its result available when the matching call to `SyncTransaction` is
* made.
*/
virtual void Flush() = 0;
/**
* Notifies the batch scanner of updated transaction data (transaction, and
* optionally the block it is found in).
*
* This will be called with transactions in the same order as they were
* `AddTransaction`.
*/
virtual void SyncTransaction(
const CTransaction &tx,
const CBlock *pblock,
const int nHeight) = 0;
};
struct MerkleFrontiers {
SproutMerkleTree sprout;
SaplingMerkleTree sapling;
@ -42,6 +79,7 @@ void UnregisterAllValidationInterfaces();
class CValidationInterface {
protected:
virtual void UpdatedBlockTip(const CBlockIndex *pindex) {}
virtual BatchScanner* GetBatchScanner() { return nullptr; }
virtual void SyncTransaction(const CTransaction &tx, const CBlock *pblock, const int nHeight) {}
virtual void EraseFromWallet(const uint256 &hash) {}
virtual void ChainTip(const CBlockIndex *pindex, const CBlock *pblock, std::optional<MerkleFrontiers> added) {}
@ -56,10 +94,57 @@ protected:
friend void ::UnregisterAllValidationInterfaces();
};
// aggregate_non_null_values is a combiner which places any non-nullptr values
// returned from slots into a container.
template<typename Container>
struct aggregate_non_null_values
{
typedef Container result_type;
template<typename InputIterator>
Container operator()(InputIterator first, InputIterator last) const
{
Container values;
while (first != last) {
auto ptr = *first;
if (ptr != nullptr) {
values.push_back(ptr);
}
++first;
}
return values;
}
};
struct CMainSignals {
/** Notifies listeners of updated block chain tip */
boost::signals2::signal<void (const CBlockIndex *)> UpdatedBlockTip;
/** Notifies listeners of updated transaction data (transaction, and optionally the block it is found in. */
/**
* Requests a pointer to the listener's batch scanner for shielded outputs,
* if it has one.
*
* The listener is responsible for managing the memory of the batch scanner.
* In practice each listener will have a single persistent batch scanner.
*
* This signal is called at the start of each notification loop, which runs
* on integer second boundaries. This is an opportunity for the listener to
* perform any updating of the batch scanner's internal state (such as
* updating its set of incoming viewing keys).
*
* Listeners of this signal should not listen to `SyncTransaction` or they
* will be notified about transactions twice.
*/
boost::signals2::signal<
BatchScanner* (),
aggregate_non_null_values<std::vector<BatchScanner*>>> GetBatchScanner;
/**
* Notifies listeners of updated transaction data (the transaction, and
* optionally the block it is found in).
*
* Listeners of this signal should not listen to `GetBatchScanner` or they
* will be notified about transactions twice.
*/
boost::signals2::signal<void (const CTransaction &, const CBlock *, const int nHeight)> SyncTransaction;
/** Notifies listeners of an erased transaction (currently disabled, requires transaction replacement). */
boost::signals2::signal<void (const uint256 &)> EraseTransaction;

View File

@ -574,13 +574,13 @@ TEST(WalletTests, FindMySaplingNotes) {
// No Sapling notes can be found in tx which does not belong to the wallet
CWalletTx wtx {&wallet, tx};
ASSERT_FALSE(wallet.HaveSaplingSpendingKey(extfvk));
auto noteMap = wallet.FindMySaplingNotes(wtx, 1).first;
auto noteMap = wallet.FindMySaplingNotes(consensusParams, wtx, 1).first;
EXPECT_EQ(0, noteMap.size());
// Add spending key to wallet, so Sapling notes can be found
ASSERT_TRUE(wallet.AddSaplingZKey(sk));
ASSERT_TRUE(wallet.HaveSaplingSpendingKey(extfvk));
noteMap = wallet.FindMySaplingNotes(wtx, 1).first;
noteMap = wallet.FindMySaplingNotes(consensusParams, wtx, 1).first;
EXPECT_EQ(2, noteMap.size());
// Revert to default
@ -733,14 +733,14 @@ TEST(WalletTests, GetConflictedSaplingNotes) {
EXPECT_EQ(0, chainActive.Height());
// Simulate SyncTransaction which calls AddToWalletIfInvolvingMe
auto saplingNoteData = wallet.FindMySaplingNotes(wtx, 1).first;
auto saplingNoteData = wallet.FindMySaplingNotes(consensusParams, wtx, 1).first;
ASSERT_TRUE(saplingNoteData.size() > 0);
wtx.SetSaplingNoteData(saplingNoteData);
wtx.SetMerkleBranch(block);
wallet.LoadWalletTx(wtx);
// Simulate receiving new block and ChainTip signal
wallet.IncrementNoteWitnesses(Params().GetConsensus(),&fakeIndex, &block, frontiers, true);
wallet.IncrementNoteWitnesses(consensusParams, &fakeIndex, &block, frontiers, true);
wallet.UpdateSaplingNullifierNoteMapForBlock(&block);
// Retrieve the updated wtx from wallet
@ -871,7 +871,7 @@ TEST(WalletTests, GetConflictedOrchardNotes) {
wallet.LoadWalletTx(wtx);
// Simulate receiving new block and ChainTip signal
wallet.IncrementNoteWitnesses(Params().GetConsensus(),&fakeIndex, &block, frontiers, true);
wallet.IncrementNoteWitnesses(consensusParams, &fakeIndex, &block, frontiers, true);
// Fetch the Orchard note so we can spend it.
std::vector<SproutNoteEntry> sproutEntries;
@ -1123,7 +1123,7 @@ TEST(WalletTests, NavigateFromSaplingNullifierToNote) {
// Simulate SyncTransaction which calls AddToWalletIfInvolvingMe
wtx.SetMerkleBranch(block);
auto saplingNoteData = wallet.FindMySaplingNotes(wtx, chainActive.Height()).first;
auto saplingNoteData = wallet.FindMySaplingNotes(consensusParams, wtx, chainActive.Height()).first;
ASSERT_TRUE(saplingNoteData.size() > 0);
wtx.SetSaplingNoteData(saplingNoteData);
wallet.LoadWalletTx(wtx);
@ -1140,7 +1140,7 @@ TEST(WalletTests, NavigateFromSaplingNullifierToNote) {
}
// Simulate receiving new block and ChainTip signal
wallet.IncrementNoteWitnesses(Params().GetConsensus(), &fakeIndex, &block, frontiers, true);
wallet.IncrementNoteWitnesses(consensusParams, &fakeIndex, &block, frontiers, true);
wallet.UpdateSaplingNullifierNoteMapForBlock(&block);
// Retrieve the updated wtx from wallet
@ -1250,7 +1250,7 @@ TEST(WalletTests, SpentSaplingNoteIsFromMe) {
EXPECT_TRUE(chainActive.Contains(&fakeIndex));
EXPECT_EQ(0, chainActive.Height());
auto saplingNoteData = wallet.FindMySaplingNotes(wtx, 1).first;
auto saplingNoteData = wallet.FindMySaplingNotes(consensusParams, wtx, 1).first;
ASSERT_TRUE(saplingNoteData.size() > 0);
wtx.SetSaplingNoteData(saplingNoteData);
wtx.SetMerkleBranch(block);
@ -1259,7 +1259,7 @@ TEST(WalletTests, SpentSaplingNoteIsFromMe) {
// Simulate receiving new block and ChainTip signal.
// This triggers calculation of nullifiers for notes belonging to this wallet
// in the output descriptions of wtx.
wallet.IncrementNoteWitnesses(Params().GetConsensus(), &fakeIndex, &block, frontiers, true);
wallet.IncrementNoteWitnesses(consensusParams, &fakeIndex, &block, frontiers, true);
wallet.UpdateSaplingNullifierNoteMapForBlock(&block);
// Retrieve the updated wtx from wallet
@ -1327,7 +1327,7 @@ TEST(WalletTests, SpentSaplingNoteIsFromMe) {
EXPECT_TRUE(chainActive.Contains(&fakeIndex2));
EXPECT_EQ(1, chainActive.Height());
auto saplingNoteData2 = wallet.FindMySaplingNotes(wtx2, 2).first;
auto saplingNoteData2 = wallet.FindMySaplingNotes(consensusParams, wtx2, 2).first;
ASSERT_TRUE(saplingNoteData2.size() > 0);
wtx2.SetSaplingNoteData(saplingNoteData2);
wtx2.SetMerkleBranch(block2);
@ -2110,14 +2110,14 @@ TEST(WalletTests, UpdatedSaplingNoteData) {
EXPECT_EQ(0, chainActive.Height());
// Simulate SyncTransaction which calls AddToWalletIfInvolvingMe
auto saplingNoteData = wallet.FindMySaplingNotes(wtx, chainActive.Height()).first;
auto saplingNoteData = wallet.FindMySaplingNotes(consensusParams, wtx, chainActive.Height()).first;
ASSERT_TRUE(saplingNoteData.size() == 1); // wallet only has key for change output
wtx.SetSaplingNoteData(saplingNoteData);
wtx.SetMerkleBranch(block);
wallet.LoadWalletTx(wtx);
// Simulate receiving new block and ChainTip signal
wallet.IncrementNoteWitnesses(Params().GetConsensus(), &fakeIndex, &block, frontiers, true);
wallet.IncrementNoteWitnesses(consensusParams, &fakeIndex, &block, frontiers, true);
wallet.UpdateSaplingNullifierNoteMapForBlock(&block);
// Retrieve the updated wtx from wallet
@ -2128,7 +2128,7 @@ TEST(WalletTests, UpdatedSaplingNoteData) {
ASSERT_TRUE(wallet.AddSaplingZKey(sk2));
ASSERT_TRUE(wallet.HaveSaplingSpendingKey(extfvk2));
CWalletTx wtx2 = wtx;
auto saplingNoteData2 = wallet.FindMySaplingNotes(wtx2, chainActive.Height()).first;
auto saplingNoteData2 = wallet.FindMySaplingNotes(consensusParams, wtx2, chainActive.Height()).first;
ASSERT_TRUE(saplingNoteData2.size() == 2);
wtx2.SetSaplingNoteData(saplingNoteData2);
@ -2259,14 +2259,14 @@ TEST(WalletTests, MarkAffectedSaplingTransactionsDirty) {
EXPECT_EQ(0, chainActive.Height());
// Simulate SyncTransaction which calls AddToWalletIfInvolvingMe
auto saplingNoteData = wallet.FindMySaplingNotes(wtx, chainActive.Height()).first;
auto saplingNoteData = wallet.FindMySaplingNotes(consensusParams, wtx, chainActive.Height()).first;
ASSERT_TRUE(saplingNoteData.size() > 0);
wtx.SetSaplingNoteData(saplingNoteData);
wtx.SetMerkleBranch(block);
wallet.LoadWalletTx(wtx);
// Simulate receiving new block and ChainTip signal
wallet.IncrementNoteWitnesses(Params().GetConsensus(), &fakeIndex, &block, frontiers, true);
wallet.IncrementNoteWitnesses(consensusParams, &fakeIndex, &block, frontiers, true);
wallet.UpdateSaplingNullifierNoteMapForBlock(&block);
// Retrieve the updated wtx from wallet

View File

@ -3375,13 +3375,16 @@ bool CWallet::UpdatedNoteData(const CWalletTx& wtxIn, CWalletTx& wtx)
auto tmp = wtxIn.mapSproutNoteData;
// Ensure we keep any cached witnesses we may already have
for (const std::pair <JSOutPoint, SproutNoteData> nd : wtx.mapSproutNoteData) {
if (tmp.count(nd.first)) {
if (nd.second.witnesses.size() > 0) {
tmp.at(nd.first).witnesses.assign(
nd.second.witnesses.cbegin(), nd.second.witnesses.cend());
}
tmp.at(nd.first).witnessHeight = nd.second.witnessHeight;
// Require that wtxIn's data is a superset of wtx's data. This holds
// because viewing keys are _never_ deleted from the wallet, so the
// number of detected notes can only increase.
assert(tmp.count(nd.first) == 1);
if (nd.second.witnesses.size() > 0) {
tmp.at(nd.first).witnesses.assign(
nd.second.witnesses.cbegin(), nd.second.witnesses.cend());
}
tmp.at(nd.first).witnessHeight = nd.second.witnessHeight;
}
// Now copy over the updated note data
wtx.mapSproutNoteData = tmp;
@ -3393,13 +3396,16 @@ bool CWallet::UpdatedNoteData(const CWalletTx& wtxIn, CWalletTx& wtx)
// Ensure we keep any cached witnesses we may already have
for (const std::pair <SaplingOutPoint, SaplingNoteData> nd : wtx.mapSaplingNoteData) {
if (tmp.count(nd.first)) {
if (nd.second.witnesses.size() > 0) {
tmp.at(nd.first).witnesses.assign(
nd.second.witnesses.cbegin(), nd.second.witnesses.cend());
}
tmp.at(nd.first).witnessHeight = nd.second.witnessHeight;
// Require that wtxIn's data is a superset of wtx's data. This holds
// because viewing keys are _never_ deleted from the wallet, so the
// number of detected notes can only increase.
assert(tmp.count(nd.first) == 1);
if (nd.second.witnesses.size() > 0) {
tmp.at(nd.first).witnesses.assign(
nd.second.witnesses.cbegin(), nd.second.witnesses.cend());
}
tmp.at(nd.first).witnessHeight = nd.second.witnessHeight;
}
// Now copy over the updated note data
@ -3414,6 +3420,23 @@ bool CWallet::UpdatedNoteData(const CWalletTx& wtxIn, CWalletTx& wtx)
return !unchangedSproutFlag || !unchangedSaplingFlag || !unchangedOrchardFlag;
}
WalletDecryptedNotes CWallet::TryDecryptShieldedOutputs(const CTransaction& tx)
{
// Sprout
auto sproutNoteData = FindMySproutNotes(tx);
// Sapling is trial decrypted in Rust.
// Orchard
// TODO: Trial decryption of Orchard notes alongside Sprout and Sapling will
// be implemented after batching is implemented, as then we can just handle
// everything in Rust.
return WalletDecryptedNotes {
.sproutNoteData = sproutNoteData,
};
}
/**
* Add a transaction to the wallet, or update it.
* pblock is optional, but should be provided if the transaction is known to be in a block.
@ -3431,6 +3454,7 @@ bool CWallet::AddToWalletIfInvolvingMe(
const CTransaction& tx,
const CBlock* pblock,
const int nHeight,
WalletDecryptedNotes decryptedNotes,
bool fUpdate)
{
{ // extra scope left in place for backport whitespace compatibility
@ -3441,12 +3465,11 @@ bool CWallet::AddToWalletIfInvolvingMe(
if (fExisted && !fUpdate) return false;
// Sprout
auto sproutNoteData = FindMySproutNotes(tx);
auto sproutNoteData = decryptedNotes.sproutNoteData;
// Sapling
auto saplingNoteDataAndAddressesToAdd = FindMySaplingNotes(tx, nHeight);
auto saplingNoteData = saplingNoteDataAndAddressesToAdd.first;
auto saplingAddressesToAdd = saplingNoteDataAndAddressesToAdd.second;
auto saplingNoteData = decryptedNotes.saplingNoteDataAndAddressesToAdd.first;
auto saplingAddressesToAdd = decryptedNotes.saplingNoteDataAndAddressesToAdd.second;
for (const auto &addressToAdd : saplingAddressesToAdd) {
// Add mapping between address and IVK for easy future lookup.
if (!AddSaplingPaymentAddress(addressToAdd.second, addressToAdd.first)) {
@ -3494,13 +3517,120 @@ bool CWallet::AddToWalletIfInvolvingMe(
}
}
void CWallet::SyncTransaction(const CTransaction& tx, const CBlock* pblock, const int nHeight)
rust::Box<wallet::BatchScanner> WalletBatchScanner::CreateBatchScanner(CWallet* pwallet) {
LOCK(pwallet->cs_KeyStore);
auto chainParams = Params();
auto consensus = chainParams.GetConsensus();
auto network = wallet::network(
chainParams.NetworkIDString(),
consensus.vUpgrades[Consensus::UPGRADE_OVERWINTER].nActivationHeight,
consensus.vUpgrades[Consensus::UPGRADE_SAPLING].nActivationHeight,
consensus.vUpgrades[Consensus::UPGRADE_BLOSSOM].nActivationHeight,
consensus.vUpgrades[Consensus::UPGRADE_HEARTWOOD].nActivationHeight,
consensus.vUpgrades[Consensus::UPGRADE_CANOPY].nActivationHeight,
consensus.vUpgrades[Consensus::UPGRADE_NU5].nActivationHeight);
// TODO: Pass the map across the FFI once cxx supports it.
std::vector<std::array<uint8_t, 32>> ivks;
for (const auto& it : pwallet->mapSaplingFullViewingKeys) {
SaplingIncomingViewingKey ivk = it.first;
ivks.push_back(ivk.GetRawBytes());
}
return wallet::init_batch_scanner(
*network,
{ivks.data(), ivks.size()});
}
bool WalletBatchScanner::AddToWalletIfInvolvingMe(
const Consensus::Params& consensus,
const CTransaction& tx,
const CBlock* pblock,
const int nHeight,
bool fUpdate)
{
AssertLockHeld(pwallet->cs_wallet);
auto decryptedNotesForTx = decryptedNotes.find(tx.GetHash());
if (decryptedNotesForTx == decryptedNotes.end()) {
throw std::logic_error("Called WalletBatchScanner::AddToWalletIfInvolvingMe with a tx that wasn't passed to AddTransaction");
}
auto decryptedNotes = decryptedNotesForTx->second;
// Fill in the details about decrypted Sapling notes.
uint256 blockTag;
if (pblock) {
blockTag = pblock->GetHash();
}
auto batchResults = inner->collect_results(blockTag.GetRawBytes(), tx.GetHash().GetRawBytes());
for (auto decrypted : batchResults->get_sapling()) {
SaplingIncomingViewingKey ivk(uint256::FromRawBytes(decrypted.ivk));
libzcash::SaplingPaymentAddress addr(
decrypted.diversifier,
uint256::FromRawBytes(decrypted.pk_d));
decryptedNotes.saplingNoteDataAndAddressesToAdd.first.insert(
std::make_pair(
SaplingOutPoint(uint256::FromRawBytes(decrypted.txid), decrypted.output),
SaplingNoteData(ivk)));
// Only track the recipient -> ivk mappings the wallet doesn't have.
if (pwallet->mapSaplingIncomingViewingKeys.count(addr) == 0) {
decryptedNotes.saplingNoteDataAndAddressesToAdd.second.insert(
std::make_pair(addr, ivk));
}
}
return pwallet->AddToWalletIfInvolvingMe(
consensus, tx, pblock, nHeight, decryptedNotes, fUpdate);
}
//
// BatchScanner APIs
//
void WalletBatchScanner::AddTransaction(
const CTransaction &tx,
const std::vector<unsigned char> &txBytes,
const uint256 &blockTag,
const int nHeight)
{
// Decrypt Sprout outputs immediately.
decryptedNotes.insert(
std::make_pair(tx.GetHash(), pwallet->TryDecryptShieldedOutputs(tx)));
// Queue Sapling outputs for trial decryption.
inner->add_transaction(blockTag.GetRawBytes(), {txBytes.data(), txBytes.size()}, nHeight);
}
void WalletBatchScanner::Flush() {
inner->flush();
}
void WalletBatchScanner::SyncTransaction(
const CTransaction &tx,
const CBlock *pblock,
const int nHeight)
{
LOCK(pwallet->cs_wallet);
if (!AddToWalletIfInvolvingMe(Params().GetConsensus(), tx, pblock, nHeight, true)) {
return; // Not one of ours
}
pwallet->MarkAffectedTransactionsDirty(tx);
}
BatchScanner* CWallet::GetBatchScanner()
{
LOCK(cs_wallet);
if (!AddToWalletIfInvolvingMe(Params().GetConsensus(), tx, pblock, nHeight, true))
return; // Not one of ours
MarkAffectedTransactionsDirty(tx);
// Rebuild the batch scanner to update its set of IVKs.
delete validationInterfaceBatchScanner;
validationInterfaceBatchScanner = new WalletBatchScanner(this);
return validationInterfaceBatchScanner;
}
void CWallet::MarkAffectedTransactionsDirty(const CTransaction& tx)
@ -3637,7 +3767,10 @@ mapSproutNoteData_t CWallet::FindMySproutNotes(const CTransaction &tx) const
* the result of FindMySaplingNotes (for the addresses available at the time) will
* already have been cached in CWalletTx.mapSaplingNoteData.
*/
std::pair<mapSaplingNoteData_t, SaplingIncomingViewingKeyMap> CWallet::FindMySaplingNotes(const CTransaction &tx, int height) const
std::pair<mapSaplingNoteData_t, SaplingIncomingViewingKeyMap> CWallet::FindMySaplingNotes(
const Consensus::Params& consensus,
const CTransaction &tx,
int height) const
{
LOCK(cs_KeyStore);
uint256 hash = tx.GetHash();
@ -3651,7 +3784,7 @@ std::pair<mapSaplingNoteData_t, SaplingIncomingViewingKeyMap> CWallet::FindMySap
for (auto it = mapSaplingFullViewingKeys.begin(); it != mapSaplingFullViewingKeys.end(); ++it) {
SaplingIncomingViewingKey ivk = it->first;
auto result = SaplingNotePlaintext::decrypt(Params().GetConsensus(), height, output.encCiphertext, ivk, output.ephemeralKey, output.cmu);
auto result = SaplingNotePlaintext::decrypt(consensus, height, output.encCiphertext, ivk, output.ephemeralKey, output.cmu);
if (!result) {
continue;
}
@ -4599,6 +4732,9 @@ int CWallet::ScanForWalletTransactions(
performOrchardWalletUpdates = true;
}
// Create a rescan-specific batch scanner for the wallet.
auto batchScanner = WalletBatchScanner(this);
ShowProgress(_("Rescanning..."), 0); // show rescan progress in GUI as dialog or on splashscreen, if -rescan on startup
double dProgressStart = Checkpoints::GuessVerificationProgress(chainParams.Checkpoints(), pindex, false);
double dProgressTip = Checkpoints::GuessVerificationProgress(chainParams.Checkpoints(), chainActive.Tip(), false);
@ -4612,9 +4748,16 @@ int CWallet::ScanForWalletTransactions(
throw std::runtime_error(
strprintf("Can't read block %d from disk (%s)", pindex->nHeight, pindex->GetBlockHash().GetHex()));
}
for (CTransaction& tx : block.vtx) {
CDataStream ssTx(SER_NETWORK, PROTOCOL_VERSION);
ssTx << tx;
std::vector<unsigned char> txBytes(ssTx.begin(), ssTx.end());
batchScanner.AddTransaction(tx, txBytes, pindex->GetBlockHash(), pindex->nHeight);
}
batchScanner.Flush();
for (CTransaction& tx : block.vtx)
{
if (AddToWalletIfInvolvingMe(consensus, tx, &block, pindex->nHeight, fUpdate)) {
if (batchScanner.AddToWalletIfInvolvingMe(consensus, tx, &block, pindex->nHeight, fUpdate)) {
myTxHashes.push_back(tx.GetHash());
ret++;
}

View File

@ -42,6 +42,8 @@
#include <utility>
#include <vector>
#include <rust/wallet_scanner.h>
#include <boost/shared_ptr.hpp>
extern CWallet* pwalletMain;
@ -1038,6 +1040,59 @@ public:
}
};
typedef struct WalletDecryptedNotes {
mapSproutNoteData_t sproutNoteData;
/**
* The decrypted Sapling notes, and any newly-discovered addresses that
* should be added to the keystore.
*
* NOTE: Adding every recipient address to this map will cause the
* transaction to not be added to the wallet, as the address write will
* attempt to overwrite an existing entry and fail.
*/
std::pair<mapSaplingNoteData_t, SaplingIncomingViewingKeyMap> saplingNoteDataAndAddressesToAdd;
} WalletDecryptedNotes;
class WalletBatchScanner : public BatchScanner {
private:
CWallet* pwallet;
rust::Box<wallet::BatchScanner> inner;
std::map<uint256, WalletDecryptedNotes> decryptedNotes;
static rust::Box<wallet::BatchScanner> CreateBatchScanner(CWallet* pwallet);
WalletBatchScanner(CWallet* pwalletIn) : pwallet(pwalletIn), inner(CreateBatchScanner(pwalletIn)) {}
friend class CWallet;
public:
void AddTransactionToBatch(const CTransaction &tx, const int nHeight);
bool AddToWalletIfInvolvingMe(
const Consensus::Params& consensus,
const CTransaction& tx,
const CBlock* pblock,
const int nHeight,
bool fUpdate);
//
// BatchScanner APIs
//
void AddTransaction(
const CTransaction &tx,
const std::vector<unsigned char> &txBytes,
const uint256 &blockTag,
const int nHeight);
void Flush();
void SyncTransaction(
const CTransaction &tx,
const CBlock *pblock,
const int nHeight);
};
/**
* A CWallet is an extension of a keystore, which also maintains a set of transactions and balances,
* and provides the ability to create new transactions.
@ -1046,6 +1101,7 @@ class CWallet : public CCryptoKeyStore, public CValidationInterface
{
private:
friend class CWalletTx;
friend class WalletBatchScanner;
/**
* Select a set of coins such that nValueRet >= nTargetValue and at least
@ -1215,6 +1271,17 @@ protected:
*/
OrchardWallet orchardWallet;
/**
* The batch scanner for this wallet's CValidationInterface listener.
*
* This is stored in the wallet so that the wallet can manage its memory and
* the CValidationInterface provider uses pointers so the vtables work.
* We rely on the CValidationInterface provider only using its pointer to
* the batch scanner synchronously, and we use a separate batch scanner
* inside ScanForWalletTransactions so they don't collide.
*/
WalletBatchScanner* validationInterfaceBatchScanner;
public:
/*
* Main wallet lock.
@ -1257,6 +1324,8 @@ public:
{
delete pwalletdbEncryption;
pwalletdbEncryption = NULL;
delete validationInterfaceBatchScanner;
validationInterfaceBatchScanner = nullptr;
}
void SetNull(const CChainParams& params)
@ -1275,6 +1344,7 @@ public:
fBroadcastTransactions = false;
nWitnessCacheSize = 0;
networkIdString = params.NetworkIDString();
validationInterfaceBatchScanner = new WalletBatchScanner(this);
}
/**
@ -1691,6 +1761,8 @@ public:
DBErrors ReorderTransactions();
WalletDecryptedNotes TryDecryptShieldedOutputs(const CTransaction& tx);
void MarkDirty();
bool UpdateNullifierNoteMap();
void UpdateNullifierNoteMapWithTx(const CWalletTx& wtx);
@ -1698,12 +1770,13 @@ public:
void UpdateSaplingNullifierNoteMapForBlock(const CBlock* pblock);
void LoadWalletTx(const CWalletTx& wtxIn);
bool AddToWallet(const CWalletTx& wtxIn, CWalletDB* pwalletdb);
void SyncTransaction(const CTransaction& tx, const CBlock* pblock, const int nHeight);
BatchScanner* GetBatchScanner();
bool AddToWalletIfInvolvingMe(
const Consensus::Params& consensus,
const CTransaction& tx,
const CBlock* pblock,
const int nHeight,
WalletDecryptedNotes decryptedNotes,
bool fUpdate
);
void EraseFromWallet(const uint256 &hash);
@ -1806,7 +1879,10 @@ public:
const uint256& hSig,
uint8_t n) const;
mapSproutNoteData_t FindMySproutNotes(const CTransaction& tx) const;
std::pair<mapSaplingNoteData_t, SaplingIncomingViewingKeyMap> FindMySaplingNotes(const CTransaction& tx, int height) const;
std::pair<mapSaplingNoteData_t, SaplingIncomingViewingKeyMap> FindMySaplingNotes(
const Consensus::Params& consensus,
const CTransaction& tx,
int height) const;
bool IsSproutNullifierFromMe(const uint256& nullifier) const;
bool IsSaplingNullifierFromMe(const uint256& nullifier) const;

View File

@ -310,7 +310,7 @@ double benchmark_try_decrypt_sapling_notes(size_t nKeys)
struct timeval tv_start;
timer_start(tv_start);
auto noteDataMapAndAddressesToAdd = wallet.FindMySaplingNotes(tx, 1);
auto noteDataMapAndAddressesToAdd = wallet.FindMySaplingNotes(Params().GetConsensus(), tx, 1);
assert(noteDataMapAndAddressesToAdd.first.empty());
return timer_stop(tv_start);
}