zcash_client_backend: Factor out common note decryption from `scan_block_with_runner`

This commit is contained in:
Kris Nuttycombe 2024-02-22 18:26:30 -07:00
parent ba568f47ad
commit c7df76f7d1
3 changed files with 205 additions and 123 deletions

View File

@ -23,6 +23,8 @@ and this library adheres to Rust's notion of
- `zcash_client_backend::fees::ChangeValue::orchard`
- `zcash_client_backend::wallet`:
- `Note::Orchard`
- `zcash_client_backend::proto`:
- `service::TreeState::orchard_tree`
### Changed
- `zcash_client_backend::data_api`:
@ -81,7 +83,7 @@ and this library adheres to Rust's notion of
- `PROPOSAL_SER_V1`
- `ProposalDecodingError`
- `proposal` module, for parsing and serializing transaction proposals.
- `service::TreeState::orchard_tree` (behind the `orchard` feature flag)
- `impl TryFrom<&CompactSaplingOutput> for CompactOutputDescription`
- `impl Clone for zcash_client_backend::{
zip321::{Payment, TransactionRequest, Zip321Error, parse::Param, parse::IndexedParam},
wallet::WalletTransparentOutput,

View File

@ -161,10 +161,20 @@ impl TryFrom<compact_formats::CompactSaplingOutput>
type Error = ();
fn try_from(value: compact_formats::CompactSaplingOutput) -> Result<Self, Self::Error> {
(&value).try_into()
}
}
impl TryFrom<&compact_formats::CompactSaplingOutput>
for sapling::note_encryption::CompactOutputDescription
{
type Error = ();
fn try_from(value: &compact_formats::CompactSaplingOutput) -> Result<Self, Self::Error> {
Ok(sapling::note_encryption::CompactOutputDescription {
cmu: value.cmu()?,
ephemeral_key: value.ephemeral_key()?,
enc_ciphertext: value.ciphertext.try_into().map_err(|_| ())?,
enc_ciphertext: value.ciphertext[..].try_into().map_err(|_| ())?,
})
}
}

View File

@ -12,14 +12,17 @@ use sapling::{
SaplingIvk,
};
use subtle::{ConditionallySelectable, ConstantTimeEq, CtOption};
use zcash_note_encryption::batch;
use zcash_primitives::consensus::{self, BlockHeight, NetworkUpgrade};
use zcash_note_encryption::{batch, BatchDomain, ShieldedOutput};
use zcash_primitives::{
consensus::{self, BlockHeight, NetworkUpgrade},
transaction::TxId,
};
use zip32::Scope;
use crate::data_api::{BlockMetadata, ScannedBlock, ScannedBundles};
use crate::{
proto::compact_formats::CompactBlock,
scan::{Batch, BatchRunner, CompactDecryptor, Tasks},
scan::{Batch, BatchRunner, CompactDecryptor, DecryptedOutput, Tasks},
wallet::{WalletSaplingOutput, WalletSaplingSpend, WalletTx},
ShieldedProtocol,
};
@ -331,7 +334,7 @@ pub(crate) fn add_block_to_runner<P, S, T, A>(
let txid = tx.txid();
let outputs = tx
.outputs
.into_iter()
.iter()
.map(|output| {
CompactOutputDescription::try_from(output)
.expect("Invalid output found in compact block decoding.")
@ -494,7 +497,7 @@ where
let tx_index =
u16::try_from(tx.index).expect("Cannot fit more than 2^16 transactions in a block");
let (sapling_spends, sapling_unlinked_nullifiers) = check_nullifiers(
let (sapling_spends, sapling_unlinked_nullifiers) = find_spent(
&tx.spends,
sapling_nullifiers,
|spend| {
@ -508,23 +511,21 @@ where
sapling_nullifier_map.push((txid, tx_index, sapling_unlinked_nullifiers));
// Collect the set of accounts that were spent from in this transaction
let spent_from_accounts: HashSet<_> =
sapling_spends.iter().map(|spend| spend.account()).collect();
let spent_from_accounts: HashSet<_> = sapling_spends
.iter()
.map(|spend| *spend.account())
.collect();
// We keep track of the number of outputs and actions here because tx.outputs
// and tx.actions end up being moved.
let tx_outputs_len =
u32::try_from(tx.outputs.len()).expect("Sapling output count cannot exceed a u32");
#[cfg(feature = "orchard")]
let tx_actions_len =
u32::try_from(tx.actions.len()).expect("Orchard action count cannot exceed a u32");
// Check for incoming notes while incrementing tree and witnesses
let mut shielded_outputs: Vec<WalletSaplingOutput<SK::Nf, SK::Scope, A>> = vec![];
{
let decoded = &tx
.outputs
.into_iter()
let (sapling_outputs, mut sapling_nc) = find_received(
cur_height,
compact_block_tx_count,
txid,
tx_idx,
sapling_commitment_tree_size,
sapling_keys,
&spent_from_accounts,
&tx.outputs
.iter()
.map(|output| {
(
SaplingDomain::new(zip212_enforcement),
@ -532,115 +533,43 @@ where
.expect("Invalid output found in compact block decoding."),
)
})
.collect::<Vec<_>>();
let decrypted: Vec<_> = if let Some(runner) = sapling_batch_runner.as_mut() {
let sapling_keys = sapling_keys
.iter()
.flat_map(|(a, k)| {
k.to_ivks()
.into_iter()
.map(move |(scope, _, nk)| ((**a, scope), nk))
})
.collect::<HashMap<_, _>>();
let mut decrypted = runner.collect_results(cur_hash, txid);
(0..decoded.len())
.map(|i| {
decrypted.remove(&(txid, i)).map(|d_out| {
let a = d_out.ivk_tag.0;
let nk = sapling_keys.get(&d_out.ivk_tag).expect(
"The batch runner and scan_block must use the same set of IVKs.",
);
(d_out.note, a, d_out.ivk_tag.1, (*nk).clone())
})
})
.collect()
} else {
let sapling_keys = sapling_keys
.iter()
.flat_map(|(a, k)| {
k.to_ivks()
.into_iter()
.map(move |(scope, ivk, nk)| (**a, scope, ivk, nk))
})
.collect::<Vec<_>>();
let ivks = sapling_keys
.iter()
.map(|(_, _, ivk, _)| PreparedIncomingViewingKey::new(ivk))
.collect::<Vec<_>>();
batch::try_compact_note_decryption(&ivks, &decoded[..])
.into_iter()
.map(|v| {
v.map(|((note, _), ivk_idx)| {
let (account, scope, _, nk) = &sapling_keys[ivk_idx];
(note, *account, scope.clone(), (*nk).clone())
})
})
.collect()
};
for (output_idx, ((_, output), dec_output)) in decoded.iter().zip(decrypted).enumerate()
{
// Collect block note commitments
let node = sapling::Node::from_cmu(&output.cmu);
let is_checkpoint =
output_idx + 1 == decoded.len() && tx_idx + 1 == compact_block_tx_count;
let retention = match (dec_output.is_some(), is_checkpoint) {
(is_marked, true) => Retention::Checkpoint {
id: cur_height,
is_marked,
},
(true, false) => Retention::Marked,
(false, false) => Retention::Ephemeral,
};
if let Some((note, account, scope, nk)) = dec_output {
// A note is marked as "change" if the account that received it
// also spent notes in the same transaction. This will catch,
// for instance:
// - Change created by spending fractions of notes.
// - Notes created by consolidation transactions.
// - Notes sent from one account to itself.
let is_change = spent_from_accounts.contains(&account);
let note_commitment_tree_position = Position::from(u64::from(
sapling_commitment_tree_size + u32::try_from(output_idx).unwrap(),
));
let nf = SK::nf(&nk, &note, note_commitment_tree_position);
shielded_outputs.push(WalletSaplingOutput::from_parts(
.collect::<Vec<_>>(),
sapling_batch_runner
.as_mut()
.map(|runner| |txid| runner.collect_results(cur_hash, txid)),
PreparedIncomingViewingKey::new,
|output| sapling::Node::from_cmu(&output.cmu),
|output_idx, output, account, note, is_change, position, nf, scope| {
WalletSaplingOutput::from_parts(
output_idx,
output.cmu,
output.ephemeral_key.clone(),
account,
note,
is_change,
note_commitment_tree_position,
position,
nf,
scope,
));
}
)
},
);
sapling_note_commitments.append(&mut sapling_nc);
sapling_note_commitments.push((node, retention));
}
}
if !(sapling_spends.is_empty() && shielded_outputs.is_empty()) {
if !(sapling_spends.is_empty() && sapling_outputs.is_empty()) {
wtxs.push(WalletTx {
txid,
index: tx_index as usize,
sapling_spends,
sapling_outputs: shielded_outputs,
sapling_outputs,
});
}
sapling_commitment_tree_size += tx_outputs_len;
sapling_commitment_tree_size +=
u32::try_from(tx.outputs.len()).expect("Sapling output count cannot exceed a u32");
#[cfg(feature = "orchard")]
{
orchard_commitment_tree_size += tx_actions_len;
orchard_commitment_tree_size +=
u32::try_from(tx.actions.len()).expect("Orchard action count cannot exceed a u32");
}
}
@ -686,7 +615,7 @@ where
// Check for spent notes. The comparison against known-unspent nullifiers is done
// in constant time.
fn check_nullifiers<A: ConditionallySelectable + Default, Spend, Nf: ConstantTimeEq + Copy, WS>(
fn find_spent<A: ConditionallySelectable + Default, Spend, Nf: ConstantTimeEq + Copy, WS>(
spends: &[Spend],
nullifiers: &[(A, Nf)],
extract_nf: impl Fn(&Spend) -> Nf,
@ -720,6 +649,147 @@ fn check_nullifiers<A: ConditionallySelectable + Default, Spend, Nf: ConstantTim
(found_spent, unlinked_nullifiers)
}
#[allow(clippy::too_many_arguments)]
#[allow(clippy::type_complexity)]
fn find_received<
A: Copy + Eq + Hash,
D: BatchDomain,
SK: ScanningKey<Note = D::Note>,
Output: ShieldedOutput<D, 52>,
WalletOutput,
NoteCommitment,
>(
block_height: BlockHeight,
block_tx_count: usize,
txid: TxId,
tx_idx: usize,
commitment_tree_size: u32,
keys: &[(&A, SK)],
spent_from_accounts: &HashSet<A>,
decoded: &[(D, Output)],
batch_results: Option<
impl FnOnce(TxId) -> HashMap<(TxId, usize), DecryptedOutput<(A, SK::Scope), D, ()>>,
>,
prepare_key: impl Fn(&SK::IncomingViewingKey) -> D::IncomingViewingKey,
extract_note_commitment: impl Fn(&Output) -> NoteCommitment,
new_wallet_output: impl Fn(
usize,
&Output,
A,
SK::Note,
bool,
Position,
SK::Nf,
SK::Scope,
) -> WalletOutput,
) -> (
Vec<WalletOutput>,
Vec<(NoteCommitment, Retention<BlockHeight>)>,
) {
// Check for incoming notes while incrementing tree and witnesses
let (decrypted_opts, decrypted_len) = if let Some(collect_results) = batch_results {
let tagged_keys = keys
.iter()
.flat_map(|(a, k)| {
k.to_ivks()
.into_iter()
.map(move |(scope, _, nk)| ((**a, scope), nk))
})
.collect::<HashMap<_, _>>();
let mut decrypted = collect_results(txid);
let decrypted_len = decrypted.len();
(
(0..decoded.len())
.map(|i| {
decrypted.remove(&(txid, i)).map(|d_out| {
let nk = tagged_keys.get(&d_out.ivk_tag).expect(
"The batch runner and scan_block must use the same set of IVKs.",
);
(d_out.note, d_out.ivk_tag.0, d_out.ivk_tag.1, (*nk).clone())
})
})
.collect::<Vec<_>>(),
decrypted_len,
)
} else {
let tagged_keys = keys
.iter()
.flat_map(|(a, k)| {
k.to_ivks()
.into_iter()
.map(move |(scope, ivk, nk)| (**a, scope, ivk, nk))
})
.collect::<Vec<_>>();
let ivks = tagged_keys
.iter()
.map(|(_, _, ivk, _)| prepare_key(ivk))
.collect::<Vec<_>>();
let mut decrypted_len = 0;
(
batch::try_compact_note_decryption(&ivks, decoded)
.into_iter()
.map(|v| {
v.map(|((note, _), ivk_idx)| {
decrypted_len += 1;
let (a, scope, _, nk) = &tagged_keys[ivk_idx];
(note, *a, scope.clone(), (*nk).clone())
})
})
.collect::<Vec<_>>(),
decrypted_len,
)
};
let mut shielded_outputs = Vec::with_capacity(decrypted_len);
let mut note_commitments = Vec::with_capacity(decoded.len());
for (output_idx, ((_, output), dec_output)) in decoded.iter().zip(decrypted_opts).enumerate() {
// Collect block note commitments
let node = extract_note_commitment(output);
let is_checkpoint = output_idx + 1 == decoded.len() && tx_idx + 1 == block_tx_count;
let retention = match (dec_output.is_some(), is_checkpoint) {
(is_marked, true) => Retention::Checkpoint {
id: block_height,
is_marked,
},
(true, false) => Retention::Marked,
(false, false) => Retention::Ephemeral,
};
if let Some((note, account, scope, nk)) = dec_output {
// A note is marked as "change" if the account that received it
// also spent notes in the same transaction. This will catch,
// for instance:
// - Change created by spending fractions of notes.
// - Notes created by consolidation transactions.
// - Notes sent from one account to itself.
let is_change = spent_from_accounts.contains(&account);
let note_commitment_tree_position = Position::from(u64::from(
commitment_tree_size + u32::try_from(output_idx).unwrap(),
));
let nf = SK::nf(&nk, &note, note_commitment_tree_position);
shielded_outputs.push(new_wallet_output(
output_idx,
output,
account,
note,
is_change,
note_commitment_tree_position,
nf,
scope,
));
}
note_commitments.push((node, retention))
}
(shielded_outputs, note_commitments)
}
#[cfg(test)]
mod tests {
use group::{