diff --git a/src/plonk.rs b/src/plonk.rs index 369c30e5..d890d3fb 100644 --- a/src/plonk.rs +++ b/src/plonk.rs @@ -282,7 +282,7 @@ fn test_proving() { for _ in 0..10 { let mut a_squared = None; - let (_, _, c0) = cs.raw_multiply(|| { + let (a0, _, c0) = cs.raw_multiply(|| { a_squared = self.a.map(|a| a.square()); Ok(( self.a.ok_or(Error::SynthesisError)?, @@ -290,7 +290,7 @@ fn test_proving() { a_squared.ok_or(Error::SynthesisError)?, )) })?; - let (_, b1, _) = cs.raw_add(|| { + let (a1, b1, _) = cs.raw_add(|| { let fin = a_squared.and_then(|a2| self.a.map(|a| a + a2)); Ok(( self.a.ok_or(Error::SynthesisError)?, @@ -298,6 +298,7 @@ fn test_proving() { fin.ok_or(Error::SynthesisError)?, )) })?; + cs.copy(a0, a1)?; cs.copy(b1, c0)?; } @@ -306,7 +307,7 @@ fn test_proving() { } let circuit: MyCircuit = MyCircuit { - a: Some((-Fp::from_u64(2) + Fp::ROOT_OF_UNITY).pow(&[100, 0, 0, 0])), + a: Some(Fp::random()), }; let empty_circuit: MyCircuit = MyCircuit { a: None }; diff --git a/src/plonk/srs.rs b/src/plonk/srs.rs index d6c69bae..9b2b9975 100644 --- a/src/plonk/srs.rs +++ b/src/plonk/srs.rs @@ -15,7 +15,9 @@ impl SRS { ) -> Result { struct Assembly { fixed: Vec>, - copy: Vec>>, + mapping: Vec>>, + aux: Vec>>, + sizes: Vec>>, } impl ConstraintSystem for Assembly { @@ -52,25 +54,45 @@ impl SRS { right_wire: usize, right_row: usize, ) -> Result<(), Error> { - let left: (usize, usize) = *self.copy[permutation] - .get_mut(left_wire) - .and_then(|wire| wire.get_mut(left_row)) - .ok_or(Error::BoundsFailure)?; - - let right: (usize, usize) = *self.copy[permutation] - .get_mut(right_wire) - .and_then(|wire| wire.get_mut(right_row)) - .ok_or(Error::BoundsFailure)?; - - if left == (left_wire, left_row) || right == (right_wire, right_row) { - // Don't perform the copy constraint because it will undo - // the effect of the permutation. - } else { - self.copy[permutation][left_wire][left_row] = right; - - self.copy[permutation][right_wire][right_row] = left; + // Check bounds first + if permutation >= self.mapping.len() + || left_wire >= self.mapping[permutation].len() + || left_row >= self.mapping[permutation][left_wire].len() + || right_wire >= self.mapping[permutation].len() + || right_row >= self.mapping[permutation][right_wire].len() + { + return Err(Error::BoundsFailure); } + let mut left_cycle = self.aux[permutation][left_wire][left_row]; + let mut right_cycle = self.aux[permutation][right_wire][right_row]; + + if left_cycle == right_cycle { + return Ok(()); + } + + if self.sizes[permutation][left_cycle.0][left_cycle.1] + < self.sizes[permutation][right_cycle.0][right_cycle.1] + { + std::mem::swap(&mut left_cycle, &mut right_cycle); + } + + self.sizes[permutation][left_cycle.0][left_cycle.1] += + self.sizes[permutation][right_cycle.0][right_cycle.1]; + let mut i = right_cycle; + loop { + self.aux[permutation][i.0][i.1] = left_cycle; + i = self.mapping[permutation][i.0][i.1]; + if i == right_cycle { + break; + } + } + + let tmp = self.mapping[permutation][left_wire][left_row]; + self.mapping[permutation][left_wire][left_row] = + self.mapping[permutation][right_wire][right_row]; + self.mapping[permutation][right_wire][right_row] = tmp; + Ok(()) } } @@ -126,7 +148,9 @@ impl SRS { let mut assembly: Assembly = Assembly { fixed: vec![vec![C::Scalar::zero(); params.n as usize]; meta.num_fixed_wires], - copy: vec![], + mapping: vec![], + aux: vec![], + sizes: vec![], }; // Initialize the copy vector to keep track of copy constraints in all @@ -137,7 +161,11 @@ impl SRS { // Computes [(i, 0), (i, 1), ..., (i, n - 1)] wires.push((0..params.n).map(|j| (i, j as usize)).collect()); } - assembly.copy.push(wires); + assembly.mapping.push(wires.clone()); + assembly.aux.push(wires); + assembly + .sizes + .push(vec![vec![1usize; params.n as usize]; permutation.len()]); } // Synthesize the circuit to obtain SRS @@ -162,7 +190,7 @@ impl SRS { // assembly.copy[permutation_index] is indexed by wire // i, and then indexed by row j, obtaining the index of // the permuted value in deltaomega. - let (permuted_i, permuted_j) = assembly.copy[permutation_index][i][j]; + let (permuted_i, permuted_j) = assembly.mapping[permutation_index][i][j]; deltaomega[permuted_i][permuted_j] }) .collect();