Optimized batch verification (#36)

* Pulls in some traits and methods from curve25519-dalek around the
vartime multiscalar multiplication.

* Move scalar mul things we want to upstream to jubjub to their own crate

* Make Verify agnostic to the SigType

Co-authored-by: Henry de Valence <hdevalence@hdevalence.ca>
Co-authored-by: Jane Lusby <jlusby42@gmail.com>
This commit is contained in:
Deirdre Connolly 2020-07-03 18:23:28 -04:00 committed by GitHub
parent f27b9c3c77
commit ba256655dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 559 additions and 11 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
/target
**/*.rs.bk
Cargo.lock
*~

View File

@ -18,19 +18,26 @@ description = "A mostly-standalone implementation of the RedJubjub signature sch
features = ["nightly"]
[dependencies]
rand_core = "0.5"
thiserror = "1.0"
blake2b_simd = "0.5"
byteorder = "1.3"
digest = "0.9"
jubjub = "0.3"
rand_core = "0.5"
serde = { version = "1", optional = true, features = ["derive"] }
thiserror = "1.0"
[dev-dependencies]
bincode = "1"
criterion = "0.3"
lazy_static = "1.4"
proptest = "0.10"
rand = "0.7"
rand_chacha = "0.2"
proptest = "0.10"
lazy_static = "1.4"
bincode = "1"
[features]
nightly = []
default = ["serde"]
[[bench]]
name = "bench"
harness = false

55
benches/bench.rs Normal file
View File

@ -0,0 +1,55 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use rand::thread_rng;
use redjubjub::*;
use std::convert::TryFrom;
fn sigs_with_distinct_keys<T: SigType>(
) -> impl Iterator<Item = (VerificationKeyBytes<T>, Signature<T>)> {
std::iter::repeat_with(|| {
let sk = SigningKey::<T>::new(thread_rng());
let vk_bytes = VerificationKey::from(&sk).into();
let sig = sk.sign(thread_rng(), b"");
(vk_bytes, sig)
})
}
fn bench_batch_verify(c: &mut Criterion) {
let mut group = c.benchmark_group("Batch Verification");
for &n in [8usize, 16, 24, 32, 40, 48, 56, 64].iter() {
group.throughput(Throughput::Elements(*n as u64));
let sigs = sigs_with_distinct_keys().take(*n).collect::<Vec<_>>();
group.bench_with_input(
BenchmarkId::new("Unbatched verification", n),
&sigs,
|b, sigs| {
b.iter(|| {
for (vk_bytes, sig) in sigs.iter() {
let _ =
VerificationKey::try_from(*vk_bytes).and_then(|vk| vk.verify(b"", sig));
}
})
},
);
group.bench_with_input(
BenchmarkId::new("Batched verification", n),
&sigs,
|b, sigs| {
b.iter(|| {
let mut batch = batch::Verifier::<SpendAuth>::new();
for (vk_bytes, sig) in sigs.iter().cloned() {
batch.queue((vk_bytes, sig, b""));
}
batch.verify(thread_rng())
})
},
);
}
group.finish();
}
criterion_group!(benches, bench_batch_verify);
criterion_main!(benches);

231
src/batch.rs Normal file
View File

@ -0,0 +1,231 @@
//! Performs batch RedJubjub signature verification.
//!
//! Batch verification asks whether *all* signatures in some set are valid,
//! rather than asking whether *each* of them is valid. This allows sharing
//! computations among all signature verifications, performing less work overall
//! at the cost of higher latency (the entire batch must complete), complexity of
//! caller code (which must assemble a batch of signatures across work-items),
//! and loss of the ability to easily pinpoint failing signatures.
//!
use std::convert::TryFrom;
use jubjub::*;
use rand_core::{CryptoRng, RngCore};
use crate::{private::Sealed, scalar_mul::VartimeMultiscalarMul, *};
// Shim to generate a random 128bit value in a [u64; 4], without
// importing `rand`.
fn gen_128_bits<R: RngCore + CryptoRng>(mut rng: R) -> [u64; 4] {
let mut bytes = [0u64; 4];
bytes[0] = rng.next_u64();
bytes[1] = rng.next_u64();
bytes
}
enum Inner {
SpendAuth {
vk_bytes: VerificationKeyBytes<SpendAuth>,
sig: Signature<SpendAuth>,
c: Scalar,
},
Binding {
vk_bytes: VerificationKeyBytes<Binding>,
sig: Signature<Binding>,
c: Scalar,
},
}
/// A batch verification item.
///
/// This struct exists to allow batch processing to be decoupled from the
/// lifetime of the message. This is useful when using the batch verification API
/// in an async context.
pub struct Item {
inner: Inner,
}
impl<'msg, M: AsRef<[u8]>>
From<(
VerificationKeyBytes<SpendAuth>,
Signature<SpendAuth>,
&'msg M,
)> for Item
{
fn from(
(vk_bytes, sig, msg): (
VerificationKeyBytes<SpendAuth>,
Signature<SpendAuth>,
&'msg M,
),
) -> Self {
// Compute c now to avoid dependency on the msg lifetime.
let c = HStar::default()
.update(&sig.r_bytes[..])
.update(&vk_bytes.bytes[..])
.update(msg)
.finalize();
Self {
inner: Inner::SpendAuth { vk_bytes, sig, c },
}
}
}
impl<'msg, M: AsRef<[u8]>> From<(VerificationKeyBytes<Binding>, Signature<Binding>, &'msg M)>
for Item
{
fn from(
(vk_bytes, sig, msg): (VerificationKeyBytes<Binding>, Signature<Binding>, &'msg M),
) -> Self {
// Compute c now to avoid dependency on the msg lifetime.
let c = HStar::default()
.update(&sig.r_bytes[..])
.update(&vk_bytes.bytes[..])
.update(msg)
.finalize();
Self {
inner: Inner::Binding { vk_bytes, sig, c },
}
}
}
#[derive(Default)]
/// A batch verification context.
pub struct Verifier {
/// Signature data queued for verification.
signatures: Vec<Item>,
}
impl Verifier {
/// Construct a new batch verifier.
pub fn new() -> Verifier {
Verifier::default()
}
/// Queue an Item for verification.
pub fn queue<I: Into<Item>>(&mut self, item: I) {
self.signatures.push(item.into());
}
/// Perform batch verification, returning `Ok(())` if all signatures were
/// valid and `Err` otherwise.
///
/// The batch verification equation is:
///
/// h_G * -[sum(z_i * s_i)]P_G + sum([z_i]R_i + [z_i * c_i]VK_i) = 0_G
///
/// which we split out into:
///
/// h_G * -[sum(z_i * s_i)]P_G + sum([z_i]R_i) + sum([z_i * c_i]VK_i) = 0_G
///
/// so that we can use multiscalar multiplication speedups.
///
/// where for each signature i,
/// - VK_i is the verification key;
/// - R_i is the signature's R value;
/// - s_i is the signature's s value;
/// - c_i is the hash of the message and other data;
/// - z_i is a random 128-bit Scalar;
/// - h_G is the cofactor of the group;
/// - P_G is the generator of the subgroup;
///
/// Since RedJubjub uses different subgroups for different types
/// of signatures, SpendAuth's and Binding's, we need to have yet
/// another point and associated scalar accumulator for all the
/// signatures of each type in our batch, but we can still
/// amortize computation nicely in one multiscalar multiplication:
///
/// h_G * ( [-sum(z_i * s_i): i_type == SpendAuth]P_SpendAuth + [-sum(z_i * s_i): i_type == Binding]P_Binding + sum([z_i]R_i) + sum([z_i * c_i]VK_i) ) = 0_G
///
/// As follows elliptic curve scalar multiplication convention,
/// scalar variables are lowercase and group point variables
/// are uppercase. This does not exactly match the RedDSA
/// notation in the [protocol specification §B.1][ps].
///
/// [ps]: https://zips.z.cash/protocol/protocol.pdf#reddsabatchverify
#[allow(non_snake_case)]
pub fn verify<R: RngCore + CryptoRng>(self, mut rng: R) -> Result<(), Error> {
let n = self.signatures.len();
let mut VK_coeffs = Vec::with_capacity(n);
let mut VKs = Vec::with_capacity(n);
let mut R_coeffs = Vec::with_capacity(self.signatures.len());
let mut Rs = Vec::with_capacity(self.signatures.len());
let mut P_spendauth_coeff = Scalar::zero();
let mut P_binding_coeff = Scalar::zero();
for item in self.signatures.iter() {
let (s_bytes, r_bytes, c) = match item.inner {
Inner::SpendAuth { sig, c, .. } => (sig.s_bytes, sig.r_bytes, c),
Inner::Binding { sig, c, .. } => (sig.s_bytes, sig.r_bytes, c),
};
let s = {
// XXX-jubjub: should not use CtOption here
let maybe_scalar = Scalar::from_bytes(&s_bytes);
if maybe_scalar.is_some().into() {
maybe_scalar.unwrap()
} else {
return Err(Error::InvalidSignature);
}
};
let R = {
// XXX-jubjub: should not use CtOption here
// XXX-jubjub: inconsistent ownership in from_bytes
let maybe_point = AffinePoint::from_bytes(r_bytes);
if maybe_point.is_some().into() {
jubjub::ExtendedPoint::from(maybe_point.unwrap())
} else {
return Err(Error::InvalidSignature);
}
};
let VK = match item.inner {
Inner::SpendAuth { vk_bytes, .. } => {
VerificationKey::<SpendAuth>::try_from(vk_bytes.bytes)?.point
}
Inner::Binding { vk_bytes, .. } => {
VerificationKey::<Binding>::try_from(vk_bytes.bytes)?.point
}
};
let z = Scalar::from_raw(gen_128_bits(&mut rng));
let P_coeff = z * s;
match item.inner {
Inner::SpendAuth { .. } => {
P_spendauth_coeff -= P_coeff;
}
Inner::Binding { .. } => {
P_binding_coeff -= P_coeff;
}
};
R_coeffs.push(z);
Rs.push(R);
VK_coeffs.push(Scalar::zero() + (z * c));
VKs.push(VK);
}
use std::iter::once;
let scalars = once(&P_spendauth_coeff)
.chain(once(&P_binding_coeff))
.chain(VK_coeffs.iter())
.chain(R_coeffs.iter());
let basepoints = [SpendAuth::basepoint(), Binding::basepoint()];
let points = basepoints.iter().chain(VKs.iter()).chain(Rs.iter());
let check = ExtendedPoint::vartime_multiscalar_mul(scalars, points);
if check.is_small_order().into() {
Ok(())
} else {
Err(Error::InvalidSignature)
}
}
}

View File

@ -1,6 +1,5 @@
use blake2b_simd::{Params, State};
use crate::Scalar;
use blake2b_simd::{Params, State};
/// Provides H^star, the hash-to-scalar function used by RedJubjub.
pub struct HStar {
@ -19,13 +18,13 @@ impl Default for HStar {
impl HStar {
/// Add `data` to the hash, and return `Self` for chaining.
pub fn update(mut self, data: &[u8]) -> Self {
self.state.update(data);
pub fn update(&mut self, data: impl AsRef<[u8]>) -> &mut Self {
self.state.update(data.as_ref());
self
}
/// Consume `self` to compute the hash output.
pub fn finalize(self) -> Scalar {
pub fn finalize(&self) -> Scalar {
Scalar::from_bytes_wide(self.state.finalize().as_array())
}
}

View File

@ -5,9 +5,11 @@
//! Docs require the `nightly` feature until RFC 1990 lands.
pub mod batch;
mod constants;
mod error;
mod hash;
mod scalar_mul;
mod signature;
mod signing_key;
mod verification_key;

185
src/scalar_mul.rs Normal file
View File

@ -0,0 +1,185 @@
use std::{borrow::Borrow, fmt::Debug};
use jubjub::*;
use crate::Scalar;
pub trait NonAdjacentForm {
fn non_adjacent_form(&self, w: usize) -> [i8; 256];
}
/// A trait for variable-time multiscalar multiplication without precomputation.
pub trait VartimeMultiscalarMul {
/// The type of point being multiplied, e.g., `AffinePoint`.
type Point;
/// Given an iterator of public scalars and an iterator of
/// `Option`s of points, compute either `Some(Q)`, where
/// $$
/// Q = c\_1 P\_1 + \cdots + c\_n P\_n,
/// $$
/// if all points were `Some(P_i)`, or else return `None`.
fn optional_multiscalar_mul<I, J>(scalars: I, points: J) -> Option<Self::Point>
where
I: IntoIterator,
I::Item: Borrow<Scalar>,
J: IntoIterator<Item = Option<Self::Point>>;
/// Given an iterator of public scalars and an iterator of
/// public points, compute
/// $$
/// Q = c\_1 P\_1 + \cdots + c\_n P\_n,
/// $$
/// using variable-time operations.
///
/// It is an error to call this function with two iterators of different lengths.
fn vartime_multiscalar_mul<I, J>(scalars: I, points: J) -> Self::Point
where
I: IntoIterator,
I::Item: Borrow<Scalar>,
J: IntoIterator,
J::Item: Borrow<Self::Point>,
Self::Point: Clone,
{
Self::optional_multiscalar_mul(
scalars,
points.into_iter().map(|p| Some(p.borrow().clone())),
)
.unwrap()
}
}
impl NonAdjacentForm for Scalar {
/// Compute a width-\\(w\\) "Non-Adjacent Form" of this scalar.
///
/// Thanks to curve25519-dalek
fn non_adjacent_form(&self, w: usize) -> [i8; 256] {
// required by the NAF definition
debug_assert!(w >= 2);
// required so that the NAF digits fit in i8
debug_assert!(w <= 8);
use byteorder::{ByteOrder, LittleEndian};
let mut naf = [0i8; 256];
let mut x_u64 = [0u64; 5];
LittleEndian::read_u64_into(&self.to_bytes(), &mut x_u64[0..4]);
let width = 1 << w;
let window_mask = width - 1;
let mut pos = 0;
let mut carry = 0;
while pos < 256 {
// Construct a buffer of bits of the scalar, starting at bit `pos`
let u64_idx = pos / 64;
let bit_idx = pos % 64;
let bit_buf: u64;
if bit_idx < 64 - w {
// This window's bits are contained in a single u64
bit_buf = x_u64[u64_idx] >> bit_idx;
} else {
// Combine the current u64's bits with the bits from the next u64
bit_buf = (x_u64[u64_idx] >> bit_idx) | (x_u64[1 + u64_idx] << (64 - bit_idx));
}
// Add the carry into the current window
let window = carry + (bit_buf & window_mask);
if window & 1 == 0 {
// If the window value is even, preserve the carry and continue.
// Why is the carry preserved?
// If carry == 0 and window & 1 == 0, then the next carry should be 0
// If carry == 1 and window & 1 == 0, then bit_buf & 1 == 1 so the next carry should be 1
pos += 1;
continue;
}
if window < width / 2 {
carry = 0;
naf[pos] = window as i8;
} else {
carry = 1;
naf[pos] = (window as i8).wrapping_sub(width as i8);
}
pos += w;
}
naf
}
}
/// Holds odd multiples 1A, 3A, ..., 15A of a point A.
#[derive(Copy, Clone)]
pub(crate) struct LookupTable5<T>(pub(crate) [T; 8]);
impl<T: Copy> LookupTable5<T> {
/// Given public, odd \\( x \\) with \\( 0 < x < 2^4 \\), return \\(xA\\).
pub fn select(&self, x: usize) -> T {
debug_assert_eq!(x & 1, 1);
debug_assert!(x < 16);
self.0[x / 2]
}
}
impl<T: Debug> Debug for LookupTable5<T> {
fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
write!(f, "LookupTable5({:?})", self.0)
}
}
impl<'a> From<&'a ExtendedPoint> for LookupTable5<ExtendedNielsPoint> {
#[allow(non_snake_case)]
fn from(A: &'a ExtendedPoint) -> Self {
let mut Ai = [A.to_niels(); 8];
let A2 = A.double();
for i in 0..7 {
Ai[i + 1] = (&A2 + &Ai[i]).to_niels();
}
// Now Ai = [A, 3A, 5A, 7A, 9A, 11A, 13A, 15A]
LookupTable5(Ai)
}
}
impl VartimeMultiscalarMul for ExtendedPoint {
type Point = ExtendedPoint;
#[allow(non_snake_case)]
fn optional_multiscalar_mul<I, J>(scalars: I, points: J) -> Option<ExtendedPoint>
where
I: IntoIterator,
I::Item: Borrow<Scalar>,
J: IntoIterator<Item = Option<ExtendedPoint>>,
{
let nafs: Vec<_> = scalars
.into_iter()
.map(|c| c.borrow().non_adjacent_form(5))
.collect();
let lookup_tables = points
.into_iter()
.map(|P_opt| P_opt.map(|P| LookupTable5::<ExtendedNielsPoint>::from(&P)))
.collect::<Option<Vec<_>>>()?;
let mut r = ExtendedPoint::identity();
for i in (0..256).rev() {
let mut t = r.double();
for (naf, lookup_table) in nafs.iter().zip(lookup_tables.iter()) {
if naf[i] > 0 {
t = &t + &lookup_table.select(naf[i] as usize);
} else if naf[i] < 0 {
t = &t - &lookup_table.select(-naf[i] as usize);
}
}
r = t;
}
Some(r)
}
}

View File

@ -1,4 +1,8 @@
use std::{convert::TryFrom, marker::PhantomData};
use std::{
convert::TryFrom,
hash::{Hash, Hasher},
marker::PhantomData,
};
use crate::{Error, Randomizer, Scalar, SigType, Signature, SpendAuth};
@ -30,6 +34,13 @@ impl<T: SigType> From<VerificationKeyBytes<T>> for [u8; 32] {
}
}
impl<T: SigType> Hash for VerificationKeyBytes<T> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.bytes.hash(state);
self._marker.hash(state);
}
}
/// A valid RedJubJub verification key.
///
/// This type holds decompressed state used in signature verification; if the

57
tests/batch.rs Normal file
View File

@ -0,0 +1,57 @@
use rand::thread_rng;
use redjubjub::*;
#[test]
fn spendauth_batch_verify() {
let rng = thread_rng();
let mut batch = batch::Verifier::new();
for _ in 0..32 {
let sk = SigningKey::<SpendAuth>::new(rng);
let vk = VerificationKey::from(&sk);
let msg = b"BatchVerifyTest";
let sig = sk.sign(rng, &msg[..]);
batch.queue((vk.into(), sig, msg));
}
assert!(batch.verify(rng).is_ok());
}
#[test]
fn binding_batch_verify() {
let rng = thread_rng();
let mut batch = batch::Verifier::new();
for _ in 0..32 {
let sk = SigningKey::<SpendAuth>::new(rng);
let vk = VerificationKey::from(&sk);
let msg = b"BatchVerifyTest";
let sig = sk.sign(rng, &msg[..]);
batch.queue((vk.into(), sig, msg));
}
assert!(batch.verify(rng).is_ok());
}
#[test]
fn alternating_batch_verify() {
let rng = thread_rng();
let mut batch = batch::Verifier::new();
for i in 0..32 {
match i % 2 {
0 => {
let sk = SigningKey::<SpendAuth>::new(rng);
let vk = VerificationKey::from(&sk);
let msg = b"BatchVerifyTest";
let sig = sk.sign(rng, &msg[..]);
batch.queue((vk.into(), sig, msg));
}
1 => {
let sk = SigningKey::<Binding>::new(rng);
let vk = VerificationKey::from(&sk);
let msg = b"BatchVerifyTest";
let sig = sk.sign(rng, &msg[..]);
batch.queue((vk.into(), sig, msg));
}
_ => panic!(),
}
}
assert!(batch.verify(rng).is_ok());
}