wallet: Domain-separate batched txids with a "block tag"

Previously when a transaction was queried for batch trial decryption, we
identified it by its txid. This is sufficient to uniquely identify the
transaction within the wallet, but was _not_ sufficient to uniquely
identify it within a `ThreadNotifyWallets` loop. In particular, when a
reorg occurs, and the same transaction is present in blocks on both
sides of the reorg (or is reorged into the mempool and then conflicted
out):

- The first occurrence would batch the transaction's outputs and store a
  result receiver.
- The second occurrence would overwrite the first occurrence's result
  receiver with its own.
- The first occurrence would read the second's result receiver (which
  has identical results to the first batch), removing it from the
  `pending_results` map.
- The second occurrence would not find any receiver in the map, and
  would mark the transaction as having no decrypted results.

We fix this by annotating each batched transaction with the hash of the
block that triggered it being trial-decrypted: either the block being
disconnected, the block being connected, or the null hash to indicate
a new transaction in the mempool. This is sufficient to domain-separate
all possible sources of duplicate txids:

- If a transaction is moved to the mempool via a block disconnection, or
  from the mempool (either mined or conflicted) via a block connection,
  its txid will appear twice: once with the block in question's hash,
  and once with the null hash.
- If a transaction is present in both a disconnected and a connected
  block (mined on both sides of the fork), its txid will appear twice:
  once each with the two block's txids.

Both of the above rely on the assumption that block hashes are collision
resistant, which in turn relies on SHA-256 being collision resistant.
This commit is contained in:
Jack Grigg 2022-07-22 14:38:34 +00:00
parent 20e6710fc6
commit 1caf6f70df
5 changed files with 80 additions and 18 deletions

View File

@ -7,6 +7,7 @@ use std::sync::mpsc;
use group::GroupEncoding;
use zcash_note_encryption::{batch, BatchDomain, Domain, ShieldedOutput, ENC_CIPHERTEXT_SIZE};
use zcash_primitives::{
block::BlockHash,
consensus, constants,
sapling::{self, note_encryption::SaplingDomain},
transaction::{
@ -46,9 +47,18 @@ mod ffi {
network: &Network,
sapling_ivks: &[[u8; 32]],
) -> Result<Box<BatchScanner>>;
fn add_transaction(self: &mut BatchScanner, tx_bytes: &[u8], height: u32) -> Result<()>;
fn add_transaction(
self: &mut BatchScanner,
block_tag: [u8; 32],
tx_bytes: &[u8],
height: u32,
) -> Result<()>;
fn flush(self: &mut BatchScanner);
fn collect_results(self: &mut BatchScanner, txid: [u8; 32]) -> Box<BatchResult>;
fn collect_results(
self: &mut BatchScanner,
block_tag: [u8; 32],
txid: [u8; 32],
) -> Box<BatchResult>;
fn get_sapling(self: &BatchResult) -> Vec<SaplingDecryptionResult>;
}
@ -298,10 +308,12 @@ impl<D: BatchDomain, Output: ShieldedOutput<D, ENC_CIPHERTEXT_SIZE> + Clone> Bat
}
}
type ResultKey = (BlockHash, TxId);
/// Logic to run batches of trial decryptions on the global threadpool.
struct BatchRunner<D: BatchDomain, Output: ShieldedOutput<D, ENC_CIPHERTEXT_SIZE>> {
acc: Batch<D, Output>,
pending_results: HashMap<TxId, mpsc::Receiver<OutputIndex<Option<DecryptedNote<D>>>>>,
pending_results: HashMap<ResultKey, mpsc::Receiver<OutputIndex<Option<DecryptedNote<D>>>>>,
}
impl<D, Output> BatchRunner<D, Output>
@ -330,13 +342,23 @@ where
{
/// Batches the given outputs for trial decryption.
///
/// `block_tag` is the hash of the block that triggered this txid being added to the
/// batch, or the all-zeros hash to indicate that no block triggered it (i.e. it was a
/// mempool change).
///
/// If after adding the given outputs, the accumulated batch size is at least
/// `BATCH_SIZE_THRESHOLD`, `Self::flush` is called. Subsequent calls to
/// `Self::add_outputs` will be accumulated into a new batch.
fn add_outputs(&mut self, txid: TxId, domain: impl Fn() -> D, outputs: &[Output]) {
fn add_outputs(
&mut self,
block_tag: BlockHash,
txid: TxId,
domain: impl Fn() -> D,
outputs: &[Output],
) {
let (tx, rx) = mpsc::channel();
self.acc.add_outputs(domain, outputs, tx);
self.pending_results.insert(txid, rx);
self.pending_results.insert((block_tag, txid), rx);
if self.acc.outputs.len() >= BATCH_SIZE_THRESHOLD {
self.flush();
@ -355,9 +377,17 @@ where
}
/// Collects the pending decryption results for the given transaction.
fn collect_results(&mut self, txid: TxId) -> HashMap<(TxId, usize), DecryptedNote<D>> {
///
/// `block_tag` is the hash of the block that triggered this txid being added to the
/// batch, or the all-zeros hash to indicate that no block triggered it (i.e. it was a
/// mempool change).
fn collect_results(
&mut self,
block_tag: BlockHash,
txid: TxId,
) -> HashMap<(TxId, usize), DecryptedNote<D>> {
self.pending_results
.remove(&txid)
.remove(&(block_tag, txid))
// We won't have a pending result if the transaction didn't have outputs of
// this runner's kind.
.map(|rx| {
@ -409,10 +439,20 @@ fn init_batch_scanner(
impl BatchScanner {
/// Adds the given transaction's shielded outputs to the various batch runners.
///
/// `block_tag` is the hash of the block that triggered this txid being added to the
/// batch, or the all-zeros hash to indicate that no block triggered it (i.e. it was a
/// mempool change).
///
/// After adding the outputs, any accumulated batch of sufficient size is run on the
/// global threadpool. Subsequent calls to `Self::add_transaction` will accumulate
/// those output kinds into new batches.
fn add_transaction(&mut self, tx_bytes: &[u8], height: u32) -> Result<(), io::Error> {
fn add_transaction(
&mut self,
block_tag: [u8; 32],
tx_bytes: &[u8],
height: u32,
) -> Result<(), io::Error> {
let block_tag = BlockHash(block_tag);
// The consensusBranchId parameter is ignored; it is not used in trial decryption.
let tx = Transaction::read(tx_bytes, consensus::BranchId::Sprout)?;
let txid = tx.txid();
@ -423,6 +463,7 @@ impl BatchScanner {
if let Some((runner, bundle)) = self.sapling_runner.as_mut().zip(tx.sapling_bundle()) {
let params = self.params;
runner.add_outputs(
block_tag,
txid,
|| SaplingDomain::for_height(params, height),
&bundle.shielded_outputs,
@ -443,14 +484,19 @@ impl BatchScanner {
/// Collects the pending decryption results for the given transaction.
///
/// `block_tag` is the hash of the block that triggered this txid being added to the
/// batch, or the all-zeros hash to indicate that no block triggered it (i.e. it was a
/// mempool change).
///
/// TODO: Return the `HashMap`s directly once `cxx` supports it.
fn collect_results(&mut self, txid: [u8; 32]) -> Box<BatchResult> {
fn collect_results(&mut self, block_tag: [u8; 32], txid: [u8; 32]) -> Box<BatchResult> {
let block_tag = BlockHash(block_tag);
let txid = TxId::from_bytes(txid);
let sapling = self
.sapling_runner
.as_mut()
.map(|runner| runner.collect_results(txid))
.map(|runner| runner.collect_results(block_tag, txid))
.unwrap_or_default();
Box::new(BatchResult { sapling })

View File

@ -71,13 +71,14 @@ void UnregisterAllValidationInterfaces() {
void AddTxToBatches(
std::vector<BatchScanner*> &batchScanners,
const CTransaction &tx,
const uint256 &blockTag,
const int nHeight)
{
CDataStream ssTx(SER_NETWORK, PROTOCOL_VERSION);
ssTx << tx;
std::vector<unsigned char> txBytes(ssTx.begin(), ssTx.end());
for (auto& batchScanner : batchScanners) {
batchScanner->AddTransaction(tx, txBytes, nHeight);
batchScanner->AddTransaction(tx, txBytes, blockTag, nHeight);
}
}
@ -319,7 +320,7 @@ void ThreadNotifyWallets(CBlockIndex *pindexLastTip)
// Batch transactions that went from 1-confirmed to 0-confirmed
// or conflicted.
for (const CTransaction &tx : block.vtx) {
AddTxToBatches(batchScanners, tx, pindexScan->nHeight);
AddTxToBatches(batchScanners, tx, block.GetHash(), pindexScan->nHeight);
}
// On to the next block!
@ -347,17 +348,25 @@ void ThreadNotifyWallets(CBlockIndex *pindexLastTip)
// Batch transactions that went from mempool to conflicted:
for (const CTransaction &tx : blockData.txConflicted) {
AddTxToBatches(batchScanners, tx, blockData.pindex->nHeight + 1);
AddTxToBatches(
batchScanners,
tx,
blockData.pindex->GetBlockHash(),
blockData.pindex->nHeight + 1);
}
// ... and transactions that got confirmed:
for (const CTransaction &tx : block.vtx) {
AddTxToBatches(batchScanners, tx, blockData.pindex->nHeight);
AddTxToBatches(
batchScanners,
tx,
blockData.pindex->GetBlockHash(),
blockData.pindex->nHeight);
}
}
// Batch transactions in the mempool.
for (auto tx : recentlyAdded.first) {
AddTxToBatches(batchScanners, tx, pindexLastTip->nHeight + 1);
AddTxToBatches(batchScanners, tx, uint256(), pindexLastTip->nHeight + 1);
}
}

View File

@ -36,6 +36,7 @@ public:
virtual void AddTransaction(
const CTransaction &tx,
const std::vector<unsigned char> &txBytes,
const uint256 &blockTag,
const int nHeight) = 0;
/**

View File

@ -3559,7 +3559,11 @@ bool WalletBatchScanner::AddToWalletIfInvolvingMe(
auto decryptedNotes = decryptedNotesForTx->second;
// Fill in the details about decrypted Sapling notes.
auto batchResults = inner->collect_results(tx.GetHash().GetRawBytes());
uint256 blockTag;
if (pblock) {
blockTag = pblock->GetHash();
}
auto batchResults = inner->collect_results(blockTag.GetRawBytes(), tx.GetHash().GetRawBytes());
for (auto decrypted : batchResults->get_sapling()) {
SaplingIncomingViewingKey ivk(uint256::FromRawBytes(decrypted.ivk));
libzcash::SaplingPaymentAddress addr(
@ -3589,6 +3593,7 @@ bool WalletBatchScanner::AddToWalletIfInvolvingMe(
void WalletBatchScanner::AddTransaction(
const CTransaction &tx,
const std::vector<unsigned char> &txBytes,
const uint256 &blockTag,
const int nHeight)
{
// Decrypt Sprout outputs immediately.
@ -3596,7 +3601,7 @@ void WalletBatchScanner::AddTransaction(
std::make_pair(tx.GetHash(), pwallet->TryDecryptShieldedOutputs(tx)));
// Queue Sapling outputs for trial decryption.
inner->add_transaction({txBytes.data(), txBytes.size()}, nHeight);
inner->add_transaction(blockTag.GetRawBytes(), {txBytes.data(), txBytes.size()}, nHeight);
}
void WalletBatchScanner::Flush() {
@ -4747,7 +4752,7 @@ int CWallet::ScanForWalletTransactions(
CDataStream ssTx(SER_NETWORK, PROTOCOL_VERSION);
ssTx << tx;
std::vector<unsigned char> txBytes(ssTx.begin(), ssTx.end());
batchScanner.AddTransaction(tx, txBytes, pindex->nHeight);
batchScanner.AddTransaction(tx, txBytes, pindex->GetBlockHash(), pindex->nHeight);
}
batchScanner.Flush();
for (CTransaction& tx : block.vtx)

View File

@ -1082,6 +1082,7 @@ public:
void AddTransaction(
const CTransaction &tx,
const std::vector<unsigned char> &txBytes,
const uint256 &blockTag,
const int nHeight);
void Flush();