Merge pull request #425 from str4d/batch-note-decryption

Batch note decryption
This commit is contained in:
str4d 2021-08-11 00:47:27 +01:00 committed by GitHub
commit 13b023387b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 319 additions and 26 deletions

View File

@ -26,3 +26,6 @@ nom = { git = "https://github.com/myrrlyn/nom.git", rev = "d6b81f5303b0a347726e1
halo2 = { git = "https://github.com/zcash/halo2.git", rev = "27c4187673a9c6ade13fbdbd4f20955530c22d7f" }
orchard = { git = "https://github.com/zcash/orchard.git", rev = "8454f86d423edbf0b53a1d5d32df1c691f8b7188" }
zcash_note_encryption = { path = "components/zcash_note_encryption" }
# Unreleased
jubjub = { git = "https://github.com/zkcrypto/jubjub.git", rev = "96ab4162b83303378eae32a326b54d88b75bffc2" }

View File

@ -0,0 +1,69 @@
//! APIs for batch trial decryption.
use std::iter;
use crate::{
try_compact_note_decryption_inner, try_note_decryption_inner, Domain, EphemeralKeyBytes,
ShieldedOutput,
};
/// Trial decryption of a batch of notes with a set of recipients.
///
/// This is the batched version of [`zcash_note_encryption::try_note_decryption`].
pub fn try_note_decryption<D: Domain, Output: ShieldedOutput<D>>(
ivks: &[D::IncomingViewingKey],
outputs: &[(D, Output)],
) -> Vec<Option<(D::Note, D::Recipient, D::Memo)>> {
batch_note_decryption(ivks, outputs, try_note_decryption_inner)
}
/// Trial decryption of a batch of notes for light clients with a set of recipients.
///
/// This is the batched version of [`zcash_note_encryption::try_compact_note_decryption`].
pub fn try_compact_note_decryption<D: Domain, Output: ShieldedOutput<D>>(
ivks: &[D::IncomingViewingKey],
outputs: &[(D, Output)],
) -> Vec<Option<(D::Note, D::Recipient)>> {
batch_note_decryption(ivks, outputs, try_compact_note_decryption_inner)
}
fn batch_note_decryption<D: Domain, Output: ShieldedOutput<D>, F, FR>(
ivks: &[D::IncomingViewingKey],
outputs: &[(D, Output)],
decrypt_inner: F,
) -> Vec<Option<FR>>
where
F: Fn(&D, &D::IncomingViewingKey, &EphemeralKeyBytes, &Output, D::SymmetricKey) -> Option<FR>,
{
// Fetch the ephemeral keys for each output and batch-parse them.
let ephemeral_keys = D::batch_epk(outputs.iter().map(|(_, output)| output.ephemeral_key()));
// Derive the shared secrets for all combinations of (ivk, output).
// The scalar multiplications cannot benefit from batching.
let items = ivks.iter().flat_map(|ivk| {
ephemeral_keys.iter().map(move |(epk, ephemeral_key)| {
(
epk.as_ref().map(|epk| D::ka_agree_dec(ivk, epk)),
ephemeral_key,
)
})
});
// Run the batch-KDF to obtain the symmetric keys from the shared secrets.
let keys = D::batch_kdf(items);
// Finish the trial decryption!
ivks.iter()
.flat_map(|ivk| {
// Reconstruct the matrix of (ivk, output) combinations.
iter::repeat(ivk)
.zip(ephemeral_keys.iter())
.zip(outputs.iter())
})
.zip(keys)
.map(|(((ivk, (_, ephemeral_key)), (domain, output)), key)| {
// The `and_then` propagates any potential rejection from `D::epk`.
key.and_then(|key| decrypt_inner(domain, ivk, ephemeral_key, output, key))
})
.collect()
}

View File

@ -7,6 +7,8 @@ use crypto_api_chachapoly::{ChaCha20Ietf, ChachaPolyIetf};
use rand_core::RngCore;
use subtle::{Choice, ConstantTimeEq};
pub mod batch;
pub const COMPACT_NOTE_SIZE: usize = 1 + // version
11 + // diversifier
8 + // value
@ -99,6 +101,19 @@ pub trait Domain {
fn kdf(secret: Self::SharedSecret, ephemeral_key: &EphemeralKeyBytes) -> Self::SymmetricKey;
/// Computes `Self::kdf` on a batch of items.
///
/// For each item in the batch, if the shared secret is `None`, this returns `None` at
/// that position.
fn batch_kdf<'a>(
items: impl Iterator<Item = (Option<Self::SharedSecret>, &'a EphemeralKeyBytes)>,
) -> Vec<Option<Self::SymmetricKey>> {
// Default implementation: do the non-batched thing.
items
.map(|(secret, ephemeral_key)| secret.map(|secret| Self::kdf(secret, ephemeral_key)))
.collect()
}
// for right now, we just need `recipient` to get `d`; in the future when we
// can get that from a Sapling note, the recipient parameter will be able
// to be removed.
@ -124,6 +139,22 @@ pub trait Domain {
fn epk(ephemeral_key: &EphemeralKeyBytes) -> Option<Self::EphemeralPublicKey>;
/// Computes `Self::epk` on a batch of ephemeral keys.
///
/// This is useful for protocols where the underlying curve requires an inversion to
/// parse an encoded point.
///
/// For usability, this returns tuples of the ephemeral keys and the result of parsing
/// them.
fn batch_epk(
ephemeral_keys: impl Iterator<Item = EphemeralKeyBytes>,
) -> Vec<(Option<Self::EphemeralPublicKey>, EphemeralKeyBytes)> {
// Default implementation: do the non-batched thing.
ephemeral_keys
.map(|ephemeral_key| (Self::epk(&ephemeral_key), ephemeral_key))
.collect()
}
fn check_epk_bytes<F: Fn(&Self::EphemeralSecretKey) -> NoteValidity>(
note: &Self::Note,
check: F,
@ -334,13 +365,23 @@ pub fn try_note_decryption<D: Domain, Output: ShieldedOutput<D>>(
ivk: &D::IncomingViewingKey,
output: &Output,
) -> Option<(D::Note, D::Recipient, D::Memo)> {
assert_eq!(output.enc_ciphertext().len(), ENC_CIPHERTEXT_SIZE);
let ephemeral_key = output.ephemeral_key();
let epk = D::epk(&ephemeral_key)?;
let shared_secret = D::ka_agree_dec(ivk, &epk);
let key = D::kdf(shared_secret, &ephemeral_key);
try_note_decryption_inner(domain, ivk, &ephemeral_key, output, key)
}
fn try_note_decryption_inner<D: Domain, Output: ShieldedOutput<D>>(
domain: &D,
ivk: &D::IncomingViewingKey,
ephemeral_key: &EphemeralKeyBytes,
output: &Output,
key: D::SymmetricKey,
) -> Option<(D::Note, D::Recipient, D::Memo)> {
assert_eq!(output.enc_ciphertext().len(), ENC_CIPHERTEXT_SIZE);
let mut plaintext = [0; ENC_CIPHERTEXT_SIZE];
assert_eq!(
ChachaPolyIetf::aead_cipher()
@ -358,7 +399,7 @@ pub fn try_note_decryption<D: Domain, Output: ShieldedOutput<D>>(
let (note, to) = parse_note_plaintext_without_memo_ivk(
domain,
ivk,
&ephemeral_key,
ephemeral_key,
&output.cmstar_bytes(),
&plaintext,
)?;
@ -419,13 +460,24 @@ pub fn try_compact_note_decryption<D: Domain, Output: ShieldedOutput<D>>(
ivk: &D::IncomingViewingKey,
output: &Output,
) -> Option<(D::Note, D::Recipient)> {
assert_eq!(output.enc_ciphertext().len(), COMPACT_NOTE_SIZE);
let ephemeral_key = output.ephemeral_key();
let epk = D::epk(&ephemeral_key)?;
let shared_secret = D::ka_agree_dec(&ivk, &epk);
let key = D::kdf(shared_secret, &ephemeral_key);
try_compact_note_decryption_inner(domain, ivk, &ephemeral_key, output, key)
}
fn try_compact_note_decryption_inner<D: Domain, Output: ShieldedOutput<D>>(
domain: &D,
ivk: &D::IncomingViewingKey,
ephemeral_key: &EphemeralKeyBytes,
output: &Output,
key: D::SymmetricKey,
) -> Option<(D::Note, D::Recipient)> {
assert_eq!(output.enc_ciphertext().len(), COMPACT_NOTE_SIZE);
// Start from block 1 to skip over Poly1305 keying output
let mut plaintext = [0; COMPACT_NOTE_SIZE];
plaintext.copy_from_slice(output.enc_ciphertext());
@ -434,7 +486,7 @@ pub fn try_compact_note_decryption<D: Domain, Output: ShieldedOutput<D>>(
parse_note_plaintext_without_memo_ivk(
domain,
ivk,
&ephemeral_key,
ephemeral_key,
&output.cmstar_bytes(),
&plaintext,
)

View File

@ -1,14 +1,17 @@
use criterion::{criterion_group, criterion_main, Criterion};
use std::iter;
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ff::Field;
use group::GroupEncoding;
use rand_core::OsRng;
use zcash_note_encryption::batch;
use zcash_primitives::{
consensus::{NetworkUpgrade::Canopy, Parameters, TestNetwork, TEST_NETWORK},
memo::MemoBytes,
sapling::{
note_encryption::{
sapling_note_encryption, try_sapling_compact_note_decryption,
try_sapling_note_decryption,
try_sapling_note_decryption, SaplingDomain,
},
util::generate_random_rseed,
Diversifier, PaymentAddress, SaplingIvk, ValueCommitment,
@ -64,30 +67,77 @@ fn bench_note_decryption(c: &mut Criterion) {
}
};
let mut group = c.benchmark_group("sapling-note-decryption");
{
let mut group = c.benchmark_group("sapling-note-decryption");
group.throughput(Throughput::Elements(1));
group.bench_function("valid", |b| {
b.iter(|| try_sapling_note_decryption(&TEST_NETWORK, height, &valid_ivk, &output).unwrap())
});
group.bench_function("valid", |b| {
b.iter(|| {
try_sapling_note_decryption(&TEST_NETWORK, height, &valid_ivk, &output).unwrap()
})
});
group.bench_function("invalid", |b| {
b.iter(|| try_sapling_note_decryption(&TEST_NETWORK, height, &invalid_ivk, &output))
});
group.bench_function("invalid", |b| {
b.iter(|| try_sapling_note_decryption(&TEST_NETWORK, height, &invalid_ivk, &output))
});
let compact = CompactOutputDescription::from(output);
let compact = CompactOutputDescription::from(output.clone());
group.bench_function("compact-valid", |b| {
b.iter(|| {
try_sapling_compact_note_decryption(&TEST_NETWORK, height, &valid_ivk, &compact)
.unwrap()
})
});
group.bench_function("compact-valid", |b| {
b.iter(|| {
try_sapling_compact_note_decryption(&TEST_NETWORK, height, &valid_ivk, &compact)
.unwrap()
})
});
group.bench_function("compact-invalid", |b| {
b.iter(|| {
try_sapling_compact_note_decryption(&TEST_NETWORK, height, &invalid_ivk, &compact)
})
});
group.bench_function("compact-invalid", |b| {
b.iter(|| {
try_sapling_compact_note_decryption(&TEST_NETWORK, height, &invalid_ivk, &compact)
})
});
}
{
let valid_ivks = vec![valid_ivk];
let invalid_ivks = vec![invalid_ivk];
// We benchmark with one IVK so the overall batch size is equal to the number of
// outputs.
let size = 10;
let outputs: Vec<_> = iter::repeat(output)
.take(size)
.map(|output| {
(
SaplingDomain::for_height(TEST_NETWORK.clone(), height),
output,
)
})
.collect();
let mut group = c.benchmark_group("sapling-batch-note-decryption");
group.throughput(Throughput::Elements(size as u64));
group.bench_function(BenchmarkId::new("valid", size), |b| {
b.iter(|| batch::try_note_decryption(&valid_ivks, &outputs))
});
group.bench_function(BenchmarkId::new("invalid", size), |b| {
b.iter(|| batch::try_note_decryption(&invalid_ivks, &outputs))
});
let compact: Vec<_> = outputs
.into_iter()
.map(|(domain, output)| (domain, CompactOutputDescription::from(output.clone())))
.collect();
group.bench_function(BenchmarkId::new("compact-valid", size), |b| {
b.iter(|| batch::try_compact_note_decryption(&valid_ivks, &compact))
});
group.bench_function(BenchmarkId::new("compact-invalid", size), |b| {
b.iter(|| batch::try_compact_note_decryption(&invalid_ivks, &compact))
});
}
}
#[cfg(unix)]

View File

@ -3,6 +3,7 @@ use blake2b_simd::{Hash as Blake2bHash, Params as Blake2bParams};
use byteorder::{LittleEndian, WriteBytesExt};
use ff::PrimeField;
use group::{cofactor::CofactorGroup, GroupEncoding};
use jubjub::{AffinePoint, ExtendedPoint};
use rand_core::RngCore;
use std::convert::TryInto;
@ -119,6 +120,12 @@ pub struct SaplingDomain<P: consensus::Parameters> {
height: BlockHeight,
}
impl<P: consensus::Parameters> SaplingDomain<P> {
pub fn for_height(params: P, height: BlockHeight) -> Self {
Self { params, height }
}
}
impl<P: consensus::Parameters> Domain for SaplingDomain<P> {
type EphemeralSecretKey = jubjub::Scalar;
// It is acceptable for this to be a point because we enforce by consensus that
@ -178,6 +185,37 @@ impl<P: consensus::Parameters> Domain for SaplingDomain<P> {
kdf_sapling(dhsecret, epk)
}
fn batch_kdf<'a>(
items: impl Iterator<Item = (Option<Self::SharedSecret>, &'a EphemeralKeyBytes)>,
) -> Vec<Option<Self::SymmetricKey>> {
let (shared_secrets, ephemeral_keys): (Vec<_>, Vec<_>) = items.unzip();
let secrets: Vec<_> = shared_secrets
.iter()
.filter_map(|s| s.map(ExtendedPoint::from))
.collect();
let mut secrets_affine = vec![AffinePoint::identity(); shared_secrets.len()];
group::Curve::batch_normalize(&secrets, &mut secrets_affine);
let mut secrets_affine = secrets_affine.into_iter();
shared_secrets
.into_iter()
.map(|s| s.and_then(|_| secrets_affine.next()))
.zip(ephemeral_keys.into_iter())
.map(|(secret, ephemeral_key)| {
secret.map(|dhsecret| {
Blake2bParams::new()
.hash_length(32)
.personal(KDF_SAPLING_PERSONALIZATION)
.to_state()
.update(&dhsecret.to_bytes())
.update(ephemeral_key.as_ref())
.finalize()
})
})
.collect()
}
fn note_plaintext_bytes(
note: &Self::Note,
to: &Self::Recipient,
@ -240,6 +278,19 @@ impl<P: consensus::Parameters> Domain for SaplingDomain<P> {
jubjub::ExtendedPoint::from_bytes(&ephemeral_key.0).into()
}
fn batch_epk(
ephemeral_keys: impl Iterator<Item = EphemeralKeyBytes>,
) -> Vec<(Option<Self::EphemeralPublicKey>, EphemeralKeyBytes)> {
let ephemeral_keys: Vec<_> = ephemeral_keys.collect();
let epks = jubjub::AffinePoint::batch_from_bytes(ephemeral_keys.iter().map(|b| b.0));
epks.into_iter()
.zip(ephemeral_keys.into_iter())
.map(|(epk, ephemeral_key)| {
(epk.map(jubjub::ExtendedPoint::from).into(), ephemeral_key)
})
.collect()
}
fn check_epk_bytes<F: FnOnce(&Self::EphemeralSecretKey) -> NoteValidity>(
note: &Note,
check: F,
@ -436,7 +487,7 @@ mod tests {
use std::convert::TryInto;
use zcash_note_encryption::{
EphemeralKeyBytes, NoteEncryption, OutgoingCipherKey, ENC_CIPHERTEXT_SIZE,
batch, EphemeralKeyBytes, NoteEncryption, OutgoingCipherKey, ENC_CIPHERTEXT_SIZE,
NOTE_PLAINTEXT_SIZE, OUT_CIPHERTEXT_SIZE, OUT_PLAINTEXT_SIZE,
};
@ -1340,6 +1391,37 @@ mod tests {
None => panic!("Output recovery failed"),
}
match &batch::try_note_decryption(
&[ivk.clone()],
&[(
SaplingDomain::for_height(TEST_NETWORK, height),
output.clone(),
)],
)[..]
{
[Some((decrypted_note, decrypted_to, decrypted_memo))] => {
assert_eq!(decrypted_note, &note);
assert_eq!(decrypted_to, &to);
assert_eq!(&decrypted_memo.as_array()[..], &tv.memo[..]);
}
_ => panic!("Note decryption failed"),
}
match &batch::try_compact_note_decryption(
&[ivk.clone()],
&[(
SaplingDomain::for_height(TEST_NETWORK, height),
CompactOutputDescription::from(output.clone()),
)],
)[..]
{
[Some((decrypted_note, decrypted_to))] => {
assert_eq!(decrypted_note, &note);
assert_eq!(decrypted_to, &to);
}
_ => panic!("Note decryption failed"),
}
//
// Test encryption
//
@ -1359,4 +1441,41 @@ mod tests {
);
}
}
#[test]
fn batching() {
let mut rng = OsRng;
let height = TEST_NETWORK.activation_height(Canopy).unwrap();
// Test batch trial-decryption with multiple IVKs and outputs.
let invalid_ivk = SaplingIvk(jubjub::Fr::random(rng));
let valid_ivk = SaplingIvk(jubjub::Fr::random(rng));
let outputs: Vec<_> = (0..10)
.map(|_| {
(
SaplingDomain::for_height(TEST_NETWORK, height),
random_enc_ciphertext_with(height, &valid_ivk, &mut rng).2,
)
})
.collect();
let res = batch::try_note_decryption(&[invalid_ivk.clone(), valid_ivk.clone()], &outputs);
assert_eq!(res.len(), 20);
// The batched trial decryptions with invalid_ivk failed.
assert_eq!(&res[..10], &vec![None; 10][..]);
for (result, (_, output)) in res[10..].iter().zip(outputs.iter()) {
// Confirm that the outputs should indeed have failed with invalid_ivk
assert_eq!(
try_sapling_note_decryption(&TEST_NETWORK, height, &invalid_ivk, output),
None
);
// Confirm the successful batched trial decryptions gave the same result.
assert!(result.is_some());
assert_eq!(
result,
&try_sapling_note_decryption(&TEST_NETWORK, height, &valid_ivk, output)
);
}
}
}