Don't mutate the witness during permutation argument. Also, adds parallelism and reduces state/multiplications.

This commit is contained in:
Sean Bowe 2020-09-07 09:37:49 -06:00
parent b65e75921b
commit 21f02a73c2
No known key found for this signature in database
GPG Key ID: 95684257D8F8B031
2 changed files with 37 additions and 46 deletions

View File

@ -116,6 +116,7 @@ fn test_proving() {
sm: FixedWire,
perm: usize,
perm2: usize,
}
trait StandardCS<FF: Field> {
@ -242,7 +243,9 @@ fn test_proving() {
};
self.cs
.copy(self.config.perm, left_wire, left.1, right_wire, right.1)
.copy(self.config.perm, left_wire, left.1, right_wire, right.1)?;
self.cs
.copy(self.config.perm2, left_wire, left.1, right_wire, right.1)
}
}
@ -258,6 +261,7 @@ fn test_proving() {
let d = meta.advice_wire();
let perm = meta.permutation(&[a, b, c]);
let perm2 = meta.permutation(&[a, b, c]);
let sm = meta.fixed_wire();
let sa = meta.fixed_wire();
@ -291,6 +295,7 @@ fn test_proving() {
sc,
sm,
perm,
perm2,
}
}

View File

@ -78,6 +78,8 @@ impl<C: CurveAffine> Proof<C> {
// Synthesize the circuit to obtain the witness and other information.
circuit.synthesize(&mut witness, config)?;
let witness = witness;
// Create a transcript for obtaining Fiat-Shamir challenges.
let mut transcript = HBase::init(C::Base::one());
@ -132,31 +134,26 @@ impl<C: CurveAffine> Proof<C> {
// Iterate over each permutation
let mut permutation_modified_advice = vec![];
for (wires, permuted_values) in srs.meta.permutations.iter().zip(srs.permutations.iter()) {
// Goal is to compute the fraction
// Goal is to compute the products of fractions
//
// (p_j(\omega^i) + \delta^j \omega^i \beta + \gamma) /
// (p_j(\omega^i) + \beta s_j(\omega^i) + \gamma)
//
// where p_j(X) is the jth advice wire in this permutation,
// and i is the ith row of the wire.
let mut modified_advice = Vec::with_capacity(wires.len());
let mut modified_advice = vec![C::Scalar::one(); params.n as usize];
// Iterate over each wire of the permutation
for (&(wire, _), permuted_wire_values) in wires.iter().zip(permuted_values.iter()) {
// Grab the advice wire's values from the witness
let mut tmp_advice_values = witness.advice[wire.0].clone();
// For each row i, compute
// p_j(\omega^i) + \beta s_j(\omega^i) + \gamma
// where p_j(omega^i) = tmp[i]
for (tmp_advice_value, permuted_advice_value) in tmp_advice_values
.iter_mut()
.zip(permuted_wire_values.iter())
{
*tmp_advice_value += &(x_0 * permuted_advice_value); // p_j(\omega^i) + \beta s_j(\omega^i)
*tmp_advice_value += &x_1; // p_j(\omega^i) + \beta s_j(\omega^i) + \gamma
}
modified_advice.push(tmp_advice_values);
parallelize(&mut modified_advice, |modified_advice, start| {
for ((modified_advice, advice_value), permuted_advice_value) in modified_advice
.iter_mut()
.zip(witness.advice[wire.0][start..].iter())
.zip(permuted_wire_values[start..].iter())
{
*modified_advice *= &(x_0 * permuted_advice_value + &x_1 + advice_value);
}
});
}
permutation_modified_advice.push(modified_advice);
@ -167,7 +164,6 @@ impl<C: CurveAffine> Proof<C> {
permutation_modified_advice
.iter_mut()
.flat_map(|v| v.iter_mut())
.flat_map(|v| v.iter_mut())
.batch_invert();
for (wires, mut modified_advice) in srs
@ -179,30 +175,30 @@ impl<C: CurveAffine> Proof<C> {
// Iterate over each wire again, this time finishing the computation
// of the entire fraction by computing the numerators
let mut deltaomega = C::Scalar::one();
for (&(wire, _), modified_advice) in wires.iter().zip(modified_advice.iter_mut()) {
// For each row i, we compute
// p_j(\omega^i) + \delta^j \omega^i \beta + \gamma
// for the jth wire of the permutation
for (tmp_advice_value, modified_advice) in witness.advice[wire.0]
.iter_mut()
.zip(modified_advice.iter_mut())
{
*tmp_advice_value += &(deltaomega * &x_0); // p_j(\omega^i) + \delta^j \omega^i \beta
*tmp_advice_value += &x_1; // p_j(\omega^i) + \delta^j \omega^i \beta + \gamma
*modified_advice *= tmp_advice_value;
deltaomega *= &domain.get_omega();
}
for &(wire, _) in wires.iter() {
let omega = domain.get_omega();
parallelize(&mut modified_advice, |modified_advice, start| {
let mut deltaomega = deltaomega * &omega.pow_vartime(&[start as u64, 0, 0, 0]);
for (modified_advice, advice_value) in modified_advice
.iter_mut()
.zip(witness.advice[wire.0][start..].iter())
{
// Multiply by p_j(\omega^i) + \delta^j \omega^i \beta
*modified_advice *= &(deltaomega * &x_0 + &x_1 + advice_value);
deltaomega *= &omega;
}
});
deltaomega *= &C::Scalar::DELTA;
}
// The modified_advice vector is a vector of vectors of fractions of
// the form
// The modified_advice vector is a vector of products of fractions
// of the form
//
// (p_j(\omega^i) + \delta^j \omega^i \beta + \gamma) /
// (p_j(\omega^i) + \beta s_j(\omega^i) + \gamma)
//
// where j is the index into modified_advice, and i is the index
// into modified_advice[j], for the jth wire in the permutation
// where i is the index into modified_advice, for the jth wire in
// the permutation
// Compute the evaluations of the permutation product polynomial
// over our domain, starting with z[0] = 1
@ -210,17 +206,7 @@ impl<C: CurveAffine> Proof<C> {
for row in 1..(params.n as usize) {
let mut tmp = z[row - 1];
// Iterate over each wire's modified advice, where for the jth
// wire we obtain the fraction
//
// (p_j(\omega^i) + \delta^j \omega^i \beta + \gamma) /
// (p_j(\omega^i) + \beta s_j(\omega^i) + \gamma)
//
// where i is the row of the permutation product polynomial
// evaluation vector that we are currently evaluating.
for wire_modified_advice in modified_advice.iter() {
tmp *= &wire_modified_advice[row];
}
tmp *= &modified_advice[row];
z.push(tmp);
}