Add parallelism in various locations in the prover.

This commit is contained in:
Sean Bowe 2020-09-06 13:40:06 -06:00
parent 3157fdd7d0
commit e37d0c946b
No known key found for this signature in database
GPG Key ID: 95684257D8F8B031
1 changed files with 51 additions and 42 deletions

View File

@ -254,10 +254,11 @@ impl<C: CurveAffine> Proof<C> {
let mut h_poly = vec![C::Scalar::zero(); domain.coset_len()]; let mut h_poly = vec![C::Scalar::zero(); domain.coset_len()];
for (i, poly) in meta.gates.iter().enumerate() { for (i, poly) in meta.gates.iter().enumerate() {
if i != 0 { if i != 0 {
// TODO: parallelize parallelize(&mut h_poly, |a, _| {
for h in h_poly.iter_mut() { for a in a.iter_mut() {
*h *= &x_2; *a *= &x_2;
} }
});
} }
let evaluation: Vec<C::Scalar> = poly.evaluate( let evaluation: Vec<C::Scalar> = poly.evaluate(
@ -294,35 +295,36 @@ impl<C: CurveAffine> Proof<C> {
if i == 0 { if i == 0 {
h_poly = evaluation; h_poly = evaluation;
} else { } else {
// TODO: parallelize parallelize(&mut h_poly, |a, start| {
for (h, e) in h_poly.iter_mut().zip(evaluation.into_iter()) { for (a, b) in a.iter_mut().zip(evaluation[start..].iter()) {
*h += &e; *a += b;
} }
});
} }
} }
// l_0(X) * (1 - z(X)) = 0 // l_0(X) * (1 - z(X)) = 0
// TODO: parallelize // TODO: parallelize
for coset in permutation_product_cosets.iter() { for coset in permutation_product_cosets.iter() {
for h in h_poly.iter_mut() { parallelize(&mut h_poly, |h, start| {
for ((h, c), l0) in h
.iter_mut()
.zip(coset[start..].iter())
.zip(srs.l0[start..].iter())
{
*h *= &x_2; *h *= &x_2;
*h += &(*l0 * &(C::Scalar::one() - c));
} }
});
let mut tmp = srs.l0.clone();
for (t, c) in tmp.iter_mut().zip(coset.iter()) {
*t *= &(C::Scalar::one() - c);
}
for (h, e) in h_poly.iter_mut().zip(tmp.into_iter()) {
*h += &e;
}
} }
// z(X) \prod (p(X) + \beta s_i(X) + \gamma) - z(omega^{-1} X) \prod (p(X) + \delta^i \beta X + \gamma) // z(X) \prod (p(X) + \beta s_i(X) + \gamma) - z(omega^{-1} X) \prod (p(X) + \delta^i \beta X + \gamma)
for (permutation_index, wires) in srs.meta.permutations.iter().enumerate() { for (permutation_index, wires) in srs.meta.permutations.iter().enumerate() {
for h in h_poly.iter_mut() { parallelize(&mut h_poly, |a, _| {
*h *= &x_2; for a in a.iter_mut() {
*a *= &x_2;
} }
});
let mut left = permutation_product_cosets[permutation_index].clone(); let mut left = permutation_product_cosets[permutation_index].clone();
for (advice, permutation) in wires for (advice, permutation) in wires
@ -330,34 +332,41 @@ impl<C: CurveAffine> Proof<C> {
.map(|&wire_index| &advice_cosets[wire_index.0]) .map(|&wire_index| &advice_cosets[wire_index.0])
.zip(srs.permutation_cosets[permutation_index].iter()) .zip(srs.permutation_cosets[permutation_index].iter())
{ {
// TODO: parallelize parallelize(&mut left, |left, start| {
for ((left, advice), permutation) in for ((left, advice), permutation) in left
left.iter_mut().zip(advice.iter()).zip(permutation.iter()) .iter_mut()
.zip(advice[start..].iter())
.zip(permutation[start..].iter())
{ {
*left *= &(*advice + &(x_0 * permutation) + &x_1); *left *= &(*advice + &(x_0 * permutation) + &x_1);
} }
});
} }
let mut right = permutation_product_cosets_inv[permutation_index].clone(); let mut right = permutation_product_cosets_inv[permutation_index].clone();
let mut current_delta = x_0 * &C::Scalar::ZETA; let mut current_delta = x_0 * &C::Scalar::ZETA;
let step = domain.get_extended_omega(); let step = domain.get_extended_omega();
for advice in wires.iter().map(|&wire_index| &advice_cosets[wire_index.0]) { for advice in wires.iter().map(|&wire_index| &advice_cosets[wire_index.0]) {
// TODO: parallelize parallelize(&mut right, move |right, start| {
let mut beta_term = current_delta; let mut beta_term = current_delta * &step.pow_vartime(&[start as u64, 0, 0, 0]);
for (right, advice) in right.iter_mut().zip(advice.iter()) { for (right, advice) in right.iter_mut().zip(advice[start..].iter()) {
*right *= &(*advice + &beta_term + &x_1); *right *= &(*advice + &beta_term + &x_1);
beta_term *= &step; beta_term *= &step;
} }
});
current_delta *= &C::Scalar::DELTA; current_delta *= &C::Scalar::DELTA;
} }
for (h, e) in h_poly.iter_mut().zip(left.into_iter()) { parallelize(&mut h_poly, |a, start| {
*h += &e; for ((h, left), right) in a
} .iter_mut()
.zip(left[start..].iter())
for (h, e) in h_poly.iter_mut().zip(right.into_iter()) { .zip(right[start..].iter())
*h -= &e; {
*h += &left;
*h -= &right;
} }
});
} }
// Divide by t(X) = X^{params.n} - 1. // Divide by t(X) = X^{params.n} - 1.