halo2/src/poly/commitment/msm.rs

150 lines
4.8 KiB
Rust
Raw Normal View History

2020-10-13 07:16:20 -07:00
use super::Params;
use crate::arithmetic::{best_multiexp, parallelize, Curve, CurveAffine};
use ff::Field;
2020-10-13 07:16:20 -07:00
/// A multiscalar multiplication in the polynomial commitment scheme
#[derive(Debug, Clone)]
pub struct MSM<'a, C: CurveAffine> {
pub(crate) params: &'a Params<C>,
g_scalars: Option<Vec<C::Scalar>>,
h_scalar: Option<C::Scalar>,
u_scalar: Option<C::Scalar>,
2020-10-13 07:16:20 -07:00
other_scalars: Vec<C::Scalar>,
other_bases: Vec<C>,
}
impl<'a, C: CurveAffine> MSM<'a, C> {
/// Create a new, empty MSM using the provided parameters.
pub fn new(params: &'a Params<C>) -> Self {
let g_scalars = None;
let h_scalar = None;
let u_scalar = None;
2020-10-13 07:16:20 -07:00
let other_scalars = vec![];
let other_bases = vec![];
MSM {
params,
g_scalars,
h_scalar,
u_scalar,
2020-10-13 07:16:20 -07:00
other_scalars,
other_bases,
}
}
/// Add another multiexp into this one
pub fn add_msm(&mut self, other: &Self) {
self.other_scalars.extend(other.other_scalars.iter());
self.other_bases.extend(other.other_bases.iter());
if let Some(g_scalars) = &other.g_scalars {
self.add_to_g_scalars(&g_scalars);
2020-10-13 07:16:20 -07:00
}
if let Some(h_scalar) = &other.h_scalar {
self.add_to_h_scalar(*h_scalar);
2020-10-13 07:16:20 -07:00
}
if let Some(u_scalar) = &other.u_scalar {
self.add_to_u_scalar(*u_scalar);
}
2020-10-13 07:16:20 -07:00
}
/// Add arbitrary term (the scalar and the point)
pub fn append_term(&mut self, scalar: C::Scalar, point: C) {
2020-10-13 07:16:20 -07:00
self.other_scalars.push(scalar);
self.other_bases.push(point);
}
/// Add a value to the first entry of `g_scalars`.
pub fn add_constant_term(&mut self, constant: C::Scalar) {
if let Some(g_scalars) = self.g_scalars.as_mut() {
g_scalars[0] += &constant;
} else {
let mut g_scalars = vec![C::Scalar::zero(); self.params.n as usize];
g_scalars[0] += &constant;
self.g_scalars = Some(g_scalars);
}
}
2020-10-13 07:16:20 -07:00
/// Add a vector of scalars to `g_scalars`. This function will panic if the
/// caller provides a slice of scalars that is not of length `params.n`.
pub fn add_to_g_scalars(&mut self, scalars: &[C::Scalar]) {
2020-10-13 07:16:20 -07:00
assert_eq!(scalars.len(), self.params.n as usize);
if let Some(g_scalars) = &mut self.g_scalars {
parallelize(g_scalars, |g_scalars, start| {
for (g_scalar, scalar) in g_scalars.iter_mut().zip(scalars[start..].iter()) {
*g_scalar += scalar;
}
})
2020-10-13 07:16:20 -07:00
} else {
self.g_scalars = Some(scalars.to_vec());
}
}
/// Add to `h_scalar`
pub fn add_to_h_scalar(&mut self, scalar: C::Scalar) {
2020-10-13 07:16:20 -07:00
self.h_scalar = self.h_scalar.map_or(Some(scalar), |a| Some(a + &scalar));
}
/// Add to `u_scalar`
pub fn add_to_u_scalar(&mut self, scalar: C::Scalar) {
self.u_scalar = self.u_scalar.map_or(Some(scalar), |a| Some(a + &scalar));
}
2020-10-13 07:16:20 -07:00
/// Scale all scalars in the MSM by some scaling factor
pub fn scale(&mut self, factor: C::Scalar) {
if let Some(g_scalars) = &mut self.g_scalars {
parallelize(g_scalars, |g_scalars, _| {
for g_scalar in g_scalars {
*g_scalar *= &factor;
}
})
2020-10-13 07:16:20 -07:00
}
if !self.other_scalars.is_empty() {
parallelize(&mut self.other_scalars, |other_scalars, _| {
for other_scalar in other_scalars {
*other_scalar *= &factor;
}
})
2020-10-13 07:16:20 -07:00
}
2020-10-13 07:16:20 -07:00
self.h_scalar = self.h_scalar.map(|a| a * &factor);
self.u_scalar = self.u_scalar.map(|a| a * &factor);
2020-10-13 07:16:20 -07:00
}
/// Perform multiexp and check that it results in zero
pub fn eval(self) -> bool {
let len = self.g_scalars.as_ref().map(|v| v.len()).unwrap_or(0)
+ self.h_scalar.map(|_| 1).unwrap_or(0)
+ self.u_scalar.map(|_| 1).unwrap_or(0)
2020-10-13 07:16:20 -07:00
+ self.other_scalars.len();
let mut scalars: Vec<C::Scalar> = Vec::with_capacity(len);
let mut bases: Vec<C> = Vec::with_capacity(len);
scalars.extend(&self.other_scalars);
bases.extend(&self.other_bases);
if let Some(h_scalar) = self.h_scalar {
scalars.push(h_scalar);
bases.push(self.params.h);
}
if let Some(u_scalar) = self.u_scalar {
scalars.push(u_scalar);
bases.push(self.params.u);
}
2020-10-13 07:16:20 -07:00
if let Some(g_scalars) = &self.g_scalars {
scalars.extend(g_scalars);
bases.extend(self.params.g.iter());
}
assert_eq!(scalars.len(), len);
2020-12-22 04:27:36 -08:00
metrics::increment_counter!("multiexp", "size" => format!("{}", len), "fn" => "MSM::eval");
2020-10-13 07:16:20 -07:00
bool::from(best_multiexp(&scalars, &bases).is_zero())
}
}