From d0026b460b0675441536865f85ad7b2b063e78bc Mon Sep 17 00:00:00 2001 From: Jack Grigg Date: Fri, 6 Aug 2021 16:21:37 +0100 Subject: [PATCH] zcash_primitives: Implement batched trial decryption optimisation --- zcash_primitives/benches/note_decryption.rs | 92 ++++++++++++++---- .../src/sapling/note_encryption.rs | 96 ++++++++++++++++++- 2 files changed, 166 insertions(+), 22 deletions(-) diff --git a/zcash_primitives/benches/note_decryption.rs b/zcash_primitives/benches/note_decryption.rs index b32659195..e585c4fa7 100644 --- a/zcash_primitives/benches/note_decryption.rs +++ b/zcash_primitives/benches/note_decryption.rs @@ -1,14 +1,17 @@ -use criterion::{criterion_group, criterion_main, Criterion}; +use std::iter; + +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; use ff::Field; use group::GroupEncoding; use rand_core::OsRng; +use zcash_note_encryption::batch; use zcash_primitives::{ consensus::{NetworkUpgrade::Canopy, Parameters, TestNetwork, TEST_NETWORK}, memo::MemoBytes, sapling::{ note_encryption::{ sapling_note_encryption, try_sapling_compact_note_decryption, - try_sapling_note_decryption, + try_sapling_note_decryption, SaplingDomain, }, util::generate_random_rseed, Diversifier, PaymentAddress, SaplingIvk, ValueCommitment, @@ -64,30 +67,77 @@ fn bench_note_decryption(c: &mut Criterion) { } }; - let mut group = c.benchmark_group("sapling-note-decryption"); + { + let mut group = c.benchmark_group("sapling-note-decryption"); + group.throughput(Throughput::Elements(1)); - group.bench_function("valid", |b| { - b.iter(|| try_sapling_note_decryption(&TEST_NETWORK, height, &valid_ivk, &output).unwrap()) - }); + group.bench_function("valid", |b| { + b.iter(|| { + try_sapling_note_decryption(&TEST_NETWORK, height, &valid_ivk, &output).unwrap() + }) + }); - group.bench_function("invalid", |b| { - b.iter(|| try_sapling_note_decryption(&TEST_NETWORK, height, &invalid_ivk, &output)) - }); + group.bench_function("invalid", |b| { + b.iter(|| try_sapling_note_decryption(&TEST_NETWORK, height, &invalid_ivk, &output)) + }); - let compact = CompactOutputDescription::from(output); + let compact = CompactOutputDescription::from(output.clone()); - group.bench_function("compact-valid", |b| { - b.iter(|| { - try_sapling_compact_note_decryption(&TEST_NETWORK, height, &valid_ivk, &compact) - .unwrap() - }) - }); + group.bench_function("compact-valid", |b| { + b.iter(|| { + try_sapling_compact_note_decryption(&TEST_NETWORK, height, &valid_ivk, &compact) + .unwrap() + }) + }); - group.bench_function("compact-invalid", |b| { - b.iter(|| { - try_sapling_compact_note_decryption(&TEST_NETWORK, height, &invalid_ivk, &compact) - }) - }); + group.bench_function("compact-invalid", |b| { + b.iter(|| { + try_sapling_compact_note_decryption(&TEST_NETWORK, height, &invalid_ivk, &compact) + }) + }); + } + + { + let valid_ivks = vec![valid_ivk]; + let invalid_ivks = vec![invalid_ivk]; + + // We benchmark with one IVK so the overall batch size is equal to the number of + // outputs. + let size = 10; + let outputs: Vec<_> = iter::repeat(output) + .take(size) + .map(|output| { + ( + SaplingDomain::for_height(TEST_NETWORK.clone(), height), + output, + ) + }) + .collect(); + + let mut group = c.benchmark_group("sapling-batch-note-decryption"); + group.throughput(Throughput::Elements(size as u64)); + + group.bench_function(BenchmarkId::new("valid", size), |b| { + b.iter(|| batch::try_note_decryption(&valid_ivks, &outputs)) + }); + + group.bench_function(BenchmarkId::new("invalid", size), |b| { + b.iter(|| batch::try_note_decryption(&invalid_ivks, &outputs)) + }); + + let compact: Vec<_> = outputs + .into_iter() + .map(|(domain, output)| (domain, CompactOutputDescription::from(output.clone()))) + .collect(); + + group.bench_function(BenchmarkId::new("compact-valid", size), |b| { + b.iter(|| batch::try_compact_note_decryption(&valid_ivks, &compact)) + }); + + group.bench_function(BenchmarkId::new("compact-invalid", size), |b| { + b.iter(|| batch::try_compact_note_decryption(&invalid_ivks, &compact)) + }); + } } #[cfg(unix)] diff --git a/zcash_primitives/src/sapling/note_encryption.rs b/zcash_primitives/src/sapling/note_encryption.rs index b3e2e7f43..5ecef7bff 100644 --- a/zcash_primitives/src/sapling/note_encryption.rs +++ b/zcash_primitives/src/sapling/note_encryption.rs @@ -3,6 +3,7 @@ use blake2b_simd::{Hash as Blake2bHash, Params as Blake2bParams}; use byteorder::{LittleEndian, WriteBytesExt}; use ff::PrimeField; use group::{cofactor::CofactorGroup, GroupEncoding}; +use jubjub::{AffinePoint, ExtendedPoint}; use rand_core::RngCore; use std::convert::TryInto; @@ -119,6 +120,12 @@ pub struct SaplingDomain { height: BlockHeight, } +impl SaplingDomain

{ + pub fn for_height(params: P, height: BlockHeight) -> Self { + Self { params, height } + } +} + impl Domain for SaplingDomain

{ type EphemeralSecretKey = jubjub::Scalar; // It is acceptable for this to be a point because we enforce by consensus that @@ -178,6 +185,37 @@ impl Domain for SaplingDomain

{ kdf_sapling(dhsecret, epk) } + fn batch_kdf<'a>( + items: impl Iterator, &'a EphemeralKeyBytes)>, + ) -> Vec> { + let (shared_secrets, ephemeral_keys): (Vec<_>, Vec<_>) = items.unzip(); + + let secrets: Vec<_> = shared_secrets + .iter() + .filter_map(|s| s.map(ExtendedPoint::from)) + .collect(); + let mut secrets_affine = vec![AffinePoint::identity(); shared_secrets.len()]; + group::Curve::batch_normalize(&secrets, &mut secrets_affine); + + let mut secrets_affine = secrets_affine.into_iter(); + shared_secrets + .into_iter() + .map(|s| s.and_then(|_| secrets_affine.next())) + .zip(ephemeral_keys.into_iter()) + .map(|(secret, ephemeral_key)| { + secret.map(|dhsecret| { + Blake2bParams::new() + .hash_length(32) + .personal(KDF_SAPLING_PERSONALIZATION) + .to_state() + .update(&dhsecret.to_bytes()) + .update(ephemeral_key.as_ref()) + .finalize() + }) + }) + .collect() + } + fn note_plaintext_bytes( note: &Self::Note, to: &Self::Recipient, @@ -436,7 +474,7 @@ mod tests { use std::convert::TryInto; use zcash_note_encryption::{ - EphemeralKeyBytes, NoteEncryption, OutgoingCipherKey, ENC_CIPHERTEXT_SIZE, + batch, EphemeralKeyBytes, NoteEncryption, OutgoingCipherKey, ENC_CIPHERTEXT_SIZE, NOTE_PLAINTEXT_SIZE, OUT_CIPHERTEXT_SIZE, OUT_PLAINTEXT_SIZE, }; @@ -1340,6 +1378,37 @@ mod tests { None => panic!("Output recovery failed"), } + match &batch::try_note_decryption( + &[ivk.clone()], + &[( + SaplingDomain::for_height(TEST_NETWORK, height), + output.clone(), + )], + )[..] + { + [Some((decrypted_note, decrypted_to, decrypted_memo))] => { + assert_eq!(decrypted_note, ¬e); + assert_eq!(decrypted_to, &to); + assert_eq!(&decrypted_memo.as_array()[..], &tv.memo[..]); + } + _ => panic!("Note decryption failed"), + } + + match &batch::try_compact_note_decryption( + &[ivk.clone()], + &[( + SaplingDomain::for_height(TEST_NETWORK, height), + CompactOutputDescription::from(output.clone()), + )], + )[..] + { + [Some((decrypted_note, decrypted_to))] => { + assert_eq!(decrypted_note, ¬e); + assert_eq!(decrypted_to, &to); + } + _ => panic!("Note decryption failed"), + } + // // Test encryption // @@ -1359,4 +1428,29 @@ mod tests { ); } } + + #[test] + fn batching() { + let mut rng = OsRng; + let height = TEST_NETWORK.activation_height(Canopy).unwrap(); + + // Test batch trial-decryption with multiple IVKs and outputs. + let invalid_ivk = SaplingIvk(jubjub::Fr::random(rng)); + let valid_ivk = SaplingIvk(jubjub::Fr::random(rng)); + let outputs: Vec<_> = (0..10) + .map(|_| { + ( + SaplingDomain::for_height(TEST_NETWORK, height), + random_enc_ciphertext_with(height, &valid_ivk, &mut rng).2, + ) + }) + .collect(); + + let res = batch::try_note_decryption(&[invalid_ivk, valid_ivk], &outputs); + assert_eq!(res.len(), 20); + assert_eq!(&res[..10], &vec![None; 10][..]); + for result in &res[10..] { + assert!(result.is_some()); + } + } }