From e275d78c7d1b5b834a94384d25f9fe5129de2fa2 Mon Sep 17 00:00:00 2001 From: Sean Bowe Date: Tue, 29 Sep 2020 08:51:00 -0600 Subject: [PATCH 1/5] Simplify permutations field of ConstraintSystem Co-authored-by: therealyingtong --- src/plonk/circuit.rs | 34 ++++++++++++++++++++-------------- src/plonk/prover.rs | 11 +++++++---- src/plonk/verifier.rs | 4 ++-- 3 files changed, 29 insertions(+), 20 deletions(-) diff --git a/src/plonk/circuit.rs b/src/plonk/circuit.rs index 5330a1b3..c12e386a 100644 --- a/src/plonk/circuit.rs +++ b/src/plonk/circuit.rs @@ -170,13 +170,8 @@ pub struct ConstraintSystem { pub(crate) rotations: BTreeMap, // Vector of permutation arguments, where each corresponds to a set of wires - // that are involved in a permutation argument, as well as the corresponding - // query index for each wire. As an example, we could have a permutation - // argument between wires (A, B, C) which allows copy constraints to be - // enforced between advice wire values in A, B and C, and another - // permutation between wires (B, C, D) which allows the same with D instead - // of A. - pub(crate) permutations: Vec>, + // that are involved in a permutation argument. + pub(crate) permutations: Vec>, } impl Default for ConstraintSystem { @@ -202,16 +197,16 @@ impl ConstraintSystem { /// Add a permutation argument for some advice wires pub fn permutation(&mut self, wires: &[AdviceWire]) -> usize { let index = self.permutations.len(); - if index == 0 { + if self.permutations.is_empty() { let at = Rotation(-1); let len = self.rotations.len(); self.rotations.entry(at).or_insert(PointIndex(len)); } - let wires = wires - .iter() - .map(|&wire| (wire, self.query_advice_index(wire, 0))) - .collect(); - self.permutations.push(wires); + + for wire in wires { + self.query_advice_index(*wire, 0); + } + self.permutations.push(wires.to_vec()); index } @@ -242,7 +237,18 @@ impl ConstraintSystem { Expression::Fixed(self.query_fixed_index(wire, at)) } - fn query_advice_index(&mut self, wire: AdviceWire, at: i32) -> usize { + pub(crate) fn get_advice_query_index(&self, wire: AdviceWire, at: i32) -> usize { + let at = Rotation(at); + for (index, advice_query) in self.advice_queries.iter().enumerate() { + if advice_query == &(wire, at) { + return index; + } + } + + panic!("get_advice_query_index called for non-existant query"); + } + + pub(crate) fn query_advice_index(&mut self, wire: AdviceWire, at: i32) -> usize { let at = Rotation(at); { let len = self.rotations.len(); diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index 215094ce..0805238e 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -187,7 +187,7 @@ impl Proof { let mut modified_advice = vec![C::Scalar::one(); params.n as usize]; // Iterate over each wire of the permutation - for (&(wire, _), permuted_wire_values) in wires.iter().zip(permuted_values.iter()) { + for (&wire, permuted_wire_values) in wires.iter().zip(permuted_values.iter()) { parallelize(&mut modified_advice, |modified_advice, start| { for ((modified_advice, advice_value), permuted_advice_value) in modified_advice .iter_mut() @@ -219,7 +219,7 @@ impl Proof { // Iterate over each wire again, this time finishing the computation // of the entire fraction by computing the numerators let mut deltaomega = C::Scalar::one(); - for &(wire, _) in wires.iter() { + for &wire in wires.iter() { let omega = domain.get_omega(); parallelize(&mut modified_advice, |modified_advice, start| { let mut deltaomega = deltaomega * &omega.pow_vartime(&[start as u64, 0, 0, 0]); @@ -320,7 +320,7 @@ impl Proof { let mut left = permutation_product_cosets[permutation_index].clone(); for (advice, permutation) in wires .iter() - .map(|&(_, index)| &advice_cosets[index]) + .map(|&wire| &advice_cosets[pk.vk.cs.get_advice_query_index(wire, 0)]) .zip(pk.permutation_cosets[permutation_index].iter()) { parallelize(&mut left, |left, start| { @@ -337,7 +337,10 @@ impl Proof { let mut right = permutation_product_cosets_inv[permutation_index].clone(); let mut current_delta = x_0 * &C::Scalar::ZETA; let step = domain.get_extended_omega(); - for advice in wires.iter().map(|&(_, index)| &advice_cosets[index]) { + for advice in wires + .iter() + .map(|&wire| &advice_cosets[pk.vk.cs.get_advice_query_index(wire, 0)]) + { parallelize(&mut right, move |right, start| { let mut beta_term = current_delta * &step.pow_vartime(&[start as u64, 0, 0, 0]); for (right, advice) in right.iter_mut().zip(advice[start..].iter()) { diff --git a/src/plonk/verifier.rs b/src/plonk/verifier.rs index 642a00b4..de910432 100644 --- a/src/plonk/verifier.rs +++ b/src/plonk/verifier.rs @@ -128,7 +128,7 @@ impl<'a, C: CurveAffine> Proof { let mut left = self.permutation_product_evals[permutation_index]; for (advice_eval, permutation_eval) in wires .iter() - .map(|&(_, query_index)| self.advice_evals[query_index]) + .map(|&wire| self.advice_evals[vk.cs.get_advice_query_index(wire, 0)]) .zip(self.permutation_evals[permutation_index].iter()) { left *= &(advice_eval + &(x_0 * permutation_eval) + &x_1); @@ -138,7 +138,7 @@ impl<'a, C: CurveAffine> Proof { let mut current_delta = x_0 * &x_3; for advice_eval in wires .iter() - .map(|&(_, query_index)| self.advice_evals[query_index]) + .map(|&wire| self.advice_evals[vk.cs.get_advice_query_index(wire, 0)]) { right *= &(advice_eval + ¤t_delta + &x_1); current_delta *= &C::Scalar::DELTA; From 7d8daa5d0560daae8aca87f13bd90260e7ea84ef Mon Sep 17 00:00:00 2001 From: Sean Bowe Date: Tue, 29 Sep 2020 16:56:21 -0600 Subject: [PATCH 2/5] Refactor h_eval computation into separate, more functional code. Co-authored-by: str4d --- src/plonk/verifier.rs | 173 +++++++++++++++++++++++------------------- 1 file changed, 95 insertions(+), 78 deletions(-) diff --git a/src/plonk/verifier.rs b/src/plonk/verifier.rs index de910432..4ee3e378 100644 --- a/src/plonk/verifier.rs +++ b/src/plonk/verifier.rs @@ -63,7 +63,10 @@ impl<'a, C: CurveAffine> Proof { // Sample x_3 challenge, which is used to ensure the circuit is // satisfied with high probability. let x_3: C::Scalar = get_challenge_scalar(Challenge(transcript.squeeze().get_lower_128())); - let x_3n = x_3.pow(&[params.n as u64, 0, 0, 0]); + + // This check ensures the circuit is satisfied so long as the polynomial + // commitments open to the correct values. + self.check_hx(params, vk, x_0, x_1, x_2, x_3)?; // Hash together all the openings provided by the prover into a new // transcript on the scalar field. @@ -86,83 +89,6 @@ impl<'a, C: CurveAffine> Proof { C::Base::from_bytes(&(transcript_scalar.squeeze()).to_bytes()).unwrap(); transcript.absorb(transcript_scalar_point); - // Evaluate the circuit using the custom gates provided - let mut h_eval = C::Scalar::zero(); - for poly in vk.cs.gates.iter() { - h_eval *= &x_2; - - let evaluation: C::Scalar = poly.evaluate( - &|index| self.fixed_evals[index], - &|index| self.advice_evals[index], - &|index| self.aux_evals[index], - &|a, b| a + &b, - &|a, b| a * &b, - &|a, scalar| a * &scalar, - ); - - h_eval += &evaluation; - } - - // First element in each permutation product should be 1 - // l_0(X) * (1 - z(X)) = 0 - { - // TODO: bubble this error up - let denominator = (x_3 - &C::Scalar::one()).invert().unwrap(); - - for eval in self.permutation_product_evals.iter() { - h_eval *= &x_2; - - let mut tmp = denominator; // 1 / (x_3 - 1) - tmp *= &(x_3n - &C::Scalar::one()); // (x_3^n - 1) / (x_3 - 1) - tmp *= &vk.domain.get_barycentric_weight(); // l_0(x_3) - tmp *= &(C::Scalar::one() - &eval); // l_0(X) * (1 - z(X)) - - h_eval += &tmp; - } - } - - // z(X) \prod (p(X) + \beta s_i(X) + \gamma) - z(omega^{-1} X) \prod (p(X) + \delta^i \beta X + \gamma) - for (permutation_index, wires) in vk.cs.permutations.iter().enumerate() { - h_eval *= &x_2; - - let mut left = self.permutation_product_evals[permutation_index]; - for (advice_eval, permutation_eval) in wires - .iter() - .map(|&wire| self.advice_evals[vk.cs.get_advice_query_index(wire, 0)]) - .zip(self.permutation_evals[permutation_index].iter()) - { - left *= &(advice_eval + &(x_0 * permutation_eval) + &x_1); - } - - let mut right = self.permutation_product_inv_evals[permutation_index]; - let mut current_delta = x_0 * &x_3; - for advice_eval in wires - .iter() - .map(|&wire| self.advice_evals[vk.cs.get_advice_query_index(wire, 0)]) - { - right *= &(advice_eval + ¤t_delta + &x_1); - current_delta *= &C::Scalar::DELTA; - } - - h_eval += &left; - h_eval -= &right; - } - - // Compute the expected h(x) value - let mut expected_h_eval = C::Scalar::zero(); - let mut cur = C::Scalar::one(); - for eval in &self.h_evals { - expected_h_eval += &(cur * eval); - cur *= &x_3n; - } - - if h_eval != (expected_h_eval * &(x_3n - &C::Scalar::one())) { - return Err(Error::ConstraintSystemFailure); - } - - // We are now convinced the circuit is satisfied so long as the - // polynomial commitments open to the correct values. - // Sample x_4 for compressing openings at the same points together let x_4: C::Scalar = get_challenge_scalar(Challenge(transcript.squeeze().get_lower_128())); @@ -293,4 +219,95 @@ impl<'a, C: CurveAffine> Proof { .verify(params, msm, &mut transcript, x_6, commitment_msm, msm_eval) .map_err(|_| Error::OpeningError) } + + /// Checks that this proof's h_evals are correct, and thus that all of the + /// rules are satisfied. + fn check_hx( + &self, + params: &'a Params, + vk: &VerifyingKey, + x_0: C::Scalar, + x_1: C::Scalar, + x_2: C::Scalar, + x_3: C::Scalar, + ) -> Result<(), Error> { + // x_3^n + let x_3n = x_3.pow(&[params.n as u64, 0, 0, 0]); + + // TODO: bubble this error up + // l_0(x_3) + let l_0 = (x_3 - &C::Scalar::one()).invert().unwrap() // 1 / (x_3 - 1) + * &(x_3n - &C::Scalar::one()) // (x_3^n - 1) / (x_3 - 1) + * &vk.domain.get_barycentric_weight(); // l_0(x_3) + + // Compute the expected value of h(x_3) + let h_eval = std::iter::empty() + // Evaluate the circuit using the custom gates provided + .chain(vk.cs.gates.iter().map(|poly| { + poly.evaluate( + &|index| self.fixed_evals[index], + &|index| self.advice_evals[index], + &|index| self.aux_evals[index], + &|a, b| a + &b, + &|a, b| a * &b, + &|a, scalar| a * &scalar, + ) + })) + // l_0(X) * (1 - z(X)) = 0 + .chain( + self.permutation_product_evals + .iter() + .map(|product_eval| l_0 * &(C::Scalar::one() - &product_eval)), + ) + // z(X) \prod (p(X) + \beta s_i(X) + \gamma) + // - z(omega^{-1} X) \prod (p(X) + \delta^i \beta X + \gamma) + .chain( + vk.cs + .permutations + .iter() + .zip(self.permutation_evals.iter()) + .zip(self.permutation_product_evals.iter()) + .zip(self.permutation_product_inv_evals.iter()) + .map( + |(((wires, permutation_evals), product_eval), product_inv_eval)| { + let mut left = *product_eval; + for (advice_eval, permutation_eval) in wires + .iter() + .map(|&wire| { + self.advice_evals[vk.cs.get_advice_query_index(wire, 0)] + }) + .zip(permutation_evals.iter()) + { + left *= &(advice_eval + &(x_0 * permutation_eval) + &x_1); + } + + let mut right = *product_inv_eval; + let mut current_delta = x_0 * &x_3; + for advice_eval in wires.iter().map(|&wire| { + self.advice_evals[vk.cs.get_advice_query_index(wire, 0)] + }) { + right *= &(advice_eval + ¤t_delta + &x_1); + current_delta *= &C::Scalar::DELTA; + } + + left - &right + }, + ), + ) + .fold(C::Scalar::zero(), |h_eval, v| h_eval * &x_2 + &v); + + // Compute the expected h(x_3) value + let mut expected_h_eval = C::Scalar::zero(); + let mut cur = C::Scalar::one(); + for eval in &self.h_evals { + expected_h_eval += &(cur * eval); + cur *= &x_3n; + } + + if h_eval != (expected_h_eval * &(x_3n - &C::Scalar::one())) { + return Err(Error::ConstraintSystemFailure); + } + + Ok(()) + } } From 9672bf9725196ab9f03897e6d23557b2c6b5ad1b Mon Sep 17 00:00:00 2001 From: Sean Bowe Date: Tue, 29 Sep 2020 17:14:37 -0600 Subject: [PATCH 3/5] Minor improvements to check_hx() --- src/plonk/verifier.rs | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/plonk/verifier.rs b/src/plonk/verifier.rs index 4ee3e378..794ae013 100644 --- a/src/plonk/verifier.rs +++ b/src/plonk/verifier.rs @@ -241,7 +241,7 @@ impl<'a, C: CurveAffine> Proof { * &vk.domain.get_barycentric_weight(); // l_0(x_3) // Compute the expected value of h(x_3) - let h_eval = std::iter::empty() + let expected_h_eval = std::iter::empty() // Evaluate the circuit using the custom gates provided .chain(vk.cs.gates.iter().map(|poly| { poly.evaluate( @@ -296,15 +296,16 @@ impl<'a, C: CurveAffine> Proof { ) .fold(C::Scalar::zero(), |h_eval, v| h_eval * &x_2 + &v); - // Compute the expected h(x_3) value - let mut expected_h_eval = C::Scalar::zero(); - let mut cur = C::Scalar::one(); - for eval in &self.h_evals { - expected_h_eval += &(cur * eval); - cur *= &x_3n; - } + // Compute h(x_3) from the prover + let (_, h_eval) = self + .h_evals + .iter() + .fold((C::Scalar::one(), C::Scalar::zero()), |(cur, acc), eval| { + (cur * &x_3n, acc + &(cur * eval)) + }); - if h_eval != (expected_h_eval * &(x_3n - &C::Scalar::one())) { + // Did the prover commit to the correct polynomial? + if expected_h_eval != (h_eval * &(x_3n - &C::Scalar::one())) { return Err(Error::ConstraintSystemFailure); } From 2ccddac674524e3d0e91d092f3d1b76c5aaf66b2 Mon Sep 17 00:00:00 2001 From: Sean Bowe Date: Tue, 29 Sep 2020 17:35:24 -0600 Subject: [PATCH 4/5] Split proof/input length checks into separate method of verifier --- src/plonk/verifier.rs | 66 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 59 insertions(+), 7 deletions(-) diff --git a/src/plonk/verifier.rs b/src/plonk/verifier.rs index 794ae013..e7866a35 100644 --- a/src/plonk/verifier.rs +++ b/src/plonk/verifier.rs @@ -15,13 +15,7 @@ impl<'a, C: CurveAffine> Proof { mut msm: MSM<'a, C>, aux_commitments: &[C], ) -> Result, Error> { - // Check that aux_commitments matches the expected number of aux_wires - // and self.aux_evals - if aux_commitments.len() != vk.cs.num_aux_wires - || self.aux_evals.len() != vk.cs.num_aux_wires - { - return Err(Error::IncompatibleParams); - } + self.check_lengths(vk, aux_commitments)?; // Scale the MSM by a random factor to ensure that if the existing MSM // has is_zero() == false then this argument won't be able to interfere @@ -220,6 +214,64 @@ impl<'a, C: CurveAffine> Proof { .map_err(|_| Error::OpeningError) } + /// Checks that the lengths of vectors are consistent with the constraint + /// system + fn check_lengths(&self, vk: &VerifyingKey, aux_commitments: &[C]) -> Result<(), Error> { + // Check that aux_commitments matches the expected number of aux_wires + // and self.aux_evals + if aux_commitments.len() != vk.cs.num_aux_wires + || self.aux_evals.len() != vk.cs.num_aux_wires + { + return Err(Error::IncompatibleParams); + } + + if self.q_evals.len() != vk.cs.rotations.len() { + return Err(Error::IncompatibleParams); + } + + // TODO: check h_evals + + if self.fixed_evals.len() != vk.cs.fixed_queries.len() { + return Err(Error::IncompatibleParams); + } + + if self.advice_evals.len() != vk.cs.advice_queries.len() { + return Err(Error::IncompatibleParams); + } + + if self.permutation_evals.len() != vk.cs.permutations.len() { + return Err(Error::IncompatibleParams); + } + + for (permutation_evals, permutation) in + self.permutation_evals.iter().zip(vk.cs.permutations.iter()) + { + if permutation_evals.len() != permutation.len() { + return Err(Error::IncompatibleParams); + } + } + + if self.permutation_product_inv_evals.len() != vk.cs.permutations.len() { + return Err(Error::IncompatibleParams); + } + + if self.permutation_product_evals.len() != vk.cs.permutations.len() { + return Err(Error::IncompatibleParams); + } + + if self.permutation_product_commitments.len() != vk.cs.permutations.len() { + return Err(Error::IncompatibleParams); + } + + // TODO: check h_commitments + + if self.advice_commitments.len() != vk.cs.num_advice_wires { + return Err(Error::IncompatibleParams); + } + + Ok(()) + } + /// Checks that this proof's h_evals are correct, and thus that all of the /// rules are satisfied. fn check_hx( From 67b35954f4fd5636e0706390e0b77735534972d4 Mon Sep 17 00:00:00 2001 From: Sean Bowe Date: Tue, 13 Oct 2020 08:16:20 -0600 Subject: [PATCH 5/5] Move MSM into submodule. --- src/poly/commitment.rs | 126 ++++--------------------------------- src/poly/commitment/msm.rs | 111 ++++++++++++++++++++++++++++++++ 2 files changed, 123 insertions(+), 114 deletions(-) create mode 100644 src/poly/commitment/msm.rs diff --git a/src/poly/commitment.rs b/src/poly/commitment.rs index 73572ed1..15e089e3 100644 --- a/src/poly/commitment.rs +++ b/src/poly/commitment.rs @@ -8,113 +8,13 @@ use crate::arithmetic::{best_fft, best_multiexp, parallelize, Curve, CurveAffine use crate::transcript::Hasher; use std::ops::{Add, AddAssign, Mul, MulAssign}; +mod msm; mod prover; mod verifier; +pub use msm::MSM; pub use verifier::{Accumulator, Guard}; -/// This is a proof object for the polynomial commitment scheme opening. -#[derive(Debug, Clone)] -pub struct Proof { - rounds: Vec<(C, C)>, - delta: C, - z1: C::Scalar, - z2: C::Scalar, -} - -/// A multiscalar multiplication in the polynomial commitment scheme -#[derive(Debug, Clone)] -pub struct MSM<'a, C: CurveAffine> { - params: &'a Params, - g_scalars: Option>, - h_scalar: Option, - other_scalars: Vec, - other_bases: Vec, -} - -impl<'a, C: CurveAffine> MSM<'a, C> { - /// Add another multiexp into this one - pub fn add_msm(&mut self, other: &Self) { - self.other_scalars.extend(other.other_scalars.iter()); - self.other_bases.extend(other.other_bases.iter()); - - if let Some(g_scalars) = &other.g_scalars { - self.add_to_g(&g_scalars); - } - - if let Some(h_scalar) = &other.h_scalar { - self.add_to_h(*h_scalar); - } - } - - /// Add arbitrary term (the scalar and the point) - pub fn add_term(&mut self, scalar: C::Scalar, point: C) { - self.other_scalars.push(scalar); - self.other_bases.push(point); - } - - /// Add a vector of scalars to `g_scalars`. This function will panic if the - /// caller provides a slice of scalars that is not of length `params.n`. - // TODO: parallelize - pub fn add_to_g(&mut self, scalars: &[C::Scalar]) { - assert_eq!(scalars.len(), self.params.n as usize); - if let Some(g_scalars) = &mut self.g_scalars { - for (g_scalar, scalar) in g_scalars.iter_mut().zip(scalars.iter()) { - *g_scalar += &scalar; - } - } else { - self.g_scalars = Some(scalars.to_vec()); - } - } - - /// Add term to h - pub fn add_to_h(&mut self, scalar: C::Scalar) { - self.h_scalar = self.h_scalar.map_or(Some(scalar), |a| Some(a + &scalar)); - } - - /// Scale all scalars in the MSM by some scaling factor - // TODO: parallelize - pub fn scale(&mut self, factor: C::Scalar) { - if let Some(g_scalars) = &mut self.g_scalars { - for g_scalar in g_scalars.iter_mut() { - *g_scalar *= &factor; - } - } - - // TODO: parallelize - for other_scalar in self.other_scalars.iter_mut() { - *other_scalar *= &factor; - } - self.h_scalar = self.h_scalar.map(|a| a * &factor); - } - - /// Perform multiexp and check that it results in zero - pub fn eval(self) -> bool { - let len = self.g_scalars.as_ref().map(|v| v.len()).unwrap_or(0) - + self.h_scalar.map(|_| 1).unwrap_or(0) - + self.other_scalars.len(); - let mut scalars: Vec = Vec::with_capacity(len); - let mut bases: Vec = Vec::with_capacity(len); - - scalars.extend(&self.other_scalars); - bases.extend(&self.other_bases); - - if let Some(h_scalar) = self.h_scalar { - scalars.push(h_scalar); - bases.push(self.params.h); - } - - if let Some(g_scalars) = &self.g_scalars { - scalars.extend(g_scalars); - bases.extend(self.params.g.iter()); - } - - assert_eq!(scalars.len(), len); - - bool::from(best_multiexp(&scalars, &bases).is_zero()) - } -} - /// These are the public parameters for the polynomial commitment scheme. #[derive(Debug)] pub struct Params { @@ -125,6 +25,15 @@ pub struct Params { pub(crate) h: C, } +/// This is a proof object for the polynomial commitment scheme opening. +#[derive(Debug, Clone)] +pub struct Proof { + rounds: Vec<(C, C)>, + delta: C, + z1: C::Scalar, + z2: C::Scalar, +} + impl Params { /// Initializes parameters for the curve, given a random oracle to draw /// points from. @@ -250,18 +159,7 @@ impl Params { /// Generates an empty multiscalar multiplication struct using the /// appropriate params. pub fn empty_msm(&self) -> MSM { - let g_scalars = None; - let h_scalar = None; - let other_scalars = vec![]; - let other_bases = vec![]; - - MSM { - params: &self, - g_scalars, - h_scalar, - other_scalars, - other_bases, - } + MSM::new(self) } /// Getter for g generators diff --git a/src/poly/commitment/msm.rs b/src/poly/commitment/msm.rs new file mode 100644 index 00000000..7e04d923 --- /dev/null +++ b/src/poly/commitment/msm.rs @@ -0,0 +1,111 @@ +use super::Params; +use crate::arithmetic::{best_multiexp, Curve, CurveAffine}; + +/// A multiscalar multiplication in the polynomial commitment scheme +#[derive(Debug, Clone)] +pub struct MSM<'a, C: CurveAffine> { + pub(crate) params: &'a Params, + g_scalars: Option>, + h_scalar: Option, + other_scalars: Vec, + other_bases: Vec, +} + +impl<'a, C: CurveAffine> MSM<'a, C> { + /// Create a new, empty MSM using the provided parameters. + pub fn new(params: &'a Params) -> Self { + let g_scalars = None; + let h_scalar = None; + let other_scalars = vec![]; + let other_bases = vec![]; + + MSM { + params, + g_scalars, + h_scalar, + other_scalars, + other_bases, + } + } + + /// Add another multiexp into this one + pub fn add_msm(&mut self, other: &Self) { + self.other_scalars.extend(other.other_scalars.iter()); + self.other_bases.extend(other.other_bases.iter()); + + if let Some(g_scalars) = &other.g_scalars { + self.add_to_g(&g_scalars); + } + + if let Some(h_scalar) = &other.h_scalar { + self.add_to_h(*h_scalar); + } + } + + /// Add arbitrary term (the scalar and the point) + pub fn add_term(&mut self, scalar: C::Scalar, point: C) { + self.other_scalars.push(scalar); + self.other_bases.push(point); + } + + /// Add a vector of scalars to `g_scalars`. This function will panic if the + /// caller provides a slice of scalars that is not of length `params.n`. + // TODO: parallelize + pub fn add_to_g(&mut self, scalars: &[C::Scalar]) { + assert_eq!(scalars.len(), self.params.n as usize); + if let Some(g_scalars) = &mut self.g_scalars { + for (g_scalar, scalar) in g_scalars.iter_mut().zip(scalars.iter()) { + *g_scalar += &scalar; + } + } else { + self.g_scalars = Some(scalars.to_vec()); + } + } + + /// Add term to h + pub fn add_to_h(&mut self, scalar: C::Scalar) { + self.h_scalar = self.h_scalar.map_or(Some(scalar), |a| Some(a + &scalar)); + } + + /// Scale all scalars in the MSM by some scaling factor + // TODO: parallelize + pub fn scale(&mut self, factor: C::Scalar) { + if let Some(g_scalars) = &mut self.g_scalars { + for g_scalar in g_scalars.iter_mut() { + *g_scalar *= &factor; + } + } + + // TODO: parallelize + for other_scalar in self.other_scalars.iter_mut() { + *other_scalar *= &factor; + } + self.h_scalar = self.h_scalar.map(|a| a * &factor); + } + + /// Perform multiexp and check that it results in zero + pub fn eval(self) -> bool { + let len = self.g_scalars.as_ref().map(|v| v.len()).unwrap_or(0) + + self.h_scalar.map(|_| 1).unwrap_or(0) + + self.other_scalars.len(); + let mut scalars: Vec = Vec::with_capacity(len); + let mut bases: Vec = Vec::with_capacity(len); + + scalars.extend(&self.other_scalars); + bases.extend(&self.other_bases); + + if let Some(h_scalar) = self.h_scalar { + scalars.push(h_scalar); + bases.push(self.params.h); + } + + if let Some(g_scalars) = &self.g_scalars { + scalars.extend(g_scalars); + bases.extend(self.params.g.iter()); + } + + assert_eq!(scalars.len(), len); + + bool::from(best_multiexp(&scalars, &bases).is_zero()) + } +}