Optimize UInt32::addmany/BLAKE2s to combine equality constraints. (Closes #5)

This commit is contained in:
Sean Bowe 2018-03-15 12:31:10 -06:00
parent 827e85547e
commit 52eb59766b
No known key found for this signature in database
GPG Key ID: 95684257D8F8B031
5 changed files with 198 additions and 50 deletions

View File

@ -15,6 +15,8 @@ use super::uint32::{
UInt32
};
use super::multieq::MultiEq;
/*
2.1. Parameters
The following table summarizes various parameters and their ranges:
@ -88,8 +90,8 @@ const SIGMA: [[usize; 16]; 10] = [
END FUNCTION.
*/
fn mixing_g<E: Engine, CS: ConstraintSystem<E>>(
mut cs: CS,
fn mixing_g<E: Engine, CS: ConstraintSystem<E>, M>(
mut cs: M,
v: &mut [UInt32],
a: usize,
b: usize,
@ -98,6 +100,7 @@ fn mixing_g<E: Engine, CS: ConstraintSystem<E>>(
x: &UInt32,
y: &UInt32
) -> Result<(), SynthesisError>
where M: ConstraintSystem<E, Root=MultiEq<E, CS>>
{
v[a] = UInt32::addmany(cs.namespace(|| "mixing step 1"), &[v[a].clone(), v[b].clone(), x.clone()])?;
v[d] = v[d].xor(cs.namespace(|| "mixing step 2"), &v[a])?.rotr(R1);
@ -199,20 +202,24 @@ fn blake2s_compression<E: Engine, CS: ConstraintSystem<E>>(
v[14] = v[14].xor(cs.namespace(|| "third xor"), &UInt32::constant(u32::max_value()))?;
}
for i in 0..10 {
let mut cs = cs.namespace(|| format!("round {}", i));
{
let mut cs = MultiEq::new(&mut cs);
let s = SIGMA[i % 10];
for i in 0..10 {
let mut cs = cs.namespace(|| format!("round {}", i));
mixing_g(cs.namespace(|| "mixing invocation 1"), &mut v, 0, 4, 8, 12, &m[s[ 0]], &m[s[ 1]])?;
mixing_g(cs.namespace(|| "mixing invocation 2"), &mut v, 1, 5, 9, 13, &m[s[ 2]], &m[s[ 3]])?;
mixing_g(cs.namespace(|| "mixing invocation 3"), &mut v, 2, 6, 10, 14, &m[s[ 4]], &m[s[ 5]])?;
mixing_g(cs.namespace(|| "mixing invocation 4"), &mut v, 3, 7, 11, 15, &m[s[ 6]], &m[s[ 7]])?;
let s = SIGMA[i % 10];
mixing_g(cs.namespace(|| "mixing invocation 5"), &mut v, 0, 5, 10, 15, &m[s[ 8]], &m[s[ 9]])?;
mixing_g(cs.namespace(|| "mixing invocation 6"), &mut v, 1, 6, 11, 12, &m[s[10]], &m[s[11]])?;
mixing_g(cs.namespace(|| "mixing invocation 7"), &mut v, 2, 7, 8, 13, &m[s[12]], &m[s[13]])?;
mixing_g(cs.namespace(|| "mixing invocation 8"), &mut v, 3, 4, 9, 14, &m[s[14]], &m[s[15]])?;
mixing_g(cs.namespace(|| "mixing invocation 1"), &mut v, 0, 4, 8, 12, &m[s[ 0]], &m[s[ 1]])?;
mixing_g(cs.namespace(|| "mixing invocation 2"), &mut v, 1, 5, 9, 13, &m[s[ 2]], &m[s[ 3]])?;
mixing_g(cs.namespace(|| "mixing invocation 3"), &mut v, 2, 6, 10, 14, &m[s[ 4]], &m[s[ 5]])?;
mixing_g(cs.namespace(|| "mixing invocation 4"), &mut v, 3, 7, 11, 15, &m[s[ 6]], &m[s[ 7]])?;
mixing_g(cs.namespace(|| "mixing invocation 5"), &mut v, 0, 5, 10, 15, &m[s[ 8]], &m[s[ 9]])?;
mixing_g(cs.namespace(|| "mixing invocation 6"), &mut v, 1, 6, 11, 12, &m[s[10]], &m[s[11]])?;
mixing_g(cs.namespace(|| "mixing invocation 7"), &mut v, 2, 7, 8, 13, &m[s[12]], &m[s[13]])?;
mixing_g(cs.namespace(|| "mixing invocation 8"), &mut v, 3, 4, 9, 14, &m[s[14]], &m[s[15]])?;
}
}
for i in 0..8 {
@ -350,7 +357,7 @@ mod test {
let input_bits: Vec<_> = (0..512).map(|i| AllocatedBit::alloc(cs.namespace(|| format!("input bit {}", i)), Some(true)).unwrap().into()).collect();
blake2s(&mut cs, &input_bits, b"12345678").unwrap();
assert!(cs.is_satisfied());
assert_eq!(cs.num_constraints(), 21792);
assert_eq!(cs.num_constraints(), 21518);
}
#[test]
@ -367,7 +374,7 @@ mod test {
.collect();
blake2s(&mut cs, &input_bits, b"12345678").unwrap();
assert!(cs.is_satisfied());
assert_eq!(cs.num_constraints(), 21792);
assert_eq!(cs.num_constraints(), 21518);
}
#[test]

View File

@ -367,6 +367,13 @@ pub enum Boolean {
}
impl Boolean {
pub fn is_constant(&self) -> bool {
match *self {
Boolean::Constant(_) => true,
_ => false
}
}
pub fn enforce_equal<E, CS>(
mut cs: CS,
a: &Self,

View File

@ -2,6 +2,7 @@
pub mod test;
pub mod boolean;
pub mod multieq;
pub mod uint32;
pub mod blake2s;
pub mod num;
@ -627,8 +628,8 @@ fn test_input_circuit_with_bls12_381() {
instance.synthesize(&mut cs).unwrap();
assert!(cs.is_satisfied());
assert_eq!(cs.num_constraints(), 101566);
assert_eq!(cs.hash(), "e3d226975c99e17ef30f5a4b7e87d355ef3dbd80eed0c8de43780f3028946d82");
assert_eq!(cs.num_constraints(), 101018);
assert_eq!(cs.hash(), "eedcef5fd638e0168ae4d53ac58df66f0acdabea46749cc5f4b39459c8377804");
let expected_value_cm = value_commitment.cm(params).into_xy();

137
src/circuit/multieq.rs Normal file
View File

@ -0,0 +1,137 @@
use pairing::{
Engine,
Field,
PrimeField
};
use bellman::{
SynthesisError,
ConstraintSystem,
LinearCombination,
Variable
};
pub struct MultiEq<E: Engine, CS: ConstraintSystem<E>>{
cs: CS,
ops: usize,
bits_used: usize,
lhs: LinearCombination<E>,
rhs: LinearCombination<E>,
}
impl<E: Engine, CS: ConstraintSystem<E>> MultiEq<E, CS> {
pub fn new(cs: CS) -> Self {
MultiEq {
cs: cs,
ops: 0,
bits_used: 0,
lhs: LinearCombination::zero(),
rhs: LinearCombination::zero()
}
}
fn accumulate(&mut self)
{
let ops = self.ops;
let lhs = self.lhs.clone();
let rhs = self.rhs.clone();
self.cs.enforce(
|| format!("multieq {}", ops),
|_| lhs,
|lc| lc + CS::one(),
|_| rhs
);
self.lhs = LinearCombination::zero();
self.rhs = LinearCombination::zero();
self.bits_used = 0;
self.ops += 1;
}
pub fn enforce_equal(
&mut self,
num_bits: usize,
lhs: &LinearCombination<E>,
rhs: &LinearCombination<E>
)
{
// Check if we will exceed the capacity
if (E::Fr::CAPACITY as usize) <= (self.bits_used + num_bits) {
self.accumulate();
}
assert!((E::Fr::CAPACITY as usize) > (self.bits_used + num_bits));
let coeff = E::Fr::from_str("2").unwrap().pow(&[self.bits_used as u64]);
self.lhs = self.lhs.clone() + (coeff, lhs);
self.rhs = self.rhs.clone() + (coeff, rhs);
self.bits_used += num_bits;
}
}
impl<E: Engine, CS: ConstraintSystem<E>> Drop for MultiEq<E, CS> {
fn drop(&mut self) {
if self.bits_used > 0 {
self.accumulate();
}
}
}
impl<E: Engine, CS: ConstraintSystem<E>> ConstraintSystem<E> for MultiEq<E, CS>
{
type Root = Self;
fn one() -> Variable {
CS::one()
}
fn alloc<F, A, AR>(
&mut self,
annotation: A,
f: F
) -> Result<Variable, SynthesisError>
where F: FnOnce() -> Result<E::Fr, SynthesisError>, A: FnOnce() -> AR, AR: Into<String>
{
self.cs.alloc(annotation, f)
}
fn alloc_input<F, A, AR>(
&mut self,
annotation: A,
f: F
) -> Result<Variable, SynthesisError>
where F: FnOnce() -> Result<E::Fr, SynthesisError>, A: FnOnce() -> AR, AR: Into<String>
{
self.cs.alloc_input(annotation, f)
}
fn enforce<A, AR, LA, LB, LC>(
&mut self,
annotation: A,
a: LA,
b: LB,
c: LC
)
where A: FnOnce() -> AR, AR: Into<String>,
LA: FnOnce(LinearCombination<E>) -> LinearCombination<E>,
LB: FnOnce(LinearCombination<E>) -> LinearCombination<E>,
LC: FnOnce(LinearCombination<E>) -> LinearCombination<E>
{
self.cs.enforce(annotation, a, b, c)
}
fn push_namespace<NR, N>(&mut self, name_fn: N)
where NR: Into<String>, N: FnOnce() -> NR
{
self.cs.get_root().push_namespace(name_fn)
}
fn pop_namespace(&mut self)
{
self.cs.get_root().pop_namespace()
}
fn get_root(&mut self) -> &mut Self::Root
{
self
}
}

View File

@ -15,6 +15,8 @@ use super::boolean::{
AllocatedBit
};
use super::multieq::MultiEq;
/// Represents an interpretation of 32 `Boolean` objects as an
/// unsigned integer.
#[derive(Clone)]
@ -188,12 +190,13 @@ impl UInt32 {
}
/// Perform modular addition of several `UInt32` objects.
pub fn addmany<E, CS>(
mut cs: CS,
pub fn addmany<E, CS, M>(
mut cs: M,
operands: &[Self]
) -> Result<Self, SynthesisError>
where E: Engine,
CS: ConstraintSystem<E>
CS: ConstraintSystem<E>,
M: ConstraintSystem<E, Root=MultiEq<E, CS>>
{
// Make some arbitrary bounds for ourselves to avoid overflows
// in the scalar field
@ -208,7 +211,8 @@ impl UInt32 {
// Keep track of the resulting value
let mut result_value = Some(0u64);
// This is a linear combination that we will enforce to be "zero"
// This is a linear combination that we will enforce to equal the
// output
let mut lc = LinearCombination::zero();
let mut all_constants = true;
@ -231,25 +235,9 @@ impl UInt32 {
// the linear combination
let mut coeff = E::Fr::one();
for bit in &op.bits {
match bit {
&Boolean::Is(ref bit) => {
all_constants = false;
lc = lc + &bit.lc(CS::one(), coeff);
// Add coeff * bit
lc = lc + (coeff, bit.get_variable());
},
&Boolean::Not(ref bit) => {
all_constants = false;
// Add coeff * (1 - bit) = coeff * ONE - coeff * bit
lc = lc + (coeff, CS::one()) - (coeff, bit.get_variable());
},
&Boolean::Constant(bit) => {
if bit {
lc = lc + (coeff, CS::one());
}
}
}
all_constants &= bit.is_constant();
coeff.double();
}
@ -268,6 +256,10 @@ impl UInt32 {
// Storage area for the resulting bits
let mut result_bits = vec![];
// Linear combination representing the output,
// for comparison with the sum of the operands
let mut result_lc = LinearCombination::zero();
// Allocate each bit of the result
let mut coeff = E::Fr::one();
let mut i = 0;
@ -278,8 +270,8 @@ impl UInt32 {
result_value.map(|v| (v >> i) & 1 == 1)
)?;
// Subtract this bit from the linear combination to ensure the sums balance out
lc = lc - (coeff, b.get_variable());
// Add this bit to the result combination
result_lc = result_lc + (coeff, b.get_variable());
result_bits.push(b.into());
@ -288,13 +280,8 @@ impl UInt32 {
coeff.double();
}
// Enforce that the linear combination equals zero
cs.enforce(
|| "modular addition",
|lc| lc,
|lc| lc,
|_| lc
);
// Enforce equality between the sum and result
cs.get_root().enforce_equal(i, &lc, &result_lc);
// Discard carry bits that we don't care about
result_bits.truncate(32);
@ -315,6 +302,7 @@ mod test {
use pairing::{Field};
use ::circuit::test::*;
use bellman::{ConstraintSystem};
use circuit::multieq::MultiEq;
#[test]
fn test_uint32_from_bits() {
@ -406,7 +394,11 @@ mod test {
let mut expected = a.wrapping_add(b).wrapping_add(c);
let r = UInt32::addmany(cs.namespace(|| "addition"), &[a_bit, b_bit, c_bit]).unwrap();
let r = {
let mut cs = MultiEq::new(&mut cs);
let r = UInt32::addmany(cs.namespace(|| "addition"), &[a_bit, b_bit, c_bit]).unwrap();
r
};
assert!(r.value == Some(expected));
@ -444,7 +436,11 @@ mod test {
let d_bit = UInt32::alloc(cs.namespace(|| "d_bit"), Some(d)).unwrap();
let r = a_bit.xor(cs.namespace(|| "xor"), &b_bit).unwrap();
let r = UInt32::addmany(cs.namespace(|| "addition"), &[r, c_bit, d_bit]).unwrap();
let r = {
let mut cs = MultiEq::new(&mut cs);
let r = UInt32::addmany(cs.namespace(|| "addition"), &[r, c_bit, d_bit]).unwrap();
r
};
assert!(cs.is_satisfied());