Optimized batch verification (#36)
* Pulls in some traits and methods from curve25519-dalek around the vartime multiscalar multiplication. * Move scalar mul things we want to upstream to jubjub to their own crate * Make Verify agnostic to the SigType Co-authored-by: Henry de Valence <hdevalence@hdevalence.ca> Co-authored-by: Jane Lusby <jlusby42@gmail.com>pull/38/head
parent
f27b9c3c77
commit
ba256655dd
@ -1,3 +1,4 @@ |
||||
/target |
||||
**/*.rs.bk |
||||
Cargo.lock |
||||
*~ |
||||
|
@ -0,0 +1,55 @@ |
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; |
||||
|
||||
use rand::thread_rng; |
||||
use redjubjub::*; |
||||
use std::convert::TryFrom; |
||||
|
||||
fn sigs_with_distinct_keys<T: SigType>( |
||||
) -> impl Iterator<Item = (VerificationKeyBytes<T>, Signature<T>)> { |
||||
std::iter::repeat_with(|| { |
||||
let sk = SigningKey::<T>::new(thread_rng()); |
||||
let vk_bytes = VerificationKey::from(&sk).into(); |
||||
let sig = sk.sign(thread_rng(), b""); |
||||
(vk_bytes, sig) |
||||
}) |
||||
} |
||||
|
||||
fn bench_batch_verify(c: &mut Criterion) { |
||||
let mut group = c.benchmark_group("Batch Verification"); |
||||
for &n in [8usize, 16, 24, 32, 40, 48, 56, 64].iter() { |
||||
group.throughput(Throughput::Elements(*n as u64)); |
||||
|
||||
let sigs = sigs_with_distinct_keys().take(*n).collect::<Vec<_>>(); |
||||
|
||||
group.bench_with_input( |
||||
BenchmarkId::new("Unbatched verification", n), |
||||
&sigs, |
||||
|b, sigs| { |
||||
b.iter(|| { |
||||
for (vk_bytes, sig) in sigs.iter() { |
||||
let _ = |
||||
VerificationKey::try_from(*vk_bytes).and_then(|vk| vk.verify(b"", sig)); |
||||
} |
||||
}) |
||||
}, |
||||
); |
||||
|
||||
group.bench_with_input( |
||||
BenchmarkId::new("Batched verification", n), |
||||
&sigs, |
||||
|b, sigs| { |
||||
b.iter(|| { |
||||
let mut batch = batch::Verifier::<SpendAuth>::new(); |
||||
for (vk_bytes, sig) in sigs.iter().cloned() { |
||||
batch.queue((vk_bytes, sig, b"")); |
||||
} |
||||
batch.verify(thread_rng()) |
||||
}) |
||||
}, |
||||
); |
||||
} |
||||
group.finish(); |
||||
} |
||||
|
||||
criterion_group!(benches, bench_batch_verify); |
||||
criterion_main!(benches); |
@ -0,0 +1,231 @@ |
||||
//! Performs batch RedJubjub signature verification.
|
||||
//!
|
||||
//! Batch verification asks whether *all* signatures in some set are valid,
|
||||
//! rather than asking whether *each* of them is valid. This allows sharing
|
||||
//! computations among all signature verifications, performing less work overall
|
||||
//! at the cost of higher latency (the entire batch must complete), complexity of
|
||||
//! caller code (which must assemble a batch of signatures across work-items),
|
||||
//! and loss of the ability to easily pinpoint failing signatures.
|
||||
//!
|
||||
|
||||
use std::convert::TryFrom; |
||||
|
||||
use jubjub::*; |
||||
use rand_core::{CryptoRng, RngCore}; |
||||
|
||||
use crate::{private::Sealed, scalar_mul::VartimeMultiscalarMul, *}; |
||||
|
||||
// Shim to generate a random 128bit value in a [u64; 4], without
|
||||
// importing `rand`.
|
||||
fn gen_128_bits<R: RngCore + CryptoRng>(mut rng: R) -> [u64; 4] { |
||||
let mut bytes = [0u64; 4]; |
||||
bytes[0] = rng.next_u64(); |
||||
bytes[1] = rng.next_u64(); |
||||
bytes |
||||
} |
||||
|
||||
enum Inner { |
||||
SpendAuth { |
||||
vk_bytes: VerificationKeyBytes<SpendAuth>, |
||||
sig: Signature<SpendAuth>, |
||||
c: Scalar, |
||||
}, |
||||
Binding { |
||||
vk_bytes: VerificationKeyBytes<Binding>, |
||||
sig: Signature<Binding>, |
||||
c: Scalar, |
||||
}, |
||||
} |
||||
|
||||
/// A batch verification item.
|
||||
///
|
||||
/// This struct exists to allow batch processing to be decoupled from the
|
||||
/// lifetime of the message. This is useful when using the batch verification API
|
||||
/// in an async context.
|
||||
pub struct Item { |
||||
inner: Inner, |
||||
} |
||||
|
||||
impl<'msg, M: AsRef<[u8]>> |
||||
From<( |
||||
VerificationKeyBytes<SpendAuth>, |
||||
Signature<SpendAuth>, |
||||
&'msg M, |
||||
)> for Item |
||||
{ |
||||
fn from( |
||||
(vk_bytes, sig, msg): ( |
||||
VerificationKeyBytes<SpendAuth>, |
||||
Signature<SpendAuth>, |
||||
&'msg M, |
||||
), |
||||
) -> Self { |
||||
// Compute c now to avoid dependency on the msg lifetime.
|
||||
let c = HStar::default() |
||||
.update(&sig.r_bytes[..]) |
||||
.update(&vk_bytes.bytes[..]) |
||||
.update(msg) |
||||
.finalize(); |
||||
Self { |
||||
inner: Inner::SpendAuth { vk_bytes, sig, c }, |
||||
} |
||||
} |
||||
} |
||||
|
||||
impl<'msg, M: AsRef<[u8]>> From<(VerificationKeyBytes<Binding>, Signature<Binding>, &'msg M)> |
||||
for Item |
||||
{ |
||||
fn from( |
||||
(vk_bytes, sig, msg): (VerificationKeyBytes<Binding>, Signature<Binding>, &'msg M), |
||||
) -> Self { |
||||
// Compute c now to avoid dependency on the msg lifetime.
|
||||
let c = HStar::default() |
||||
.update(&sig.r_bytes[..]) |
||||
.update(&vk_bytes.bytes[..]) |
||||
.update(msg) |
||||
.finalize(); |
||||
Self { |
||||
inner: Inner::Binding { vk_bytes, sig, c }, |
||||
} |
||||
} |
||||
} |
||||
|
||||
#[derive(Default)] |
||||
/// A batch verification context.
|
||||
pub struct Verifier { |
||||
/// Signature data queued for verification.
|
||||
signatures: Vec<Item>, |
||||
} |
||||
|
||||
impl Verifier { |
||||
/// Construct a new batch verifier.
|
||||
pub fn new() -> Verifier { |
||||
Verifier::default() |
||||
} |
||||
|
||||
/// Queue an Item for verification.
|
||||
pub fn queue<I: Into<Item>>(&mut self, item: I) { |
||||
self.signatures.push(item.into()); |
||||
} |
||||
|
||||
/// Perform batch verification, returning `Ok(())` if all signatures were
|
||||
/// valid and `Err` otherwise.
|
||||
///
|
||||
/// The batch verification equation is:
|
||||
///
|
||||
/// h_G * -[sum(z_i * s_i)]P_G + sum([z_i]R_i + [z_i * c_i]VK_i) = 0_G
|
||||
///
|
||||
/// which we split out into:
|
||||
///
|
||||
/// h_G * -[sum(z_i * s_i)]P_G + sum([z_i]R_i) + sum([z_i * c_i]VK_i) = 0_G
|
||||
///
|
||||
/// so that we can use multiscalar multiplication speedups.
|
||||
///
|
||||
/// where for each signature i,
|
||||
/// - VK_i is the verification key;
|
||||
/// - R_i is the signature's R value;
|
||||
/// - s_i is the signature's s value;
|
||||
/// - c_i is the hash of the message and other data;
|
||||
/// - z_i is a random 128-bit Scalar;
|
||||
/// - h_G is the cofactor of the group;
|
||||
/// - P_G is the generator of the subgroup;
|
||||
///
|
||||
/// Since RedJubjub uses different subgroups for different types
|
||||
/// of signatures, SpendAuth's and Binding's, we need to have yet
|
||||
/// another point and associated scalar accumulator for all the
|
||||
/// signatures of each type in our batch, but we can still
|
||||
/// amortize computation nicely in one multiscalar multiplication:
|
||||
///
|
||||
/// h_G * ( [-sum(z_i * s_i): i_type == SpendAuth]P_SpendAuth + [-sum(z_i * s_i): i_type == Binding]P_Binding + sum([z_i]R_i) + sum([z_i * c_i]VK_i) ) = 0_G
|
||||
///
|
||||
/// As follows elliptic curve scalar multiplication convention,
|
||||
/// scalar variables are lowercase and group point variables
|
||||
/// are uppercase. This does not exactly match the RedDSA
|
||||
/// notation in the [protocol specification §B.1][ps].
|
||||
///
|
||||
/// [ps]: https://zips.z.cash/protocol/protocol.pdf#reddsabatchverify
|
||||
#[allow(non_snake_case)] |
||||
pub fn verify<R: RngCore + CryptoRng>(self, mut rng: R) -> Result<(), Error> { |
||||
let n = self.signatures.len(); |
||||
|
||||
let mut VK_coeffs = Vec::with_capacity(n); |
||||
let mut VKs = Vec::with_capacity(n); |
||||
let mut R_coeffs = Vec::with_capacity(self.signatures.len()); |
||||
let mut Rs = Vec::with_capacity(self.signatures.len()); |
||||
let mut P_spendauth_coeff = Scalar::zero(); |
||||
let mut P_binding_coeff = Scalar::zero(); |
||||
|
||||
for item in self.signatures.iter() { |
||||
let (s_bytes, r_bytes, c) = match item.inner { |
||||
Inner::SpendAuth { sig, c, .. } => (sig.s_bytes, sig.r_bytes, c), |
||||
Inner::Binding { sig, c, .. } => (sig.s_bytes, sig.r_bytes, c), |
||||
}; |
||||
|
||||
let s = { |
||||
// XXX-jubjub: should not use CtOption here
|
||||
let maybe_scalar = Scalar::from_bytes(&s_bytes); |
||||
if maybe_scalar.is_some().into() { |
||||
maybe_scalar.unwrap() |
||||
} else { |
||||
return Err(Error::InvalidSignature); |
||||
} |
||||
}; |
||||
|
||||
let R = { |
||||
// XXX-jubjub: should not use CtOption here
|
||||
// XXX-jubjub: inconsistent ownership in from_bytes
|
||||
let maybe_point = AffinePoint::from_bytes(r_bytes); |
||||
if maybe_point.is_some().into() { |
||||
jubjub::ExtendedPoint::from(maybe_point.unwrap()) |
||||
} else { |
||||
return Err(Error::InvalidSignature); |
||||
} |
||||
}; |
||||
|
||||
let VK = match item.inner { |
||||
Inner::SpendAuth { vk_bytes, .. } => { |
||||
VerificationKey::<SpendAuth>::try_from(vk_bytes.bytes)?.point |
||||
} |
||||
Inner::Binding { vk_bytes, .. } => { |
||||
VerificationKey::<Binding>::try_from(vk_bytes.bytes)?.point |
||||
} |
||||
}; |
||||
|
||||
let z = Scalar::from_raw(gen_128_bits(&mut rng)); |
||||
|
||||
let P_coeff = z * s; |
||||
match item.inner { |
||||
Inner::SpendAuth { .. } => { |
||||
P_spendauth_coeff -= P_coeff; |
||||
} |
||||
Inner::Binding { .. } => { |
||||
P_binding_coeff -= P_coeff; |
||||
} |
||||
}; |
||||
|
||||
R_coeffs.push(z); |
||||
Rs.push(R); |
||||
|
||||
VK_coeffs.push(Scalar::zero() + (z * c)); |
||||
VKs.push(VK); |
||||
} |
||||
|
||||
use std::iter::once; |
||||
|
||||
let scalars = once(&P_spendauth_coeff) |
||||
.chain(once(&P_binding_coeff)) |
||||
.chain(VK_coeffs.iter()) |
||||
.chain(R_coeffs.iter()); |
||||
|
||||
let basepoints = [SpendAuth::basepoint(), Binding::basepoint()]; |
||||
let points = basepoints.iter().chain(VKs.iter()).chain(Rs.iter()); |
||||
|
||||
let check = ExtendedPoint::vartime_multiscalar_mul(scalars, points); |
||||
|
||||
if check.is_small_order().into() { |
||||
Ok(()) |
||||
} else { |
||||
Err(Error::InvalidSignature) |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,185 @@ |
||||
use std::{borrow::Borrow, fmt::Debug}; |
||||
|
||||
use jubjub::*; |
||||
|
||||
use crate::Scalar; |
||||
|
||||
pub trait NonAdjacentForm { |
||||
fn non_adjacent_form(&self, w: usize) -> [i8; 256]; |
||||
} |
||||
|
||||
/// A trait for variable-time multiscalar multiplication without precomputation.
|
||||
pub trait VartimeMultiscalarMul { |
||||
/// The type of point being multiplied, e.g., `AffinePoint`.
|
||||
type Point; |
||||
|
||||
/// Given an iterator of public scalars and an iterator of
|
||||
/// `Option`s of points, compute either `Some(Q)`, where
|
||||
/// $$
|
||||
/// Q = c\_1 P\_1 + \cdots + c\_n P\_n,
|
||||
/// $$
|
||||
/// if all points were `Some(P_i)`, or else return `None`.
|
||||
fn optional_multiscalar_mul<I, J>(scalars: I, points: J) -> Option<Self::Point> |
||||
where |
||||
I: IntoIterator, |
||||
I::Item: Borrow<Scalar>, |
||||
J: IntoIterator<Item = Option<Self::Point>>; |
||||
|
||||
/// Given an iterator of public scalars and an iterator of
|
||||
/// public points, compute
|
||||
/// $$
|
||||
/// Q = c\_1 P\_1 + \cdots + c\_n P\_n,
|
||||
/// $$
|
||||
/// using variable-time operations.
|
||||
///
|
||||
/// It is an error to call this function with two iterators of different lengths.
|
||||
fn vartime_multiscalar_mul<I, J>(scalars: I, points: J) -> Self::Point |
||||
where |
||||
I: IntoIterator, |
||||
I::Item: Borrow<Scalar>, |
||||
J: IntoIterator, |
||||
J::Item: Borrow<Self::Point>, |
||||
Self::Point: Clone, |
||||
{ |
||||
Self::optional_multiscalar_mul( |
||||
scalars, |
||||
points.into_iter().map(|p| Some(p.borrow().clone())), |
||||
) |
||||
.unwrap() |
||||
} |
||||
} |
||||
|
||||
impl NonAdjacentForm for Scalar { |
||||
/// Compute a width-\\(w\\) "Non-Adjacent Form" of this scalar.
|
||||
///
|
||||
/// Thanks to curve25519-dalek
|
||||
fn non_adjacent_form(&self, w: usize) -> [i8; 256] { |
||||
// required by the NAF definition
|
||||
debug_assert!(w >= 2); |
||||
// required so that the NAF digits fit in i8
|
||||
debug_assert!(w <= 8); |
||||
|
||||
use byteorder::{ByteOrder, LittleEndian}; |
||||
|
||||
let mut naf = [0i8; 256]; |
||||
|
||||
let mut x_u64 = [0u64; 5]; |
||||
LittleEndian::read_u64_into(&self.to_bytes(), &mut x_u64[0..4]); |
||||
|
||||
let width = 1 << w; |
||||
let window_mask = width - 1; |
||||
|
||||
let mut pos = 0; |
||||
let mut carry = 0; |
||||
while pos < 256 { |
||||
// Construct a buffer of bits of the scalar, starting at bit `pos`
|
||||
let u64_idx = pos / 64; |
||||
let bit_idx = pos % 64; |
||||
let bit_buf: u64; |
||||
if bit_idx < 64 - w { |
||||
// This window's bits are contained in a single u64
|
||||
bit_buf = x_u64[u64_idx] >> bit_idx; |
||||
} else { |
||||
// Combine the current u64's bits with the bits from the next u64
|
||||
bit_buf = (x_u64[u64_idx] >> bit_idx) | (x_u64[1 + u64_idx] << (64 - bit_idx)); |
||||
} |
||||
|
||||
// Add the carry into the current window
|
||||
let window = carry + (bit_buf & window_mask); |
||||
|
||||
if window & 1 == 0 { |
||||
// If the window value is even, preserve the carry and continue.
|
||||
// Why is the carry preserved?
|
||||
// If carry == 0 and window & 1 == 0, then the next carry should be 0
|
||||
// If carry == 1 and window & 1 == 0, then bit_buf & 1 == 1 so the next carry should be 1
|
||||
pos += 1; |
||||
continue; |
||||
} |
||||
|
||||
if window < width / 2 { |
||||
carry = 0; |
||||
naf[pos] = window as i8; |
||||
} else { |
||||
carry = 1; |
||||
naf[pos] = (window as i8).wrapping_sub(width as i8); |
||||
} |
||||
|
||||
pos += w; |
||||
} |
||||
|
||||
naf |
||||
} |
||||
} |
||||
|
||||
/// Holds odd multiples 1A, 3A, ..., 15A of a point A.
|
||||
#[derive(Copy, Clone)] |
||||
pub(crate) struct LookupTable5<T>(pub(crate) [T; 8]); |
||||
|
||||
impl<T: Copy> LookupTable5<T> { |
||||
/// Given public, odd \\( x \\) with \\( 0 < x < 2^4 \\), return \\(xA\\).
|
||||
pub fn select(&self, x: usize) -> T { |
||||
debug_assert_eq!(x & 1, 1); |
||||
debug_assert!(x < 16); |
||||
|
||||
self.0[x / 2] |
||||
} |
||||
} |
||||
|
||||
impl<T: Debug> Debug for LookupTable5<T> { |
||||
fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { |
||||
write!(f, "LookupTable5({:?})", self.0) |
||||
} |
||||
} |
||||
|
||||
impl<'a> From<&'a ExtendedPoint> for LookupTable5<ExtendedNielsPoint> { |
||||
#[allow(non_snake_case)] |
||||
fn from(A: &'a ExtendedPoint) -> Self { |
||||
let mut Ai = [A.to_niels(); 8]; |
||||
let A2 = A.double(); |
||||
for i in 0..7 { |
||||
Ai[i + 1] = (&A2 + &Ai[i]).to_niels(); |
||||
} |
||||
// Now Ai = [A, 3A, 5A, 7A, 9A, 11A, 13A, 15A]
|
||||
LookupTable5(Ai) |
||||
} |
||||
} |
||||
|
||||
impl VartimeMultiscalarMul for ExtendedPoint { |
||||
type Point = ExtendedPoint; |
||||
|
||||
#[allow(non_snake_case)] |
||||
fn optional_multiscalar_mul<I, J>(scalars: I, points: J) -> Option<ExtendedPoint> |
||||
where |
||||
I: IntoIterator, |
||||
I::Item: Borrow<Scalar>, |
||||
J: IntoIterator<Item = Option<ExtendedPoint>>, |
||||
{ |
||||
let nafs: Vec<_> = scalars |
||||
.into_iter() |
||||
.map(|c| c.borrow().non_adjacent_form(5)) |
||||
.collect(); |
||||
|
||||
let lookup_tables = points |
||||
.into_iter() |
||||
.map(|P_opt| P_opt.map(|P| LookupTable5::<ExtendedNielsPoint>::from(&P))) |
||||
.collect::<Option<Vec<_>>>()?; |
||||
|
||||
let mut r = ExtendedPoint::identity(); |
||||
|
||||
for i in (0..256).rev() { |
||||
let mut t = r.double(); |
||||
|
||||
for (naf, lookup_table) in nafs.iter().zip(lookup_tables.iter()) { |
||||
if naf[i] > 0 { |
||||
t = &t + &lookup_table.select(naf[i] as usize); |
||||
} else if naf[i] < 0 { |
||||
t = &t - &lookup_table.select(-naf[i] as usize); |
||||
} |
||||
} |
||||
|
||||
r = t; |
||||
} |
||||
|
||||
Some(r) |
||||
} |
||||
} |
@ -0,0 +1,57 @@ |
||||
use rand::thread_rng; |
||||
|
||||
use redjubjub::*; |
||||
|
||||
#[test] |
||||
fn spendauth_batch_verify() { |
||||
let rng = thread_rng(); |
||||
let mut batch = batch::Verifier::new(); |
||||
for _ in 0..32 { |
||||
let sk = SigningKey::<SpendAuth>::new(rng); |
||||
let vk = VerificationKey::from(&sk); |
||||
let msg = b"BatchVerifyTest"; |
||||
let sig = sk.sign(rng, &msg[..]); |
||||
batch.queue((vk.into(), sig, msg)); |
||||
} |
||||
assert!(batch.verify(rng).is_ok()); |
||||
} |
||||
|
||||
#[test] |
||||
fn binding_batch_verify() { |
||||
let rng = thread_rng(); |
||||
let mut batch = batch::Verifier::new(); |
||||
for _ in 0..32 { |
||||
let sk = SigningKey::<SpendAuth>::new(rng); |
||||
let vk = VerificationKey::from(&sk); |
||||
let msg = b"BatchVerifyTest"; |
||||
let sig = sk.sign(rng, &msg[..]); |
||||
batch.queue((vk.into(), sig, msg)); |
||||
} |
||||
assert!(batch.verify(rng).is_ok()); |
||||
} |
||||
|
||||
#[test] |
||||
fn alternating_batch_verify() { |
||||
let rng = thread_rng(); |
||||
let mut batch = batch::Verifier::new(); |
||||
for i in 0..32 { |
||||
match i % 2 { |
||||
0 => { |
||||
let sk = SigningKey::<SpendAuth>::new(rng); |
||||
let vk = VerificationKey::from(&sk); |
||||
let msg = b"BatchVerifyTest"; |
||||
let sig = sk.sign(rng, &msg[..]); |
||||
batch.queue((vk.into(), sig, msg)); |
||||
} |
||||
1 => { |
||||
let sk = SigningKey::<Binding>::new(rng); |
||||
let vk = VerificationKey::from(&sk); |
||||
let msg = b"BatchVerifyTest"; |
||||
let sig = sk.sign(rng, &msg[..]); |
||||
batch.queue((vk.into(), sig, msg)); |
||||
} |
||||
_ => panic!(), |
||||
} |
||||
} |
||||
assert!(batch.verify(rng).is_ok()); |
||||
} |
Loading…
Reference in new issue