Add implementation of daira's algorithm for copy constraint enforcement.

This commit is contained in:
Sean Bowe 2020-09-05 12:56:45 -06:00
parent d7132404ba
commit 937861c0b8
No known key found for this signature in database
GPG Key ID: 95684257D8F8B031
2 changed files with 53 additions and 24 deletions

View File

@ -282,7 +282,7 @@ fn test_proving() {
for _ in 0..10 {
let mut a_squared = None;
let (_, _, c0) = cs.raw_multiply(|| {
let (a0, _, c0) = cs.raw_multiply(|| {
a_squared = self.a.map(|a| a.square());
Ok((
self.a.ok_or(Error::SynthesisError)?,
@ -290,7 +290,7 @@ fn test_proving() {
a_squared.ok_or(Error::SynthesisError)?,
))
})?;
let (_, b1, _) = cs.raw_add(|| {
let (a1, b1, _) = cs.raw_add(|| {
let fin = a_squared.and_then(|a2| self.a.map(|a| a + a2));
Ok((
self.a.ok_or(Error::SynthesisError)?,
@ -298,6 +298,7 @@ fn test_proving() {
fin.ok_or(Error::SynthesisError)?,
))
})?;
cs.copy(a0, a1)?;
cs.copy(b1, c0)?;
}
@ -306,7 +307,7 @@ fn test_proving() {
}
let circuit: MyCircuit<Fp> = MyCircuit {
a: Some((-Fp::from_u64(2) + Fp::ROOT_OF_UNITY).pow(&[100, 0, 0, 0])),
a: Some(Fp::random()),
};
let empty_circuit: MyCircuit<Fp> = MyCircuit { a: None };

View File

@ -15,7 +15,9 @@ impl<C: CurveAffine> SRS<C> {
) -> Result<Self, Error> {
struct Assembly<F: Field> {
fixed: Vec<Vec<F>>,
copy: Vec<Vec<Vec<(usize, usize)>>>,
mapping: Vec<Vec<Vec<(usize, usize)>>>,
aux: Vec<Vec<Vec<(usize, usize)>>>,
sizes: Vec<Vec<Vec<usize>>>,
}
impl<F: Field> ConstraintSystem<F> for Assembly<F> {
@ -52,25 +54,45 @@ impl<C: CurveAffine> SRS<C> {
right_wire: usize,
right_row: usize,
) -> Result<(), Error> {
let left: (usize, usize) = *self.copy[permutation]
.get_mut(left_wire)
.and_then(|wire| wire.get_mut(left_row))
.ok_or(Error::BoundsFailure)?;
let right: (usize, usize) = *self.copy[permutation]
.get_mut(right_wire)
.and_then(|wire| wire.get_mut(right_row))
.ok_or(Error::BoundsFailure)?;
if left == (left_wire, left_row) || right == (right_wire, right_row) {
// Don't perform the copy constraint because it will undo
// the effect of the permutation.
} else {
self.copy[permutation][left_wire][left_row] = right;
self.copy[permutation][right_wire][right_row] = left;
// Check bounds first
if permutation >= self.mapping.len()
|| left_wire >= self.mapping[permutation].len()
|| left_row >= self.mapping[permutation][left_wire].len()
|| right_wire >= self.mapping[permutation].len()
|| right_row >= self.mapping[permutation][right_wire].len()
{
return Err(Error::BoundsFailure);
}
let mut left_cycle = self.aux[permutation][left_wire][left_row];
let mut right_cycle = self.aux[permutation][right_wire][right_row];
if left_cycle == right_cycle {
return Ok(());
}
if self.sizes[permutation][left_cycle.0][left_cycle.1]
< self.sizes[permutation][right_cycle.0][right_cycle.1]
{
std::mem::swap(&mut left_cycle, &mut right_cycle);
}
self.sizes[permutation][left_cycle.0][left_cycle.1] +=
self.sizes[permutation][right_cycle.0][right_cycle.1];
let mut i = right_cycle;
loop {
self.aux[permutation][i.0][i.1] = left_cycle;
i = self.mapping[permutation][i.0][i.1];
if i == right_cycle {
break;
}
}
let tmp = self.mapping[permutation][left_wire][left_row];
self.mapping[permutation][left_wire][left_row] =
self.mapping[permutation][right_wire][right_row];
self.mapping[permutation][right_wire][right_row] = tmp;
Ok(())
}
}
@ -126,7 +148,9 @@ impl<C: CurveAffine> SRS<C> {
let mut assembly: Assembly<C::Scalar> = Assembly {
fixed: vec![vec![C::Scalar::zero(); params.n as usize]; meta.num_fixed_wires],
copy: vec![],
mapping: vec![],
aux: vec![],
sizes: vec![],
};
// Initialize the copy vector to keep track of copy constraints in all
@ -137,7 +161,11 @@ impl<C: CurveAffine> SRS<C> {
// Computes [(i, 0), (i, 1), ..., (i, n - 1)]
wires.push((0..params.n).map(|j| (i, j as usize)).collect());
}
assembly.copy.push(wires);
assembly.mapping.push(wires.clone());
assembly.aux.push(wires);
assembly
.sizes
.push(vec![vec![1usize; params.n as usize]; permutation.len()]);
}
// Synthesize the circuit to obtain SRS
@ -162,7 +190,7 @@ impl<C: CurveAffine> SRS<C> {
// assembly.copy[permutation_index] is indexed by wire
// i, and then indexed by row j, obtaining the index of
// the permuted value in deltaomega.
let (permuted_i, permuted_j) = assembly.copy[permutation_index][i][j];
let (permuted_i, permuted_j) = assembly.mapping[permutation_index][i][j];
deltaomega[permuted_i][permuted_j]
})
.collect();