Add implementation of daira's algorithm for copy constraint enforcement.

This commit is contained in:
Sean Bowe 2020-09-05 12:56:45 -06:00
parent d7132404ba
commit 937861c0b8
No known key found for this signature in database
GPG Key ID: 95684257D8F8B031
2 changed files with 53 additions and 24 deletions

View File

@ -282,7 +282,7 @@ fn test_proving() {
for _ in 0..10 { for _ in 0..10 {
let mut a_squared = None; let mut a_squared = None;
let (_, _, c0) = cs.raw_multiply(|| { let (a0, _, c0) = cs.raw_multiply(|| {
a_squared = self.a.map(|a| a.square()); a_squared = self.a.map(|a| a.square());
Ok(( Ok((
self.a.ok_or(Error::SynthesisError)?, self.a.ok_or(Error::SynthesisError)?,
@ -290,7 +290,7 @@ fn test_proving() {
a_squared.ok_or(Error::SynthesisError)?, a_squared.ok_or(Error::SynthesisError)?,
)) ))
})?; })?;
let (_, b1, _) = cs.raw_add(|| { let (a1, b1, _) = cs.raw_add(|| {
let fin = a_squared.and_then(|a2| self.a.map(|a| a + a2)); let fin = a_squared.and_then(|a2| self.a.map(|a| a + a2));
Ok(( Ok((
self.a.ok_or(Error::SynthesisError)?, self.a.ok_or(Error::SynthesisError)?,
@ -298,6 +298,7 @@ fn test_proving() {
fin.ok_or(Error::SynthesisError)?, fin.ok_or(Error::SynthesisError)?,
)) ))
})?; })?;
cs.copy(a0, a1)?;
cs.copy(b1, c0)?; cs.copy(b1, c0)?;
} }
@ -306,7 +307,7 @@ fn test_proving() {
} }
let circuit: MyCircuit<Fp> = MyCircuit { let circuit: MyCircuit<Fp> = MyCircuit {
a: Some((-Fp::from_u64(2) + Fp::ROOT_OF_UNITY).pow(&[100, 0, 0, 0])), a: Some(Fp::random()),
}; };
let empty_circuit: MyCircuit<Fp> = MyCircuit { a: None }; let empty_circuit: MyCircuit<Fp> = MyCircuit { a: None };

View File

@ -15,7 +15,9 @@ impl<C: CurveAffine> SRS<C> {
) -> Result<Self, Error> { ) -> Result<Self, Error> {
struct Assembly<F: Field> { struct Assembly<F: Field> {
fixed: Vec<Vec<F>>, fixed: Vec<Vec<F>>,
copy: Vec<Vec<Vec<(usize, usize)>>>, mapping: Vec<Vec<Vec<(usize, usize)>>>,
aux: Vec<Vec<Vec<(usize, usize)>>>,
sizes: Vec<Vec<Vec<usize>>>,
} }
impl<F: Field> ConstraintSystem<F> for Assembly<F> { impl<F: Field> ConstraintSystem<F> for Assembly<F> {
@ -52,25 +54,45 @@ impl<C: CurveAffine> SRS<C> {
right_wire: usize, right_wire: usize,
right_row: usize, right_row: usize,
) -> Result<(), Error> { ) -> Result<(), Error> {
let left: (usize, usize) = *self.copy[permutation] // Check bounds first
.get_mut(left_wire) if permutation >= self.mapping.len()
.and_then(|wire| wire.get_mut(left_row)) || left_wire >= self.mapping[permutation].len()
.ok_or(Error::BoundsFailure)?; || left_row >= self.mapping[permutation][left_wire].len()
|| right_wire >= self.mapping[permutation].len()
let right: (usize, usize) = *self.copy[permutation] || right_row >= self.mapping[permutation][right_wire].len()
.get_mut(right_wire) {
.and_then(|wire| wire.get_mut(right_row)) return Err(Error::BoundsFailure);
.ok_or(Error::BoundsFailure)?;
if left == (left_wire, left_row) || right == (right_wire, right_row) {
// Don't perform the copy constraint because it will undo
// the effect of the permutation.
} else {
self.copy[permutation][left_wire][left_row] = right;
self.copy[permutation][right_wire][right_row] = left;
} }
let mut left_cycle = self.aux[permutation][left_wire][left_row];
let mut right_cycle = self.aux[permutation][right_wire][right_row];
if left_cycle == right_cycle {
return Ok(());
}
if self.sizes[permutation][left_cycle.0][left_cycle.1]
< self.sizes[permutation][right_cycle.0][right_cycle.1]
{
std::mem::swap(&mut left_cycle, &mut right_cycle);
}
self.sizes[permutation][left_cycle.0][left_cycle.1] +=
self.sizes[permutation][right_cycle.0][right_cycle.1];
let mut i = right_cycle;
loop {
self.aux[permutation][i.0][i.1] = left_cycle;
i = self.mapping[permutation][i.0][i.1];
if i == right_cycle {
break;
}
}
let tmp = self.mapping[permutation][left_wire][left_row];
self.mapping[permutation][left_wire][left_row] =
self.mapping[permutation][right_wire][right_row];
self.mapping[permutation][right_wire][right_row] = tmp;
Ok(()) Ok(())
} }
} }
@ -126,7 +148,9 @@ impl<C: CurveAffine> SRS<C> {
let mut assembly: Assembly<C::Scalar> = Assembly { let mut assembly: Assembly<C::Scalar> = Assembly {
fixed: vec![vec![C::Scalar::zero(); params.n as usize]; meta.num_fixed_wires], fixed: vec![vec![C::Scalar::zero(); params.n as usize]; meta.num_fixed_wires],
copy: vec![], mapping: vec![],
aux: vec![],
sizes: vec![],
}; };
// Initialize the copy vector to keep track of copy constraints in all // Initialize the copy vector to keep track of copy constraints in all
@ -137,7 +161,11 @@ impl<C: CurveAffine> SRS<C> {
// Computes [(i, 0), (i, 1), ..., (i, n - 1)] // Computes [(i, 0), (i, 1), ..., (i, n - 1)]
wires.push((0..params.n).map(|j| (i, j as usize)).collect()); wires.push((0..params.n).map(|j| (i, j as usize)).collect());
} }
assembly.copy.push(wires); assembly.mapping.push(wires.clone());
assembly.aux.push(wires);
assembly
.sizes
.push(vec![vec![1usize; params.n as usize]; permutation.len()]);
} }
// Synthesize the circuit to obtain SRS // Synthesize the circuit to obtain SRS
@ -162,7 +190,7 @@ impl<C: CurveAffine> SRS<C> {
// assembly.copy[permutation_index] is indexed by wire // assembly.copy[permutation_index] is indexed by wire
// i, and then indexed by row j, obtaining the index of // i, and then indexed by row j, obtaining the index of
// the permuted value in deltaomega. // the permuted value in deltaomega.
let (permuted_i, permuted_j) = assembly.copy[permutation_index][i][j]; let (permuted_i, permuted_j) = assembly.mapping[permutation_index][i][j];
deltaomega[permuted_i][permuted_j] deltaomega[permuted_i][permuted_j]
}) })
.collect(); .collect();