rust: Implement multithreaded batched trial decryption for Sapling

This commit is contained in:
Jack Grigg 2022-07-19 00:12:04 +00:00
parent f447e33aad
commit 8190815526
3 changed files with 486 additions and 3 deletions

View File

@ -50,19 +50,22 @@ CXXBRIDGE_RS = \
rust/src/bundlecache.rs \
rust/src/equihash.rs \
rust/src/orchard_bundle.rs \
rust/src/sapling.rs
rust/src/sapling.rs \
rust/src/wallet_scanner.rs
CXXBRIDGE_H = \
rust/gen/include/rust/blake2b.h \
rust/gen/include/rust/bundlecache.h \
rust/gen/include/rust/equihash.h \
rust/gen/include/rust/orchard_bundle.h \
rust/gen/include/rust/sapling.h
rust/gen/include/rust/sapling.h \
rust/gen/include/rust/wallet_scanner.h
CXXBRIDGE_CPP = \
rust/gen/src/blake2b.cpp \
rust/gen/src/bundlecache.cpp \
rust/gen/src/equihash.cpp \
rust/gen/src/orchard_bundle.cpp \
rust/gen/src/sapling.cpp
rust/gen/src/sapling.cpp \
rust/gen/src/wallet_scanner.cpp
# We add a rust/cxx.h include to indicate that we provide this (via the rustcxx depends
# package), so that cxxbridge doesn't include it within the generated headers and code.

View File

@ -81,6 +81,7 @@ mod sapling;
mod transaction_ffi;
mod unified_keys_ffi;
mod wallet;
mod wallet_scanner;
mod zip339_ffi;
mod test_harness_ffi;

View File

@ -0,0 +1,479 @@
use core::fmt;
use std::collections::HashMap;
use std::io;
use std::mem;
use std::sync::mpsc;
use group::GroupEncoding;
use zcash_note_encryption::{batch, BatchDomain, Domain, ShieldedOutput, ENC_CIPHERTEXT_SIZE};
use zcash_primitives::{
consensus, constants,
sapling::{self, note_encryption::SaplingDomain},
transaction::{
components::{sapling::GrothProofBytes, OutputDescription},
Transaction, TxId,
},
};
#[cxx::bridge]
mod ffi {
#[namespace = "wallet"]
struct SaplingDecryptionResult {
txid: [u8; 32],
output: u32,
ivk: [u8; 32],
diversifier: [u8; 11],
pk_d: [u8; 32],
}
#[namespace = "wallet"]
extern "Rust" {
type Network;
type BatchScanner;
type BatchResult;
fn network(
network: &str,
overwinter: i32,
sapling: i32,
blossom: i32,
heartwood: i32,
canopy: i32,
nu5: i32,
) -> Result<Box<Network>>;
fn init_batch_scanner(
network: &Network,
sapling_ivks: &[[u8; 32]],
) -> Result<Box<BatchScanner>>;
fn add_transaction(self: &mut BatchScanner, tx_bytes: &[u8], height: u32) -> Result<()>;
fn flush(self: &mut BatchScanner);
fn collect_results(self: &mut BatchScanner, txid: [u8; 32]) -> Box<BatchResult>;
fn get_sapling(self: &BatchResult) -> Vec<SaplingDecryptionResult>;
}
}
/// The minimum number of outputs to trial decrypt in a batch.
///
/// TODO: Tune this.
const BATCH_SIZE_THRESHOLD: usize = 20;
/// Chain parameters for the networks supported by `zcashd`.
#[derive(Clone, Copy)]
pub enum Network {
Consensus(consensus::Network),
RegTest {
overwinter: Option<consensus::BlockHeight>,
sapling: Option<consensus::BlockHeight>,
blossom: Option<consensus::BlockHeight>,
heartwood: Option<consensus::BlockHeight>,
canopy: Option<consensus::BlockHeight>,
nu5: Option<consensus::BlockHeight>,
},
}
/// Constructs a `Network` from the given network string.
///
/// The heights are only for constructing a regtest network, and are ignored otherwise.
fn network(
network: &str,
overwinter: i32,
sapling: i32,
blossom: i32,
heartwood: i32,
canopy: i32,
nu5: i32,
) -> Result<Box<Network>, &'static str> {
let i32_to_optional_height = |n: i32| {
if n.is_negative() {
None
} else {
Some(consensus::BlockHeight::from_u32(n.unsigned_abs()))
}
};
let params = match network {
"main" => Network::Consensus(consensus::Network::MainNetwork),
"test" => Network::Consensus(consensus::Network::TestNetwork),
"regtest" => Network::RegTest {
overwinter: i32_to_optional_height(overwinter),
sapling: i32_to_optional_height(sapling),
blossom: i32_to_optional_height(blossom),
heartwood: i32_to_optional_height(heartwood),
canopy: i32_to_optional_height(canopy),
nu5: i32_to_optional_height(nu5),
},
_ => return Err("Unsupported network kind"),
};
Ok(Box::new(params))
}
impl consensus::Parameters for Network {
fn activation_height(&self, nu: consensus::NetworkUpgrade) -> Option<consensus::BlockHeight> {
match self {
Self::Consensus(params) => params.activation_height(nu),
Self::RegTest {
overwinter,
sapling,
blossom,
heartwood,
canopy,
nu5,
} => match nu {
consensus::NetworkUpgrade::Overwinter => *overwinter,
consensus::NetworkUpgrade::Sapling => *sapling,
consensus::NetworkUpgrade::Blossom => *blossom,
consensus::NetworkUpgrade::Heartwood => *heartwood,
consensus::NetworkUpgrade::Canopy => *canopy,
consensus::NetworkUpgrade::Nu5 => *nu5,
},
}
}
fn coin_type(&self) -> u32 {
match self {
Self::Consensus(params) => params.coin_type(),
Self::RegTest { .. } => constants::regtest::COIN_TYPE,
}
}
fn hrp_sapling_extended_spending_key(&self) -> &str {
match self {
Self::Consensus(params) => params.hrp_sapling_extended_spending_key(),
Self::RegTest { .. } => constants::regtest::HRP_SAPLING_EXTENDED_SPENDING_KEY,
}
}
fn hrp_sapling_extended_full_viewing_key(&self) -> &str {
match self {
Self::Consensus(params) => params.hrp_sapling_extended_full_viewing_key(),
Self::RegTest { .. } => constants::regtest::HRP_SAPLING_EXTENDED_FULL_VIEWING_KEY,
}
}
fn hrp_sapling_payment_address(&self) -> &str {
match self {
Self::Consensus(params) => params.hrp_sapling_payment_address(),
Self::RegTest { .. } => constants::regtest::HRP_SAPLING_PAYMENT_ADDRESS,
}
}
fn b58_pubkey_address_prefix(&self) -> [u8; 2] {
match self {
Self::Consensus(params) => params.b58_pubkey_address_prefix(),
Self::RegTest { .. } => constants::regtest::B58_PUBKEY_ADDRESS_PREFIX,
}
}
fn b58_script_address_prefix(&self) -> [u8; 2] {
match self {
Self::Consensus(params) => params.b58_script_address_prefix(),
Self::RegTest { .. } => constants::regtest::B58_SCRIPT_ADDRESS_PREFIX,
}
}
}
/// A decrypted note.
struct DecryptedNote<D: Domain> {
/// The incoming viewing key used to decrypt the note.
ivk: D::IncomingViewingKey,
/// The recipient of the note.
recipient: D::Recipient,
/// The note!
note: D::Note,
/// The memo sent with the note.
memo: D::Memo,
}
impl<D: Domain> fmt::Debug for DecryptedNote<D>
where
D::IncomingViewingKey: fmt::Debug,
D::Recipient: fmt::Debug,
D::Note: fmt::Debug,
D::Memo: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DecryptedNote")
.field("ivk", &self.ivk)
.field("recipient", &self.recipient)
.field("note", &self.note)
.field("memo", &self.memo)
.finish()
}
}
/// A value correlated with an output index.
struct OutputIndex<V> {
/// The index of the output within the corresponding shielded bundle.
output_index: usize,
/// The value for the output index.
value: V,
}
type OutputReplier<D> = OutputIndex<mpsc::Sender<OutputIndex<Option<DecryptedNote<D>>>>>;
/// A batch of outputs to trial decrypt.
struct Batch<D: BatchDomain, Output: ShieldedOutput<D, ENC_CIPHERTEXT_SIZE>> {
ivks: Vec<D::IncomingViewingKey>,
outputs: Vec<(D, Output)>,
repliers: Vec<OutputReplier<D>>,
}
impl<D, Output> Batch<D, Output>
where
D: BatchDomain,
Output: ShieldedOutput<D, ENC_CIPHERTEXT_SIZE>,
D::IncomingViewingKey: Clone,
{
/// Constructs a new batch.
fn new(ivks: Vec<D::IncomingViewingKey>) -> Self {
Self {
ivks,
outputs: vec![],
repliers: vec![],
}
}
/// Returns `true` if the batch is currently empty.
fn is_empty(&self) -> bool {
self.outputs.is_empty()
}
/// Runs the batch of trial decryptions, and reports the results.
fn run(self) {
assert_eq!(self.outputs.len(), self.repliers.len());
let decrypted = batch::try_note_decryption(&self.ivks, &self.outputs);
for (decrypted_note, (ivk, replier)) in decrypted.into_iter().zip(
// The output of `batch::try_note_decryption` corresponds to the stream of
// trial decryptions:
// (ivk0, out0), (ivk0, out1), ..., (ivk0, outN), (ivk1, out0), ...
// So we can use the position in the stream to figure out which output was
// decrypted and which ivk decrypted it.
self.ivks
.iter()
.flat_map(|ivk| self.repliers.iter().map(move |tx| (ivk, tx))),
) {
let value = decrypted_note.map(|(note, recipient, memo)| DecryptedNote {
ivk: ivk.clone(),
memo,
note,
recipient,
});
let output_index = replier.output_index;
let tx = &replier.value;
if tx
.send(OutputIndex {
output_index,
value,
})
.is_err()
{
tracing::debug!("BatchRunner was dropped before batch finished");
return;
}
}
}
}
impl<D: BatchDomain, Output: ShieldedOutput<D, ENC_CIPHERTEXT_SIZE> + Clone> Batch<D, Output> {
/// Adds the given outputs to this batch.
///
/// `replier` will be called with the result of every output.
fn add_outputs(
&mut self,
domain: impl Fn() -> D,
outputs: &[Output],
replier: mpsc::Sender<OutputIndex<Option<DecryptedNote<D>>>>,
) {
self.outputs
.extend(outputs.iter().cloned().map(|output| (domain(), output)));
self.repliers
.extend((0..outputs.len()).map(|output_index| OutputIndex {
output_index,
value: replier.clone(),
}));
}
}
/// Logic to run batches of trial decryptions on the global threadpool.
struct BatchRunner<D: BatchDomain, Output: ShieldedOutput<D, ENC_CIPHERTEXT_SIZE>> {
acc: Batch<D, Output>,
pending_results: HashMap<TxId, mpsc::Receiver<OutputIndex<Option<DecryptedNote<D>>>>>,
}
impl<D, Output> BatchRunner<D, Output>
where
D: BatchDomain,
Output: ShieldedOutput<D, ENC_CIPHERTEXT_SIZE>,
D::IncomingViewingKey: Clone,
{
/// Constructs a new batch runner for the given incoming viewing keys.
fn new(ivks: Vec<D::IncomingViewingKey>) -> Self {
Self {
acc: Batch::new(ivks),
pending_results: HashMap::default(),
}
}
}
impl<D, Output> BatchRunner<D, Output>
where
D: BatchDomain + Send + 'static,
D::IncomingViewingKey: Clone + Send,
D::Memo: Send,
D::Note: Send,
D::Recipient: Send,
Output: ShieldedOutput<D, ENC_CIPHERTEXT_SIZE> + Clone + Send + 'static,
{
/// Batches the given outputs for trial decryption.
///
/// If after adding the given outputs, the accumulated batch size is at least
/// `BATCH_SIZE_THRESHOLD`, `Self::flush` is called. Subsequent calls to
/// `Self::add_outputs` will be accumulated into a new batch.
fn add_outputs(&mut self, txid: TxId, domain: impl Fn() -> D, outputs: &[Output]) {
let (tx, rx) = mpsc::channel();
self.acc.add_outputs(domain, outputs, tx);
self.pending_results.insert(txid, rx);
if self.acc.outputs.len() >= BATCH_SIZE_THRESHOLD {
self.flush();
}
}
/// Runs the currently accumulated batch on the global threadpool.
///
/// Subsequent calls to `Self::add_outputs` will be accumulated into a new batch.
fn flush(&mut self) {
if !self.acc.is_empty() {
let mut batch = Batch::new(self.acc.ivks.clone());
mem::swap(&mut batch, &mut self.acc);
rayon::spawn_fifo(|| batch.run());
}
}
/// Collects the pending decryption results for the given transaction.
fn collect_results(&mut self, txid: TxId) -> HashMap<(TxId, usize), DecryptedNote<D>> {
self.pending_results
.remove(&txid)
// We won't have a pending result if the transaction didn't have outputs of
// this runner's kind.
.map(|rx| {
rx.into_iter()
.filter_map(
|OutputIndex {
output_index,
value,
}| {
value.map(|decrypted_note| ((txid, output_index), decrypted_note))
},
)
.collect()
})
.unwrap_or_default()
}
}
/// A batch scanner for the `zcashd` wallet.
struct BatchScanner {
params: Network,
sapling_runner: Option<BatchRunner<SaplingDomain<Network>, OutputDescription<GrothProofBytes>>>,
}
fn init_batch_scanner(
network: &Network,
sapling_ivks: &[[u8; 32]],
) -> Result<Box<BatchScanner>, &'static str> {
let sapling_runner = if sapling_ivks.is_empty() {
None
} else {
let ivks = sapling_ivks
.iter()
.map(|ivk| {
let ivk: Option<sapling::SaplingIvk> =
jubjub::Fr::from_bytes(ivk).map(sapling::SaplingIvk).into();
ivk.ok_or("Invalid Sapling ivk passed to wallet::init_batch_scanner()")
})
.collect::<Result<_, _>>()?;
Some(BatchRunner::new(ivks))
};
Ok(Box::new(BatchScanner {
params: *network,
sapling_runner,
}))
}
impl BatchScanner {
/// Adds the given transaction's shielded outputs to the various batch runners.
///
/// After adding the outputs, any accumulated batch of sufficient size is run on the
/// global threadpool. Subsequent calls to `Self::add_transaction` will accumulate
/// those output kinds into new batches.
fn add_transaction(&mut self, tx_bytes: &[u8], height: u32) -> Result<(), io::Error> {
// The consensusBranchId parameter is ignored; it is not used in trial decryption.
let tx = Transaction::read(tx_bytes, consensus::BranchId::Sprout)?;
let txid = tx.txid();
let height = consensus::BlockHeight::from_u32(height);
// If we have any Sapling IVKs, and the transaction has any Sapling outputs, queue
// the outputs for trial decryption.
if let Some((runner, bundle)) = self.sapling_runner.as_mut().zip(tx.sapling_bundle()) {
let params = self.params;
runner.add_outputs(
txid,
|| SaplingDomain::for_height(params, height),
&bundle.shielded_outputs,
);
}
Ok(())
}
/// Runs the currently accumulated batches on the global threadpool.
///
/// Subsequent calls to `Self::add_transaction` will be accumulated into new batches.
fn flush(&mut self) {
if let Some(runner) = &mut self.sapling_runner {
runner.flush();
}
}
/// Collects the pending decryption results for the given transaction.
///
/// TODO: Return the `HashMap`s directly once `cxx` supports it.
fn collect_results(&mut self, txid: [u8; 32]) -> Box<BatchResult> {
let txid = TxId::from_bytes(txid);
let sapling = self
.sapling_runner
.as_mut()
.map(|runner| runner.collect_results(txid))
.unwrap_or_default();
Box::new(BatchResult { sapling })
}
}
struct BatchResult {
sapling: HashMap<(TxId, usize), DecryptedNote<SaplingDomain<Network>>>,
}
impl BatchResult {
fn get_sapling(&self) -> Vec<ffi::SaplingDecryptionResult> {
self.sapling
.iter()
.map(
|((txid, output), decrypted_note)| ffi::SaplingDecryptionResult {
txid: *txid.as_ref(),
output: *output as u32,
ivk: decrypted_note.ivk.to_repr(),
diversifier: decrypted_note.recipient.diversifier().0,
pk_d: decrypted_note.recipient.pk_d().to_bytes(),
},
)
.collect()
}
}