zcash_primitives: Implement batched trial decryption optimisation

This commit is contained in:
Jack Grigg 2021-08-06 16:21:37 +01:00
parent 8a615c4393
commit d0026b460b
2 changed files with 166 additions and 22 deletions

View File

@ -1,14 +1,17 @@
use criterion::{criterion_group, criterion_main, Criterion}; use std::iter;
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ff::Field; use ff::Field;
use group::GroupEncoding; use group::GroupEncoding;
use rand_core::OsRng; use rand_core::OsRng;
use zcash_note_encryption::batch;
use zcash_primitives::{ use zcash_primitives::{
consensus::{NetworkUpgrade::Canopy, Parameters, TestNetwork, TEST_NETWORK}, consensus::{NetworkUpgrade::Canopy, Parameters, TestNetwork, TEST_NETWORK},
memo::MemoBytes, memo::MemoBytes,
sapling::{ sapling::{
note_encryption::{ note_encryption::{
sapling_note_encryption, try_sapling_compact_note_decryption, sapling_note_encryption, try_sapling_compact_note_decryption,
try_sapling_note_decryption, try_sapling_note_decryption, SaplingDomain,
}, },
util::generate_random_rseed, util::generate_random_rseed,
Diversifier, PaymentAddress, SaplingIvk, ValueCommitment, Diversifier, PaymentAddress, SaplingIvk, ValueCommitment,
@ -64,30 +67,77 @@ fn bench_note_decryption(c: &mut Criterion) {
} }
}; };
let mut group = c.benchmark_group("sapling-note-decryption"); {
let mut group = c.benchmark_group("sapling-note-decryption");
group.throughput(Throughput::Elements(1));
group.bench_function("valid", |b| { group.bench_function("valid", |b| {
b.iter(|| try_sapling_note_decryption(&TEST_NETWORK, height, &valid_ivk, &output).unwrap()) b.iter(|| {
}); try_sapling_note_decryption(&TEST_NETWORK, height, &valid_ivk, &output).unwrap()
})
});
group.bench_function("invalid", |b| { group.bench_function("invalid", |b| {
b.iter(|| try_sapling_note_decryption(&TEST_NETWORK, height, &invalid_ivk, &output)) b.iter(|| try_sapling_note_decryption(&TEST_NETWORK, height, &invalid_ivk, &output))
}); });
let compact = CompactOutputDescription::from(output); let compact = CompactOutputDescription::from(output.clone());
group.bench_function("compact-valid", |b| { group.bench_function("compact-valid", |b| {
b.iter(|| { b.iter(|| {
try_sapling_compact_note_decryption(&TEST_NETWORK, height, &valid_ivk, &compact) try_sapling_compact_note_decryption(&TEST_NETWORK, height, &valid_ivk, &compact)
.unwrap() .unwrap()
}) })
}); });
group.bench_function("compact-invalid", |b| { group.bench_function("compact-invalid", |b| {
b.iter(|| { b.iter(|| {
try_sapling_compact_note_decryption(&TEST_NETWORK, height, &invalid_ivk, &compact) try_sapling_compact_note_decryption(&TEST_NETWORK, height, &invalid_ivk, &compact)
}) })
}); });
}
{
let valid_ivks = vec![valid_ivk];
let invalid_ivks = vec![invalid_ivk];
// We benchmark with one IVK so the overall batch size is equal to the number of
// outputs.
let size = 10;
let outputs: Vec<_> = iter::repeat(output)
.take(size)
.map(|output| {
(
SaplingDomain::for_height(TEST_NETWORK.clone(), height),
output,
)
})
.collect();
let mut group = c.benchmark_group("sapling-batch-note-decryption");
group.throughput(Throughput::Elements(size as u64));
group.bench_function(BenchmarkId::new("valid", size), |b| {
b.iter(|| batch::try_note_decryption(&valid_ivks, &outputs))
});
group.bench_function(BenchmarkId::new("invalid", size), |b| {
b.iter(|| batch::try_note_decryption(&invalid_ivks, &outputs))
});
let compact: Vec<_> = outputs
.into_iter()
.map(|(domain, output)| (domain, CompactOutputDescription::from(output.clone())))
.collect();
group.bench_function(BenchmarkId::new("compact-valid", size), |b| {
b.iter(|| batch::try_compact_note_decryption(&valid_ivks, &compact))
});
group.bench_function(BenchmarkId::new("compact-invalid", size), |b| {
b.iter(|| batch::try_compact_note_decryption(&invalid_ivks, &compact))
});
}
} }
#[cfg(unix)] #[cfg(unix)]

View File

@ -3,6 +3,7 @@ use blake2b_simd::{Hash as Blake2bHash, Params as Blake2bParams};
use byteorder::{LittleEndian, WriteBytesExt}; use byteorder::{LittleEndian, WriteBytesExt};
use ff::PrimeField; use ff::PrimeField;
use group::{cofactor::CofactorGroup, GroupEncoding}; use group::{cofactor::CofactorGroup, GroupEncoding};
use jubjub::{AffinePoint, ExtendedPoint};
use rand_core::RngCore; use rand_core::RngCore;
use std::convert::TryInto; use std::convert::TryInto;
@ -119,6 +120,12 @@ pub struct SaplingDomain<P: consensus::Parameters> {
height: BlockHeight, height: BlockHeight,
} }
impl<P: consensus::Parameters> SaplingDomain<P> {
pub fn for_height(params: P, height: BlockHeight) -> Self {
Self { params, height }
}
}
impl<P: consensus::Parameters> Domain for SaplingDomain<P> { impl<P: consensus::Parameters> Domain for SaplingDomain<P> {
type EphemeralSecretKey = jubjub::Scalar; type EphemeralSecretKey = jubjub::Scalar;
// It is acceptable for this to be a point because we enforce by consensus that // It is acceptable for this to be a point because we enforce by consensus that
@ -178,6 +185,37 @@ impl<P: consensus::Parameters> Domain for SaplingDomain<P> {
kdf_sapling(dhsecret, epk) kdf_sapling(dhsecret, epk)
} }
fn batch_kdf<'a>(
items: impl Iterator<Item = (Option<Self::SharedSecret>, &'a EphemeralKeyBytes)>,
) -> Vec<Option<Self::SymmetricKey>> {
let (shared_secrets, ephemeral_keys): (Vec<_>, Vec<_>) = items.unzip();
let secrets: Vec<_> = shared_secrets
.iter()
.filter_map(|s| s.map(ExtendedPoint::from))
.collect();
let mut secrets_affine = vec![AffinePoint::identity(); shared_secrets.len()];
group::Curve::batch_normalize(&secrets, &mut secrets_affine);
let mut secrets_affine = secrets_affine.into_iter();
shared_secrets
.into_iter()
.map(|s| s.and_then(|_| secrets_affine.next()))
.zip(ephemeral_keys.into_iter())
.map(|(secret, ephemeral_key)| {
secret.map(|dhsecret| {
Blake2bParams::new()
.hash_length(32)
.personal(KDF_SAPLING_PERSONALIZATION)
.to_state()
.update(&dhsecret.to_bytes())
.update(ephemeral_key.as_ref())
.finalize()
})
})
.collect()
}
fn note_plaintext_bytes( fn note_plaintext_bytes(
note: &Self::Note, note: &Self::Note,
to: &Self::Recipient, to: &Self::Recipient,
@ -436,7 +474,7 @@ mod tests {
use std::convert::TryInto; use std::convert::TryInto;
use zcash_note_encryption::{ use zcash_note_encryption::{
EphemeralKeyBytes, NoteEncryption, OutgoingCipherKey, ENC_CIPHERTEXT_SIZE, batch, EphemeralKeyBytes, NoteEncryption, OutgoingCipherKey, ENC_CIPHERTEXT_SIZE,
NOTE_PLAINTEXT_SIZE, OUT_CIPHERTEXT_SIZE, OUT_PLAINTEXT_SIZE, NOTE_PLAINTEXT_SIZE, OUT_CIPHERTEXT_SIZE, OUT_PLAINTEXT_SIZE,
}; };
@ -1340,6 +1378,37 @@ mod tests {
None => panic!("Output recovery failed"), None => panic!("Output recovery failed"),
} }
match &batch::try_note_decryption(
&[ivk.clone()],
&[(
SaplingDomain::for_height(TEST_NETWORK, height),
output.clone(),
)],
)[..]
{
[Some((decrypted_note, decrypted_to, decrypted_memo))] => {
assert_eq!(decrypted_note, &note);
assert_eq!(decrypted_to, &to);
assert_eq!(&decrypted_memo.as_array()[..], &tv.memo[..]);
}
_ => panic!("Note decryption failed"),
}
match &batch::try_compact_note_decryption(
&[ivk.clone()],
&[(
SaplingDomain::for_height(TEST_NETWORK, height),
CompactOutputDescription::from(output.clone()),
)],
)[..]
{
[Some((decrypted_note, decrypted_to))] => {
assert_eq!(decrypted_note, &note);
assert_eq!(decrypted_to, &to);
}
_ => panic!("Note decryption failed"),
}
// //
// Test encryption // Test encryption
// //
@ -1359,4 +1428,29 @@ mod tests {
); );
} }
} }
#[test]
fn batching() {
let mut rng = OsRng;
let height = TEST_NETWORK.activation_height(Canopy).unwrap();
// Test batch trial-decryption with multiple IVKs and outputs.
let invalid_ivk = SaplingIvk(jubjub::Fr::random(rng));
let valid_ivk = SaplingIvk(jubjub::Fr::random(rng));
let outputs: Vec<_> = (0..10)
.map(|_| {
(
SaplingDomain::for_height(TEST_NETWORK, height),
random_enc_ciphertext_with(height, &valid_ivk, &mut rng).2,
)
})
.collect();
let res = batch::try_note_decryption(&[invalid_ivk, valid_ivk], &outputs);
assert_eq!(res.len(), 20);
assert_eq!(&res[..10], &vec![None; 10][..]);
for result in &res[10..] {
assert!(result.is_some());
}
}
} }