Optimize UInt32::addmany/BLAKE2s to combine equality constraints. (Closes #5)
This commit is contained in:
parent
827e85547e
commit
52eb59766b
|
@ -15,6 +15,8 @@ use super::uint32::{
|
||||||
UInt32
|
UInt32
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use super::multieq::MultiEq;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
2.1. Parameters
|
2.1. Parameters
|
||||||
The following table summarizes various parameters and their ranges:
|
The following table summarizes various parameters and their ranges:
|
||||||
|
@ -88,8 +90,8 @@ const SIGMA: [[usize; 16]; 10] = [
|
||||||
END FUNCTION.
|
END FUNCTION.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
fn mixing_g<E: Engine, CS: ConstraintSystem<E>>(
|
fn mixing_g<E: Engine, CS: ConstraintSystem<E>, M>(
|
||||||
mut cs: CS,
|
mut cs: M,
|
||||||
v: &mut [UInt32],
|
v: &mut [UInt32],
|
||||||
a: usize,
|
a: usize,
|
||||||
b: usize,
|
b: usize,
|
||||||
|
@ -98,6 +100,7 @@ fn mixing_g<E: Engine, CS: ConstraintSystem<E>>(
|
||||||
x: &UInt32,
|
x: &UInt32,
|
||||||
y: &UInt32
|
y: &UInt32
|
||||||
) -> Result<(), SynthesisError>
|
) -> Result<(), SynthesisError>
|
||||||
|
where M: ConstraintSystem<E, Root=MultiEq<E, CS>>
|
||||||
{
|
{
|
||||||
v[a] = UInt32::addmany(cs.namespace(|| "mixing step 1"), &[v[a].clone(), v[b].clone(), x.clone()])?;
|
v[a] = UInt32::addmany(cs.namespace(|| "mixing step 1"), &[v[a].clone(), v[b].clone(), x.clone()])?;
|
||||||
v[d] = v[d].xor(cs.namespace(|| "mixing step 2"), &v[a])?.rotr(R1);
|
v[d] = v[d].xor(cs.namespace(|| "mixing step 2"), &v[a])?.rotr(R1);
|
||||||
|
@ -199,20 +202,24 @@ fn blake2s_compression<E: Engine, CS: ConstraintSystem<E>>(
|
||||||
v[14] = v[14].xor(cs.namespace(|| "third xor"), &UInt32::constant(u32::max_value()))?;
|
v[14] = v[14].xor(cs.namespace(|| "third xor"), &UInt32::constant(u32::max_value()))?;
|
||||||
}
|
}
|
||||||
|
|
||||||
for i in 0..10 {
|
{
|
||||||
let mut cs = cs.namespace(|| format!("round {}", i));
|
let mut cs = MultiEq::new(&mut cs);
|
||||||
|
|
||||||
let s = SIGMA[i % 10];
|
for i in 0..10 {
|
||||||
|
let mut cs = cs.namespace(|| format!("round {}", i));
|
||||||
|
|
||||||
mixing_g(cs.namespace(|| "mixing invocation 1"), &mut v, 0, 4, 8, 12, &m[s[ 0]], &m[s[ 1]])?;
|
let s = SIGMA[i % 10];
|
||||||
mixing_g(cs.namespace(|| "mixing invocation 2"), &mut v, 1, 5, 9, 13, &m[s[ 2]], &m[s[ 3]])?;
|
|
||||||
mixing_g(cs.namespace(|| "mixing invocation 3"), &mut v, 2, 6, 10, 14, &m[s[ 4]], &m[s[ 5]])?;
|
|
||||||
mixing_g(cs.namespace(|| "mixing invocation 4"), &mut v, 3, 7, 11, 15, &m[s[ 6]], &m[s[ 7]])?;
|
|
||||||
|
|
||||||
mixing_g(cs.namespace(|| "mixing invocation 5"), &mut v, 0, 5, 10, 15, &m[s[ 8]], &m[s[ 9]])?;
|
mixing_g(cs.namespace(|| "mixing invocation 1"), &mut v, 0, 4, 8, 12, &m[s[ 0]], &m[s[ 1]])?;
|
||||||
mixing_g(cs.namespace(|| "mixing invocation 6"), &mut v, 1, 6, 11, 12, &m[s[10]], &m[s[11]])?;
|
mixing_g(cs.namespace(|| "mixing invocation 2"), &mut v, 1, 5, 9, 13, &m[s[ 2]], &m[s[ 3]])?;
|
||||||
mixing_g(cs.namespace(|| "mixing invocation 7"), &mut v, 2, 7, 8, 13, &m[s[12]], &m[s[13]])?;
|
mixing_g(cs.namespace(|| "mixing invocation 3"), &mut v, 2, 6, 10, 14, &m[s[ 4]], &m[s[ 5]])?;
|
||||||
mixing_g(cs.namespace(|| "mixing invocation 8"), &mut v, 3, 4, 9, 14, &m[s[14]], &m[s[15]])?;
|
mixing_g(cs.namespace(|| "mixing invocation 4"), &mut v, 3, 7, 11, 15, &m[s[ 6]], &m[s[ 7]])?;
|
||||||
|
|
||||||
|
mixing_g(cs.namespace(|| "mixing invocation 5"), &mut v, 0, 5, 10, 15, &m[s[ 8]], &m[s[ 9]])?;
|
||||||
|
mixing_g(cs.namespace(|| "mixing invocation 6"), &mut v, 1, 6, 11, 12, &m[s[10]], &m[s[11]])?;
|
||||||
|
mixing_g(cs.namespace(|| "mixing invocation 7"), &mut v, 2, 7, 8, 13, &m[s[12]], &m[s[13]])?;
|
||||||
|
mixing_g(cs.namespace(|| "mixing invocation 8"), &mut v, 3, 4, 9, 14, &m[s[14]], &m[s[15]])?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for i in 0..8 {
|
for i in 0..8 {
|
||||||
|
@ -350,7 +357,7 @@ mod test {
|
||||||
let input_bits: Vec<_> = (0..512).map(|i| AllocatedBit::alloc(cs.namespace(|| format!("input bit {}", i)), Some(true)).unwrap().into()).collect();
|
let input_bits: Vec<_> = (0..512).map(|i| AllocatedBit::alloc(cs.namespace(|| format!("input bit {}", i)), Some(true)).unwrap().into()).collect();
|
||||||
blake2s(&mut cs, &input_bits, b"12345678").unwrap();
|
blake2s(&mut cs, &input_bits, b"12345678").unwrap();
|
||||||
assert!(cs.is_satisfied());
|
assert!(cs.is_satisfied());
|
||||||
assert_eq!(cs.num_constraints(), 21792);
|
assert_eq!(cs.num_constraints(), 21518);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -367,7 +374,7 @@ mod test {
|
||||||
.collect();
|
.collect();
|
||||||
blake2s(&mut cs, &input_bits, b"12345678").unwrap();
|
blake2s(&mut cs, &input_bits, b"12345678").unwrap();
|
||||||
assert!(cs.is_satisfied());
|
assert!(cs.is_satisfied());
|
||||||
assert_eq!(cs.num_constraints(), 21792);
|
assert_eq!(cs.num_constraints(), 21518);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
@ -367,6 +367,13 @@ pub enum Boolean {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Boolean {
|
impl Boolean {
|
||||||
|
pub fn is_constant(&self) -> bool {
|
||||||
|
match *self {
|
||||||
|
Boolean::Constant(_) => true,
|
||||||
|
_ => false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn enforce_equal<E, CS>(
|
pub fn enforce_equal<E, CS>(
|
||||||
mut cs: CS,
|
mut cs: CS,
|
||||||
a: &Self,
|
a: &Self,
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
pub mod test;
|
pub mod test;
|
||||||
|
|
||||||
pub mod boolean;
|
pub mod boolean;
|
||||||
|
pub mod multieq;
|
||||||
pub mod uint32;
|
pub mod uint32;
|
||||||
pub mod blake2s;
|
pub mod blake2s;
|
||||||
pub mod num;
|
pub mod num;
|
||||||
|
@ -627,8 +628,8 @@ fn test_input_circuit_with_bls12_381() {
|
||||||
instance.synthesize(&mut cs).unwrap();
|
instance.synthesize(&mut cs).unwrap();
|
||||||
|
|
||||||
assert!(cs.is_satisfied());
|
assert!(cs.is_satisfied());
|
||||||
assert_eq!(cs.num_constraints(), 101566);
|
assert_eq!(cs.num_constraints(), 101018);
|
||||||
assert_eq!(cs.hash(), "e3d226975c99e17ef30f5a4b7e87d355ef3dbd80eed0c8de43780f3028946d82");
|
assert_eq!(cs.hash(), "eedcef5fd638e0168ae4d53ac58df66f0acdabea46749cc5f4b39459c8377804");
|
||||||
|
|
||||||
let expected_value_cm = value_commitment.cm(params).into_xy();
|
let expected_value_cm = value_commitment.cm(params).into_xy();
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,137 @@
|
||||||
|
use pairing::{
|
||||||
|
Engine,
|
||||||
|
Field,
|
||||||
|
PrimeField
|
||||||
|
};
|
||||||
|
|
||||||
|
use bellman::{
|
||||||
|
SynthesisError,
|
||||||
|
ConstraintSystem,
|
||||||
|
LinearCombination,
|
||||||
|
Variable
|
||||||
|
};
|
||||||
|
|
||||||
|
pub struct MultiEq<E: Engine, CS: ConstraintSystem<E>>{
|
||||||
|
cs: CS,
|
||||||
|
ops: usize,
|
||||||
|
bits_used: usize,
|
||||||
|
lhs: LinearCombination<E>,
|
||||||
|
rhs: LinearCombination<E>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E: Engine, CS: ConstraintSystem<E>> MultiEq<E, CS> {
|
||||||
|
pub fn new(cs: CS) -> Self {
|
||||||
|
MultiEq {
|
||||||
|
cs: cs,
|
||||||
|
ops: 0,
|
||||||
|
bits_used: 0,
|
||||||
|
lhs: LinearCombination::zero(),
|
||||||
|
rhs: LinearCombination::zero()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn accumulate(&mut self)
|
||||||
|
{
|
||||||
|
let ops = self.ops;
|
||||||
|
let lhs = self.lhs.clone();
|
||||||
|
let rhs = self.rhs.clone();
|
||||||
|
self.cs.enforce(
|
||||||
|
|| format!("multieq {}", ops),
|
||||||
|
|_| lhs,
|
||||||
|
|lc| lc + CS::one(),
|
||||||
|
|_| rhs
|
||||||
|
);
|
||||||
|
self.lhs = LinearCombination::zero();
|
||||||
|
self.rhs = LinearCombination::zero();
|
||||||
|
self.bits_used = 0;
|
||||||
|
self.ops += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn enforce_equal(
|
||||||
|
&mut self,
|
||||||
|
num_bits: usize,
|
||||||
|
lhs: &LinearCombination<E>,
|
||||||
|
rhs: &LinearCombination<E>
|
||||||
|
)
|
||||||
|
{
|
||||||
|
// Check if we will exceed the capacity
|
||||||
|
if (E::Fr::CAPACITY as usize) <= (self.bits_used + num_bits) {
|
||||||
|
self.accumulate();
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!((E::Fr::CAPACITY as usize) > (self.bits_used + num_bits));
|
||||||
|
|
||||||
|
let coeff = E::Fr::from_str("2").unwrap().pow(&[self.bits_used as u64]);
|
||||||
|
self.lhs = self.lhs.clone() + (coeff, lhs);
|
||||||
|
self.rhs = self.rhs.clone() + (coeff, rhs);
|
||||||
|
self.bits_used += num_bits;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E: Engine, CS: ConstraintSystem<E>> Drop for MultiEq<E, CS> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if self.bits_used > 0 {
|
||||||
|
self.accumulate();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E: Engine, CS: ConstraintSystem<E>> ConstraintSystem<E> for MultiEq<E, CS>
|
||||||
|
{
|
||||||
|
type Root = Self;
|
||||||
|
|
||||||
|
fn one() -> Variable {
|
||||||
|
CS::one()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn alloc<F, A, AR>(
|
||||||
|
&mut self,
|
||||||
|
annotation: A,
|
||||||
|
f: F
|
||||||
|
) -> Result<Variable, SynthesisError>
|
||||||
|
where F: FnOnce() -> Result<E::Fr, SynthesisError>, A: FnOnce() -> AR, AR: Into<String>
|
||||||
|
{
|
||||||
|
self.cs.alloc(annotation, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn alloc_input<F, A, AR>(
|
||||||
|
&mut self,
|
||||||
|
annotation: A,
|
||||||
|
f: F
|
||||||
|
) -> Result<Variable, SynthesisError>
|
||||||
|
where F: FnOnce() -> Result<E::Fr, SynthesisError>, A: FnOnce() -> AR, AR: Into<String>
|
||||||
|
{
|
||||||
|
self.cs.alloc_input(annotation, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn enforce<A, AR, LA, LB, LC>(
|
||||||
|
&mut self,
|
||||||
|
annotation: A,
|
||||||
|
a: LA,
|
||||||
|
b: LB,
|
||||||
|
c: LC
|
||||||
|
)
|
||||||
|
where A: FnOnce() -> AR, AR: Into<String>,
|
||||||
|
LA: FnOnce(LinearCombination<E>) -> LinearCombination<E>,
|
||||||
|
LB: FnOnce(LinearCombination<E>) -> LinearCombination<E>,
|
||||||
|
LC: FnOnce(LinearCombination<E>) -> LinearCombination<E>
|
||||||
|
{
|
||||||
|
self.cs.enforce(annotation, a, b, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn push_namespace<NR, N>(&mut self, name_fn: N)
|
||||||
|
where NR: Into<String>, N: FnOnce() -> NR
|
||||||
|
{
|
||||||
|
self.cs.get_root().push_namespace(name_fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pop_namespace(&mut self)
|
||||||
|
{
|
||||||
|
self.cs.get_root().pop_namespace()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_root(&mut self) -> &mut Self::Root
|
||||||
|
{
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
|
@ -15,6 +15,8 @@ use super::boolean::{
|
||||||
AllocatedBit
|
AllocatedBit
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use super::multieq::MultiEq;
|
||||||
|
|
||||||
/// Represents an interpretation of 32 `Boolean` objects as an
|
/// Represents an interpretation of 32 `Boolean` objects as an
|
||||||
/// unsigned integer.
|
/// unsigned integer.
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
|
@ -188,12 +190,13 @@ impl UInt32 {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Perform modular addition of several `UInt32` objects.
|
/// Perform modular addition of several `UInt32` objects.
|
||||||
pub fn addmany<E, CS>(
|
pub fn addmany<E, CS, M>(
|
||||||
mut cs: CS,
|
mut cs: M,
|
||||||
operands: &[Self]
|
operands: &[Self]
|
||||||
) -> Result<Self, SynthesisError>
|
) -> Result<Self, SynthesisError>
|
||||||
where E: Engine,
|
where E: Engine,
|
||||||
CS: ConstraintSystem<E>
|
CS: ConstraintSystem<E>,
|
||||||
|
M: ConstraintSystem<E, Root=MultiEq<E, CS>>
|
||||||
{
|
{
|
||||||
// Make some arbitrary bounds for ourselves to avoid overflows
|
// Make some arbitrary bounds for ourselves to avoid overflows
|
||||||
// in the scalar field
|
// in the scalar field
|
||||||
|
@ -208,7 +211,8 @@ impl UInt32 {
|
||||||
// Keep track of the resulting value
|
// Keep track of the resulting value
|
||||||
let mut result_value = Some(0u64);
|
let mut result_value = Some(0u64);
|
||||||
|
|
||||||
// This is a linear combination that we will enforce to be "zero"
|
// This is a linear combination that we will enforce to equal the
|
||||||
|
// output
|
||||||
let mut lc = LinearCombination::zero();
|
let mut lc = LinearCombination::zero();
|
||||||
|
|
||||||
let mut all_constants = true;
|
let mut all_constants = true;
|
||||||
|
@ -231,25 +235,9 @@ impl UInt32 {
|
||||||
// the linear combination
|
// the linear combination
|
||||||
let mut coeff = E::Fr::one();
|
let mut coeff = E::Fr::one();
|
||||||
for bit in &op.bits {
|
for bit in &op.bits {
|
||||||
match bit {
|
lc = lc + &bit.lc(CS::one(), coeff);
|
||||||
&Boolean::Is(ref bit) => {
|
|
||||||
all_constants = false;
|
|
||||||
|
|
||||||
// Add coeff * bit
|
all_constants &= bit.is_constant();
|
||||||
lc = lc + (coeff, bit.get_variable());
|
|
||||||
},
|
|
||||||
&Boolean::Not(ref bit) => {
|
|
||||||
all_constants = false;
|
|
||||||
|
|
||||||
// Add coeff * (1 - bit) = coeff * ONE - coeff * bit
|
|
||||||
lc = lc + (coeff, CS::one()) - (coeff, bit.get_variable());
|
|
||||||
},
|
|
||||||
&Boolean::Constant(bit) => {
|
|
||||||
if bit {
|
|
||||||
lc = lc + (coeff, CS::one());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
coeff.double();
|
coeff.double();
|
||||||
}
|
}
|
||||||
|
@ -268,6 +256,10 @@ impl UInt32 {
|
||||||
// Storage area for the resulting bits
|
// Storage area for the resulting bits
|
||||||
let mut result_bits = vec![];
|
let mut result_bits = vec![];
|
||||||
|
|
||||||
|
// Linear combination representing the output,
|
||||||
|
// for comparison with the sum of the operands
|
||||||
|
let mut result_lc = LinearCombination::zero();
|
||||||
|
|
||||||
// Allocate each bit of the result
|
// Allocate each bit of the result
|
||||||
let mut coeff = E::Fr::one();
|
let mut coeff = E::Fr::one();
|
||||||
let mut i = 0;
|
let mut i = 0;
|
||||||
|
@ -278,8 +270,8 @@ impl UInt32 {
|
||||||
result_value.map(|v| (v >> i) & 1 == 1)
|
result_value.map(|v| (v >> i) & 1 == 1)
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
// Subtract this bit from the linear combination to ensure the sums balance out
|
// Add this bit to the result combination
|
||||||
lc = lc - (coeff, b.get_variable());
|
result_lc = result_lc + (coeff, b.get_variable());
|
||||||
|
|
||||||
result_bits.push(b.into());
|
result_bits.push(b.into());
|
||||||
|
|
||||||
|
@ -288,13 +280,8 @@ impl UInt32 {
|
||||||
coeff.double();
|
coeff.double();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enforce that the linear combination equals zero
|
// Enforce equality between the sum and result
|
||||||
cs.enforce(
|
cs.get_root().enforce_equal(i, &lc, &result_lc);
|
||||||
|| "modular addition",
|
|
||||||
|lc| lc,
|
|
||||||
|lc| lc,
|
|
||||||
|_| lc
|
|
||||||
);
|
|
||||||
|
|
||||||
// Discard carry bits that we don't care about
|
// Discard carry bits that we don't care about
|
||||||
result_bits.truncate(32);
|
result_bits.truncate(32);
|
||||||
|
@ -315,6 +302,7 @@ mod test {
|
||||||
use pairing::{Field};
|
use pairing::{Field};
|
||||||
use ::circuit::test::*;
|
use ::circuit::test::*;
|
||||||
use bellman::{ConstraintSystem};
|
use bellman::{ConstraintSystem};
|
||||||
|
use circuit::multieq::MultiEq;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_uint32_from_bits() {
|
fn test_uint32_from_bits() {
|
||||||
|
@ -406,7 +394,11 @@ mod test {
|
||||||
|
|
||||||
let mut expected = a.wrapping_add(b).wrapping_add(c);
|
let mut expected = a.wrapping_add(b).wrapping_add(c);
|
||||||
|
|
||||||
let r = UInt32::addmany(cs.namespace(|| "addition"), &[a_bit, b_bit, c_bit]).unwrap();
|
let r = {
|
||||||
|
let mut cs = MultiEq::new(&mut cs);
|
||||||
|
let r = UInt32::addmany(cs.namespace(|| "addition"), &[a_bit, b_bit, c_bit]).unwrap();
|
||||||
|
r
|
||||||
|
};
|
||||||
|
|
||||||
assert!(r.value == Some(expected));
|
assert!(r.value == Some(expected));
|
||||||
|
|
||||||
|
@ -444,7 +436,11 @@ mod test {
|
||||||
let d_bit = UInt32::alloc(cs.namespace(|| "d_bit"), Some(d)).unwrap();
|
let d_bit = UInt32::alloc(cs.namespace(|| "d_bit"), Some(d)).unwrap();
|
||||||
|
|
||||||
let r = a_bit.xor(cs.namespace(|| "xor"), &b_bit).unwrap();
|
let r = a_bit.xor(cs.namespace(|| "xor"), &b_bit).unwrap();
|
||||||
let r = UInt32::addmany(cs.namespace(|| "addition"), &[r, c_bit, d_bit]).unwrap();
|
let r = {
|
||||||
|
let mut cs = MultiEq::new(&mut cs);
|
||||||
|
let r = UInt32::addmany(cs.namespace(|| "addition"), &[r, c_bit, d_bit]).unwrap();
|
||||||
|
r
|
||||||
|
};
|
||||||
|
|
||||||
assert!(cs.is_satisfied());
|
assert!(cs.is_satisfied());
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue