From ff8f9eb20ecc04893f4d75705c4a1cab84ec3ea3 Mon Sep 17 00:00:00 2001 From: Sean Bowe Date: Sun, 6 Sep 2020 12:24:55 -0600 Subject: [PATCH] Reduce number of inversions by batch inverting when possible. --- src/arithmetic.rs | 37 +++++++++++++++++++++++++++++++ src/plonk/prover.rs | 54 ++++++++++++++++++++++++++++++++++----------- 2 files changed, 78 insertions(+), 13 deletions(-) diff --git a/src/arithmetic.rs b/src/arithmetic.rs index 49f13c81..312cfe57 100644 --- a/src/arithmetic.rs +++ b/src/arithmetic.rs @@ -32,6 +32,43 @@ pub trait Group: Copy + Clone + Send + Sync + 'static { fn group_scale(&mut self, by: &Self::Scalar); } +/// Extension trait for iterators over mutable field elements which allows those +/// field elements to be inverted in a batch. +pub trait BatchInvert { + /// Consume this iterator and invert each field element (when nonzero), + /// returning the inverse of all nonzero field elements. + fn batch_invert(self) -> F; +} + +impl<'a, F, I> BatchInvert for I +where + F: Field, + I: IntoIterator, +{ + fn batch_invert(self) -> F { + let mut acc = F::one(); + let mut iter = self.into_iter(); + let mut tmp = Vec::with_capacity(iter.size_hint().0); + while let Some(p) = iter.next() { + let q = *p; + tmp.push((acc, p)); + acc = F::conditional_select(&(acc * q), &acc, q.is_zero()); + } + acc = acc.invert().unwrap(); + let allinv = acc; + + for (tmp, p) in tmp.into_iter().rev() { + let skip = p.is_zero(); + + let tmp = tmp * acc; + acc = F::conditional_select(&(acc * *p), &acc, skip); + *p = F::conditional_select(&tmp, p, skip); + } + + allinv + } +} + /// This is a 128-bit verifier challenge. #[derive(Copy, Clone, Debug)] pub struct Challenge(pub(crate) u128); diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index 42367b46..5eb49210 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -4,8 +4,8 @@ use super::{ hash_point, Error, Proof, SRS, }; use crate::arithmetic::{ - eval_polynomial, get_challenge_scalar, kate_division, parallelize, Challenge, Curve, - CurveAffine, Field, + eval_polynomial, get_challenge_scalar, kate_division, parallelize, BatchInvert, Challenge, + Curve, CurveAffine, Field, }; use crate::polycommit::Params; use crate::transcript::Hasher; @@ -83,12 +83,16 @@ impl Proof { // Compute commitments to advice wire polynomials let advice_blinds: Vec<_> = witness.advice.iter().map(|_| C::Scalar::random()).collect(); - let advice_commitments = witness + let advice_commitments_projective: Vec<_> = witness .advice .iter() .zip(advice_blinds.iter()) - .map(|(poly, blind)| params.commit_lagrange(poly, *blind).to_affine()) + .map(|(poly, blind)| params.commit_lagrange(poly, *blind)) .collect(); + let mut advice_commitments = vec![C::zero(); advice_commitments_projective.len()]; + C::Projective::batch_to_affine(&advice_commitments_projective, &mut advice_commitments); + let advice_commitments = advice_commitments; + drop(advice_commitments_projective); for commitment in &advice_commitments { hash_point(&mut transcript, commitment)?; @@ -122,10 +126,11 @@ impl Proof { let mut permutation_product_polys = vec![]; let mut permutation_product_cosets = vec![]; let mut permutation_product_cosets_inv = vec![]; - let mut permutation_product_commitments = vec![]; + let mut permutation_product_commitments_projective = vec![]; let mut permutation_product_blinds = vec![]; // Iterate over each permutation + let mut permutation_modified_advice = vec![]; for (wires, permuted_values) in srs.meta.permutations.iter().zip(srs.permutations.iter()) { // Goal is to compute the fraction // @@ -154,12 +159,23 @@ impl Proof { modified_advice.push(tmp_advice_values); } - // Batch invert to obtain the denominators for the permutation product - // polynomial - for v in &mut modified_advice { - C::Scalar::batch_invert(v); - } + permutation_modified_advice.push(modified_advice); + } + // Batch invert to obtain the denominators for the permutation product + // polynomials + permutation_modified_advice + .iter_mut() + .flat_map(|v| v.iter_mut()) + .flat_map(|v| v.iter_mut()) + .batch_invert(); + + for (wires, mut modified_advice) in srs + .meta + .permutations + .iter() + .zip(permutation_modified_advice.into_iter()) + { // Iterate over each wire again, this time finishing the computation // of the entire fraction by computing the numerators let mut deltaomega = C::Scalar::one(); @@ -210,13 +226,21 @@ impl Proof { let blind = C::Scalar::random(); - permutation_product_commitments.push(params.commit_lagrange(&z, blind).to_affine()); + permutation_product_commitments_projective.push(params.commit_lagrange(&z, blind)); permutation_product_blinds.push(blind); let z = domain.obtain_poly(z); permutation_product_polys.push(z.clone()); permutation_product_cosets.push(domain.obtain_coset(z.clone(), Rotation::default())); permutation_product_cosets_inv.push(domain.obtain_coset(z, Rotation(-1))); } + let mut permutation_product_commitments = + vec![C::zero(); permutation_product_commitments_projective.len()]; + C::Projective::batch_to_affine( + &permutation_product_commitments_projective, + &mut permutation_product_commitments, + ); + let permutation_product_commitments = permutation_product_commitments; + drop(permutation_product_commitments_projective); // Hash each permutation product commitment for c in &permutation_product_commitments { @@ -351,11 +375,15 @@ impl Proof { let h_blinds: Vec<_> = h_pieces.iter().map(|_| C::Scalar::random()).collect(); // Compute commitments to each h(X) piece - let h_commitments: Vec<_> = h_pieces + let h_commitments_projective: Vec<_> = h_pieces .iter() .zip(h_blinds.iter()) - .map(|(h_piece, blind)| params.commit(&h_piece, *blind).to_affine()) + .map(|(h_piece, blind)| params.commit(&h_piece, *blind)) .collect(); + let mut h_commitments = vec![C::zero(); h_commitments_projective.len()]; + C::Projective::batch_to_affine(&h_commitments_projective, &mut h_commitments); + let h_commitments = h_commitments; + drop(h_commitments_projective); // Hash each h(X) piece for c in h_commitments.iter() {