From 52eb59766b4ec4205830d4d2619c5988d3a3985a Mon Sep 17 00:00:00 2001 From: Sean Bowe Date: Thu, 15 Mar 2018 12:31:10 -0600 Subject: [PATCH] Optimize UInt32::addmany/BLAKE2s to combine equality constraints. (Closes #5) --- src/circuit/blake2s.rs | 37 ++++++----- src/circuit/boolean.rs | 7 +++ src/circuit/mod.rs | 5 +- src/circuit/multieq.rs | 137 +++++++++++++++++++++++++++++++++++++++++ src/circuit/uint32.rs | 62 +++++++++---------- 5 files changed, 198 insertions(+), 50 deletions(-) create mode 100644 src/circuit/multieq.rs diff --git a/src/circuit/blake2s.rs b/src/circuit/blake2s.rs index f75f5c7..a78f3cf 100644 --- a/src/circuit/blake2s.rs +++ b/src/circuit/blake2s.rs @@ -15,6 +15,8 @@ use super::uint32::{ UInt32 }; +use super::multieq::MultiEq; + /* 2.1. Parameters The following table summarizes various parameters and their ranges: @@ -88,8 +90,8 @@ const SIGMA: [[usize; 16]; 10] = [ END FUNCTION. */ -fn mixing_g>( - mut cs: CS, +fn mixing_g, M>( + mut cs: M, v: &mut [UInt32], a: usize, b: usize, @@ -98,6 +100,7 @@ fn mixing_g>( x: &UInt32, y: &UInt32 ) -> Result<(), SynthesisError> + where M: ConstraintSystem> { v[a] = UInt32::addmany(cs.namespace(|| "mixing step 1"), &[v[a].clone(), v[b].clone(), x.clone()])?; v[d] = v[d].xor(cs.namespace(|| "mixing step 2"), &v[a])?.rotr(R1); @@ -199,20 +202,24 @@ fn blake2s_compression>( v[14] = v[14].xor(cs.namespace(|| "third xor"), &UInt32::constant(u32::max_value()))?; } - for i in 0..10 { - let mut cs = cs.namespace(|| format!("round {}", i)); + { + let mut cs = MultiEq::new(&mut cs); - let s = SIGMA[i % 10]; + for i in 0..10 { + let mut cs = cs.namespace(|| format!("round {}", i)); - mixing_g(cs.namespace(|| "mixing invocation 1"), &mut v, 0, 4, 8, 12, &m[s[ 0]], &m[s[ 1]])?; - mixing_g(cs.namespace(|| "mixing invocation 2"), &mut v, 1, 5, 9, 13, &m[s[ 2]], &m[s[ 3]])?; - mixing_g(cs.namespace(|| "mixing invocation 3"), &mut v, 2, 6, 10, 14, &m[s[ 4]], &m[s[ 5]])?; - mixing_g(cs.namespace(|| "mixing invocation 4"), &mut v, 3, 7, 11, 15, &m[s[ 6]], &m[s[ 7]])?; + let s = SIGMA[i % 10]; - mixing_g(cs.namespace(|| "mixing invocation 5"), &mut v, 0, 5, 10, 15, &m[s[ 8]], &m[s[ 9]])?; - mixing_g(cs.namespace(|| "mixing invocation 6"), &mut v, 1, 6, 11, 12, &m[s[10]], &m[s[11]])?; - mixing_g(cs.namespace(|| "mixing invocation 7"), &mut v, 2, 7, 8, 13, &m[s[12]], &m[s[13]])?; - mixing_g(cs.namespace(|| "mixing invocation 8"), &mut v, 3, 4, 9, 14, &m[s[14]], &m[s[15]])?; + mixing_g(cs.namespace(|| "mixing invocation 1"), &mut v, 0, 4, 8, 12, &m[s[ 0]], &m[s[ 1]])?; + mixing_g(cs.namespace(|| "mixing invocation 2"), &mut v, 1, 5, 9, 13, &m[s[ 2]], &m[s[ 3]])?; + mixing_g(cs.namespace(|| "mixing invocation 3"), &mut v, 2, 6, 10, 14, &m[s[ 4]], &m[s[ 5]])?; + mixing_g(cs.namespace(|| "mixing invocation 4"), &mut v, 3, 7, 11, 15, &m[s[ 6]], &m[s[ 7]])?; + + mixing_g(cs.namespace(|| "mixing invocation 5"), &mut v, 0, 5, 10, 15, &m[s[ 8]], &m[s[ 9]])?; + mixing_g(cs.namespace(|| "mixing invocation 6"), &mut v, 1, 6, 11, 12, &m[s[10]], &m[s[11]])?; + mixing_g(cs.namespace(|| "mixing invocation 7"), &mut v, 2, 7, 8, 13, &m[s[12]], &m[s[13]])?; + mixing_g(cs.namespace(|| "mixing invocation 8"), &mut v, 3, 4, 9, 14, &m[s[14]], &m[s[15]])?; + } } for i in 0..8 { @@ -350,7 +357,7 @@ mod test { let input_bits: Vec<_> = (0..512).map(|i| AllocatedBit::alloc(cs.namespace(|| format!("input bit {}", i)), Some(true)).unwrap().into()).collect(); blake2s(&mut cs, &input_bits, b"12345678").unwrap(); assert!(cs.is_satisfied()); - assert_eq!(cs.num_constraints(), 21792); + assert_eq!(cs.num_constraints(), 21518); } #[test] @@ -367,7 +374,7 @@ mod test { .collect(); blake2s(&mut cs, &input_bits, b"12345678").unwrap(); assert!(cs.is_satisfied()); - assert_eq!(cs.num_constraints(), 21792); + assert_eq!(cs.num_constraints(), 21518); } #[test] diff --git a/src/circuit/boolean.rs b/src/circuit/boolean.rs index 239d404..e5fa435 100644 --- a/src/circuit/boolean.rs +++ b/src/circuit/boolean.rs @@ -367,6 +367,13 @@ pub enum Boolean { } impl Boolean { + pub fn is_constant(&self) -> bool { + match *self { + Boolean::Constant(_) => true, + _ => false + } + } + pub fn enforce_equal( mut cs: CS, a: &Self, diff --git a/src/circuit/mod.rs b/src/circuit/mod.rs index f12ad9e..f928820 100644 --- a/src/circuit/mod.rs +++ b/src/circuit/mod.rs @@ -2,6 +2,7 @@ pub mod test; pub mod boolean; +pub mod multieq; pub mod uint32; pub mod blake2s; pub mod num; @@ -627,8 +628,8 @@ fn test_input_circuit_with_bls12_381() { instance.synthesize(&mut cs).unwrap(); assert!(cs.is_satisfied()); - assert_eq!(cs.num_constraints(), 101566); - assert_eq!(cs.hash(), "e3d226975c99e17ef30f5a4b7e87d355ef3dbd80eed0c8de43780f3028946d82"); + assert_eq!(cs.num_constraints(), 101018); + assert_eq!(cs.hash(), "eedcef5fd638e0168ae4d53ac58df66f0acdabea46749cc5f4b39459c8377804"); let expected_value_cm = value_commitment.cm(params).into_xy(); diff --git a/src/circuit/multieq.rs b/src/circuit/multieq.rs new file mode 100644 index 0000000..0f9c755 --- /dev/null +++ b/src/circuit/multieq.rs @@ -0,0 +1,137 @@ +use pairing::{ + Engine, + Field, + PrimeField +}; + +use bellman::{ + SynthesisError, + ConstraintSystem, + LinearCombination, + Variable +}; + +pub struct MultiEq>{ + cs: CS, + ops: usize, + bits_used: usize, + lhs: LinearCombination, + rhs: LinearCombination, +} + +impl> MultiEq { + pub fn new(cs: CS) -> Self { + MultiEq { + cs: cs, + ops: 0, + bits_used: 0, + lhs: LinearCombination::zero(), + rhs: LinearCombination::zero() + } + } + + fn accumulate(&mut self) + { + let ops = self.ops; + let lhs = self.lhs.clone(); + let rhs = self.rhs.clone(); + self.cs.enforce( + || format!("multieq {}", ops), + |_| lhs, + |lc| lc + CS::one(), + |_| rhs + ); + self.lhs = LinearCombination::zero(); + self.rhs = LinearCombination::zero(); + self.bits_used = 0; + self.ops += 1; + } + + pub fn enforce_equal( + &mut self, + num_bits: usize, + lhs: &LinearCombination, + rhs: &LinearCombination + ) + { + // Check if we will exceed the capacity + if (E::Fr::CAPACITY as usize) <= (self.bits_used + num_bits) { + self.accumulate(); + } + + assert!((E::Fr::CAPACITY as usize) > (self.bits_used + num_bits)); + + let coeff = E::Fr::from_str("2").unwrap().pow(&[self.bits_used as u64]); + self.lhs = self.lhs.clone() + (coeff, lhs); + self.rhs = self.rhs.clone() + (coeff, rhs); + self.bits_used += num_bits; + } +} + +impl> Drop for MultiEq { + fn drop(&mut self) { + if self.bits_used > 0 { + self.accumulate(); + } + } +} + +impl> ConstraintSystem for MultiEq +{ + type Root = Self; + + fn one() -> Variable { + CS::one() + } + + fn alloc( + &mut self, + annotation: A, + f: F + ) -> Result + where F: FnOnce() -> Result, A: FnOnce() -> AR, AR: Into + { + self.cs.alloc(annotation, f) + } + + fn alloc_input( + &mut self, + annotation: A, + f: F + ) -> Result + where F: FnOnce() -> Result, A: FnOnce() -> AR, AR: Into + { + self.cs.alloc_input(annotation, f) + } + + fn enforce( + &mut self, + annotation: A, + a: LA, + b: LB, + c: LC + ) + where A: FnOnce() -> AR, AR: Into, + LA: FnOnce(LinearCombination) -> LinearCombination, + LB: FnOnce(LinearCombination) -> LinearCombination, + LC: FnOnce(LinearCombination) -> LinearCombination + { + self.cs.enforce(annotation, a, b, c) + } + + fn push_namespace(&mut self, name_fn: N) + where NR: Into, N: FnOnce() -> NR + { + self.cs.get_root().push_namespace(name_fn) + } + + fn pop_namespace(&mut self) + { + self.cs.get_root().pop_namespace() + } + + fn get_root(&mut self) -> &mut Self::Root + { + self + } +} diff --git a/src/circuit/uint32.rs b/src/circuit/uint32.rs index ff39433..4714724 100644 --- a/src/circuit/uint32.rs +++ b/src/circuit/uint32.rs @@ -15,6 +15,8 @@ use super::boolean::{ AllocatedBit }; +use super::multieq::MultiEq; + /// Represents an interpretation of 32 `Boolean` objects as an /// unsigned integer. #[derive(Clone)] @@ -188,12 +190,13 @@ impl UInt32 { } /// Perform modular addition of several `UInt32` objects. - pub fn addmany( - mut cs: CS, + pub fn addmany( + mut cs: M, operands: &[Self] ) -> Result where E: Engine, - CS: ConstraintSystem + CS: ConstraintSystem, + M: ConstraintSystem> { // Make some arbitrary bounds for ourselves to avoid overflows // in the scalar field @@ -208,7 +211,8 @@ impl UInt32 { // Keep track of the resulting value let mut result_value = Some(0u64); - // This is a linear combination that we will enforce to be "zero" + // This is a linear combination that we will enforce to equal the + // output let mut lc = LinearCombination::zero(); let mut all_constants = true; @@ -231,25 +235,9 @@ impl UInt32 { // the linear combination let mut coeff = E::Fr::one(); for bit in &op.bits { - match bit { - &Boolean::Is(ref bit) => { - all_constants = false; + lc = lc + &bit.lc(CS::one(), coeff); - // Add coeff * bit - lc = lc + (coeff, bit.get_variable()); - }, - &Boolean::Not(ref bit) => { - all_constants = false; - - // Add coeff * (1 - bit) = coeff * ONE - coeff * bit - lc = lc + (coeff, CS::one()) - (coeff, bit.get_variable()); - }, - &Boolean::Constant(bit) => { - if bit { - lc = lc + (coeff, CS::one()); - } - } - } + all_constants &= bit.is_constant(); coeff.double(); } @@ -268,6 +256,10 @@ impl UInt32 { // Storage area for the resulting bits let mut result_bits = vec![]; + // Linear combination representing the output, + // for comparison with the sum of the operands + let mut result_lc = LinearCombination::zero(); + // Allocate each bit of the result let mut coeff = E::Fr::one(); let mut i = 0; @@ -278,8 +270,8 @@ impl UInt32 { result_value.map(|v| (v >> i) & 1 == 1) )?; - // Subtract this bit from the linear combination to ensure the sums balance out - lc = lc - (coeff, b.get_variable()); + // Add this bit to the result combination + result_lc = result_lc + (coeff, b.get_variable()); result_bits.push(b.into()); @@ -288,13 +280,8 @@ impl UInt32 { coeff.double(); } - // Enforce that the linear combination equals zero - cs.enforce( - || "modular addition", - |lc| lc, - |lc| lc, - |_| lc - ); + // Enforce equality between the sum and result + cs.get_root().enforce_equal(i, &lc, &result_lc); // Discard carry bits that we don't care about result_bits.truncate(32); @@ -315,6 +302,7 @@ mod test { use pairing::{Field}; use ::circuit::test::*; use bellman::{ConstraintSystem}; + use circuit::multieq::MultiEq; #[test] fn test_uint32_from_bits() { @@ -406,7 +394,11 @@ mod test { let mut expected = a.wrapping_add(b).wrapping_add(c); - let r = UInt32::addmany(cs.namespace(|| "addition"), &[a_bit, b_bit, c_bit]).unwrap(); + let r = { + let mut cs = MultiEq::new(&mut cs); + let r = UInt32::addmany(cs.namespace(|| "addition"), &[a_bit, b_bit, c_bit]).unwrap(); + r + }; assert!(r.value == Some(expected)); @@ -444,7 +436,11 @@ mod test { let d_bit = UInt32::alloc(cs.namespace(|| "d_bit"), Some(d)).unwrap(); let r = a_bit.xor(cs.namespace(|| "xor"), &b_bit).unwrap(); - let r = UInt32::addmany(cs.namespace(|| "addition"), &[r, c_bit, d_bit]).unwrap(); + let r = { + let mut cs = MultiEq::new(&mut cs); + let r = UInt32::addmany(cs.namespace(|| "addition"), &[r, c_bit, d_bit]).unwrap(); + r + }; assert!(cs.is_satisfied());