Refactor bit implementation (small improvements in number of constraints)

This commit is contained in:
Sean Bowe 2016-01-10 03:26:54 -07:00
parent b82a2f60f7
commit b4d4331926
6 changed files with 359 additions and 89 deletions

View File

@ -1,25 +1,228 @@
use tinysnark::FieldT; use tinysnark::FieldT;
use std::rc::Rc;
use std::cell::RefCell;
use super::variable::*; use super::variable::*;
use self::Bit::*; use self::Bit::*;
use self::Op::*;
macro_rules! mirror {
($a:pat, $b:pat) => (($a, $b) | ($b, $a))
}
macro_rules! mirror_match {
(@as_expr $e:expr) => {$e};
(@parse
$e:expr, ($($arms:tt)*);
$(,)*
) => {
mirror_match!(@as_expr match $e { $($arms)* })
};
(@parse
$e:expr, $arms:tt;
, $($tail:tt)*
) => {
mirror_match!(@parse $e, $arms; $($tail)*)
};
(@parse
$e:expr, ($($arms:tt)*);
mirror!($a:pat, $b:pat) => $body:expr,
$($tail:tt)*
) => {
mirror_match!(
@parse
$e,
(
$($arms)*
($a, $b) | ($b, $a) => $body,
);
$($tail)*
)
};
(@parse
$e:expr, ($($arms:tt)*);
$pat:pat => $body:expr,
$($tail:tt)*
) => {
mirror_match!(
@parse
$e,
(
$($arms)*
$pat => $body,
);
$($tail)*
)
};
(@parse
$e:expr, ($($arms:tt)*);
$pat:pat => $body:expr,
$($tail:tt)*
) => {
mirror_match!(
@parse
$e,
(
$($arms)*
$pat => $body,
);
$($tail)*
)
};
($e:expr { $($arms:tt)* }) => {
mirror_match!(@parse $e, (); $($arms)*)
};
}
#[derive(Debug, Eq, PartialEq, Copy, Clone)]
enum Op {
And,
Nand,
Xor,
Xnor,
MaterialNonimplication,
MaterialImplication,
Nor,
Or
}
impl Op {
fn not(&self) -> Op {
match *self {
And => Nand,
Nand => And,
Xor => Xnor,
Xnor => Xor,
Nor => Or,
Or => Nor,
MaterialNonimplication => MaterialImplication,
MaterialImplication => MaterialNonimplication
}
}
fn val(&self, a: FieldT, b: FieldT) -> FieldT {
let a = a == FieldT::one();
let b = b == FieldT::one();
let res = match *self {
And => a && b,
Nand => !(a && b),
Xor => a != b,
Xnor => a == b,
Or => a || b,
Nor => !(a || b),
MaterialNonimplication => a && (!b),
MaterialImplication => !(a && (!b))
};
if res {
FieldT::one()
} else {
FieldT::zero()
}
}
}
#[derive(Clone)]
pub struct BinaryOp {
a: Var,
b: Var,
op: Op,
resolved: Rc<RefCell<Option<Var>>>
}
impl BinaryOp {
fn new(a: Var, b: Var, op: Op) -> BinaryOp {
BinaryOp {
a: a,
b: b,
op: op,
resolved: Rc::new(RefCell::new(None))
}
}
fn walk(&self, counter: &mut usize, constraints: &mut Vec<Constraint>, witness_map: &mut WitnessMap)
{
self.a.walk(counter, constraints, witness_map);
self.b.walk(counter, constraints, witness_map);
}
fn val(&self, map: &[FieldT], inverted: bool) -> FieldT {
let v = self.op.val(self.a.val(map), self.b.val(map));
if inverted {
if v == FieldT::one() {
FieldT::zero()
} else {
FieldT::one()
}
} else {
v
}
}
fn resolve(&self, inverted: bool) -> Bit {
let res = { self.resolved.borrow_mut().clone() };
match res {
Some(v) => {
if inverted {
Not(v)
} else {
Is(v)
}
},
None => {
let v = resolve(&self.a, &self.b, self.op);
*self.resolved.borrow_mut() = Some(v.clone());
if inverted {
Not(v)
} else {
Is(v)
}
}
}
}
}
#[derive(Clone)] #[derive(Clone)]
pub enum Bit { pub enum Bit {
Constant(bool), Constant(bool),
Is(Var), Is(Var),
Not(Var) Not(Var),
Bin(BinaryOp, bool)
} }
fn resolve_not(v: &Var) -> Var { fn resolve(a: &Var, b: &Var, op: Op) -> Var {
gadget(&[v], 1, |vars| { gadget(&[a, b], 1, move |vals| {
if vars.get_input(0) == FieldT::zero() { let a = vals.get_input(0);
vars.set_output(0, FieldT::one()); let b = vals.get_input(1);
} else {
vars.set_output(0, FieldT::zero()); vals.set_output(0, op.val(a, b));
}
}, |i, o, cs| { }, |i, o, cs| {
// (1 - a) * 1 = b cs.push(match op {
cs.push(Constraint); And => Constraint::And(i[0].index(), i[1].index(), o[0].index()),
Nand => Constraint::Nand(i[0].index(), i[1].index(), o[0].index()),
Xor => Constraint::Xor(i[0].index(), i[1].index(), o[0].index()),
Xnor => Constraint::Xnor(i[0].index(), i[1].index(), o[0].index()),
Nor => Constraint::Nor(i[0].index(), i[1].index(), o[0].index()),
Or => Constraint::Or(i[0].index(), i[1].index(), o[0].index()),
MaterialNonimplication => Constraint::MaterialNonimplication(i[0].index(), i[1].index(), o[0].index()),
MaterialImplication => Constraint::MaterialImplication(i[0].index(), i[1].index(), o[0].index())
});
vec![o[0]] vec![o[0]]
}).remove(0) }).remove(0)
@ -30,7 +233,16 @@ impl Bit {
match *self { match *self {
Constant(c) => c, Constant(c) => c,
Not(ref v) => v.val(map) == FieldT::zero(), Not(ref v) => v.val(map) == FieldT::zero(),
Is(ref v) => v.val(map) == FieldT::one() Is(ref v) => v.val(map) == FieldT::one(),
Bin(ref bin, inverted) => bin.val(map, inverted) == FieldT::one()
}
}
// probably could remove this
pub fn resolve(&self) -> Bit {
match *self {
Bin(ref bin, inverted) => bin.resolve(inverted),
_ => self.clone()
} }
} }
@ -42,13 +254,16 @@ impl Bit {
}, },
Is(ref v) => { Is(ref v) => {
v.walk(counter, constraints, witness_map); v.walk(counter, constraints, witness_map);
},
Bin(ref bin, _) => {
bin.walk(counter, constraints, witness_map);
} }
} }
} }
pub fn new(v: &Var) -> Bit { pub fn new(v: &Var) -> Bit {
Is(gadget(&[v], 0, |_| {}, |i, o, cs| { Is(gadget(&[v], 0, |_| {}, |i, o, cs| {
cs.push(Constraint); cs.push(Constraint::Bitness(i[0].index()));
vec![i[0]] vec![i[0]]
}).remove(0)) }).remove(0))
@ -60,11 +275,11 @@ impl Bit {
// self xor other // self xor other
pub fn xor(&self, other: &Bit) -> Bit { pub fn xor(&self, other: &Bit) -> Bit {
match (self, other) { mirror_match!((self, other) {
(&Constant(a), &Constant(b)) => { (&Constant(a), &Constant(b)) => {
Constant(a != b) Constant(a != b)
}, },
(&Is(ref v), &Constant(a)) | (&Constant(a), &Is(ref v)) => { mirror!(&Is(ref v), &Constant(a)) => {
if a { if a {
// Anything XOR 1 is the NOT of that thing. // Anything XOR 1 is the NOT of that thing.
Not(v.clone()) Not(v.clone())
@ -73,104 +288,94 @@ impl Bit {
Is(v.clone()) Is(v.clone())
} }
}, },
(&Is(ref a), &Is(ref b)) => { mirror!(&Not(ref v), &Constant(a)) => {
Is(gadget(&[a, b], 1, |vars| {
if vars.get_input(0) != vars.get_input(1) {
vars.set_output(0, FieldT::one());
} else {
vars.set_output(0, FieldT::zero());
}
}, |i, o, cs| {
// (2*b) * c = b+c - a
cs.push(Constraint);
vec![o[0]]
}).remove(0))
},
(&Not(ref v), &Constant(a)) | (&Constant(a), &Not(ref v)) => {
if a { if a {
// Anything XOR 1 is the NOT of that thing. // Anything XOR 1 is the NOT of that thing.
// !A XOR 1 = A
Is(v.clone()) Is(v.clone())
} else { } else {
Not(v.clone()) Not(v.clone())
} }
}, },
(&Not(ref a), &Not(ref b)) => { mirror!(&Bin(ref bin, inverted), &Constant(c)) => {
// !A xor !B is equivalent to A xor B if c {
Is(a.clone()).xor(&Is(b.clone())) // Anything XOR 1 is the NOT of that thing.
Bin(bin.clone(), !inverted)
} else {
Bin(bin.clone(), inverted)
}
}, },
(&Is(ref i), &Not(ref n)) | (&Not(ref n), &Is(ref i)) => { mirror!(&Bin(ref bin, inverted), &Is(ref i)) => {
Is(i.clone()).xor(&Is(resolve_not(n))) bin.resolve(inverted).xor(&Is(i.clone()))
} },
} (&Bin(ref bin1, inverted1), &Bin(ref bin2, inverted2)) => {
bin1.resolve(inverted1).xor(&bin2.resolve(inverted2))
},
mirror!(&Bin(ref bin, inverted), &Not(ref n)) => {
bin.resolve(inverted).xor(&Not(n.clone()))
},
(&Not(ref a), &Not(ref b)) => {
Bin(BinaryOp::new(a.clone(), b.clone(), Xor), false)
},
mirror!(&Is(ref i), &Not(ref n)) => {
Bin(BinaryOp::new(i.clone(), n.clone(), Xnor), false)
},
(&Is(ref a), &Is(ref b)) => {
Bin(BinaryOp::new(a.clone(), b.clone(), Xor), false)
},
})
} }
fn and(&self, other: &Bit) -> Bit { pub fn and(&self, other: &Bit) -> Bit {
match (self, other) { mirror_match!((self, other) {
(&Constant(a), &Constant(b)) => { (&Constant(a), &Constant(b)) => {
Constant(a && b) Constant(a && b)
}, },
(&Constant(a), &Is(ref v)) | (&Is(ref v), &Constant(a)) => { mirror!(&Is(ref v), &Constant(a)) => {
if a { if a {
// Anything AND 1 is the identity of that thing
Is(v.clone()) Is(v.clone())
} else { } else {
// Anything AND 0 is false
Constant(false) Constant(false)
} }
}, },
(&Is(ref a), &Is(ref b)) => { mirror!(&Not(ref v), &Constant(a)) => {
Is(gadget(&[a, b], 1, |vars| { if a {
if vars.get_input(0) == FieldT::one() && vars.get_input(1) == FieldT::one() { // Anything AND 1 is the identity of that thing
vars.set_output(0, FieldT::one()); Not(v.clone())
} else { } else {
vars.set_output(0, FieldT::zero()); // Anything AND 0 is false
Constant(false)
} }
}, |i, o, cs| {
// a * b = c
cs.push(Constraint);
vec![o[0]]
}).remove(0))
}, },
(&Not(ref a), &Constant(c)) | (&Constant(c), &Not(ref a)) => { mirror!(&Bin(ref bin, inverted), &Constant(c)) => {
if c { if c {
// X and 1 is the identity of X // Anything AND 1 is the identity of that thing
Not(a.clone()) Bin(bin.clone(), inverted)
} else { } else {
// Anything AND 0 is false
Constant(false) Constant(false)
} }
}, },
(&Not(ref n), &Is(ref i)) | (&Is(ref i), &Not(ref n)) => { mirror!(&Bin(ref bin, inverted), &Is(ref i)) => {
//Is(i.clone()).and(&Is(resolve_not(n))) bin.resolve(inverted).and(&Is(i.clone()))
Is(gadget(&[n, i], 1, |vars| { },
if vars.get_input(0) == FieldT::zero() && vars.get_input(1) == FieldT::one() { (&Bin(ref bin1, inverted1), &Bin(ref bin2, inverted2)) => {
vars.set_output(0, FieldT::one()); bin1.resolve(inverted1).and(&bin2.resolve(inverted2))
} else { },
vars.set_output(0, FieldT::zero()); mirror!(&Bin(ref bin, inverted), &Not(ref n)) => {
} bin.resolve(inverted).and(&Not(n.clone()))
}, |i, o, cs| {
// (1-a) * b = c
cs.push(Constraint);
vec![o[0]]
}).remove(0))
}, },
(&Not(ref a), &Not(ref b)) => { (&Not(ref a), &Not(ref b)) => {
//Is(resolve_not(a)).and(&Is(resolve_not(b))) Bin(BinaryOp::new(a.clone(), b.clone(), Nor), false)
Is(gadget(&[a, b], 1, |vars| { },
if vars.get_input(0) == FieldT::zero() && vars.get_input(1) == FieldT::zero() { mirror!(&Is(ref i), &Not(ref n)) => {
vars.set_output(0, FieldT::one()); Bin(BinaryOp::new(i.clone(), n.clone(), MaterialNonimplication), false)
} else { },
vars.set_output(0, FieldT::zero()); (&Is(ref a), &Is(ref b)) => {
} Bin(BinaryOp::new(a.clone(), b.clone(), And), false)
}, |i, o, cs| { },
// (1 - a) * (1 - b) = c })
cs.push(Constraint);
vec![o[0]]
}).remove(0))
}
}
} }
// (not self) and other // (not self) and other
@ -191,6 +396,7 @@ fn test_binary_op<F: Fn(&Bit, &Bit) -> Bit>(op: F, a_in: i64, b_in: i64, c_out:
let mut constraints = vec![]; let mut constraints = vec![];
let c = op(&a, &b); let c = op(&a, &b);
let c = c.resolve();
c.walk(&mut counter, &mut constraints, &mut witness_map); c.walk(&mut counter, &mut constraints, &mut witness_map);
assert_eq!(counter, 4); assert_eq!(counter, 4);
assert_eq!(constraints.len(), 3); assert_eq!(constraints.len(), 3);

View File

@ -282,6 +282,9 @@ fn test_sha3_256() {
((0..512).map(|_| 0x30).collect::<Vec<_>>(), ((0..512).map(|_| 0x30).collect::<Vec<_>>(),
[0x1c,0x80,0x1b,0x16,0x3a,0x2a,0xbe,0xd0,0xe8,0x07,0x1e,0x7f,0xf2,0x60,0x4e,0x98,0x11,0x22,0x80,0x54,0x14,0xf3,0xc8,0xfd,0x96,0x59,0x5d,0x7e,0xe1,0xd6,0x54,0xe2] [0x1c,0x80,0x1b,0x16,0x3a,0x2a,0xbe,0xd0,0xe8,0x07,0x1e,0x7f,0xf2,0x60,0x4e,0x98,0x11,0x22,0x80,0x54,0x14,0xf3,0xc8,0xfd,0x96,0x59,0x5d,0x7e,0xe1,0xd6,0x54,0xe2]
), ),
((0..64).map(|_| 0x00).collect::<Vec<_>>(),
[0x07,0x0f,0xa1,0xab,0x6f,0xcc,0x55,0x7e,0xd1,0x4d,0x42,0x94,0x1f,0x19,0x67,0x69,0x30,0x48,0x55,0x1e,0xb9,0x04,0x2a,0x8d,0x0a,0x05,0x7a,0xfb,0xd7,0x5e,0x81,0xe0]
),
]; ];
for (i, &(ref message, ref expected)) in test_vector.iter().enumerate() { for (i, &(ref message, ref expected)) in test_vector.iter().enumerate() {

View File

@ -15,7 +15,7 @@ mod bit;
fn main() { fn main() {
tinysnark::init(); tinysnark::init();
let inbytes = 1; let inbytes = 64;
//for inbits in 0..1024 { //for inbits in 0..1024 {
let inbits = inbytes * 8; let inbits = inbytes * 8;
let input: Vec<Bit> = (0..inbits).map(|i| Bit::new(&Var::new(i+1))).collect(); let input: Vec<Bit> = (0..inbits).map(|i| Bit::new(&Var::new(i+1))).collect();

View File

@ -41,7 +41,33 @@ pub fn witness_field_elements(vars: &mut [FieldT], witness_map: &WitnessMap) {
} }
#[derive(Clone)] #[derive(Clone)]
pub struct Constraint; pub enum Constraint {
Bitness(Rc<Cell<usize>>),
And(Rc<Cell<usize>>, Rc<Cell<usize>>, Rc<Cell<usize>>),
Nand(Rc<Cell<usize>>, Rc<Cell<usize>>, Rc<Cell<usize>>),
Xor(Rc<Cell<usize>>, Rc<Cell<usize>>, Rc<Cell<usize>>),
Xnor(Rc<Cell<usize>>, Rc<Cell<usize>>, Rc<Cell<usize>>),
MaterialNonimplication(Rc<Cell<usize>>, Rc<Cell<usize>>, Rc<Cell<usize>>),
MaterialImplication(Rc<Cell<usize>>, Rc<Cell<usize>>, Rc<Cell<usize>>),
Nor(Rc<Cell<usize>>, Rc<Cell<usize>>, Rc<Cell<usize>>),
Or(Rc<Cell<usize>>, Rc<Cell<usize>>, Rc<Cell<usize>>)
}
impl fmt::Debug for Constraint {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", match *self {
Constraint::Bitness(ref b) => format!("bitness: {}", b.get()),
Constraint::And(ref a, ref b, ref c) => format!("{} = {} AND {}", c.get(), a.get(), b.get()),
Constraint::Nand(ref a, ref b, ref c) => format!("{} = {} NAND {}", c.get(), a.get(), b.get()),
Constraint::Xor(ref a, ref b, ref c) => format!("{} = {} XOR {}", c.get(), a.get(), b.get()),
Constraint::Xnor(ref a, ref b, ref c) => format!("{} = {} XNOR {}", c.get(), a.get(), b.get()),
Constraint::MaterialNonimplication(ref a, ref b, ref c) => format!("{} = {}{}", c.get(), a.get(), b.get()),
Constraint::MaterialImplication(ref a, ref b, ref c) => format!("{} = {} <-> {}", c.get(), a.get(), b.get()),
Constraint::Nor(ref a, ref b, ref c) => format!("{} = {} NOR {}", c.get(), a.get(), b.get()),
Constraint::Or(ref a, ref b, ref c) => format!("{} = {} OR {}", c.get(), a.get(), b.get())
})
}
}
struct Gadget { struct Gadget {
inputs: Vec<Var>, inputs: Vec<Var>,
@ -96,6 +122,11 @@ impl Var {
} }
} }
// make this not public or unsafe too
pub fn index(&self) -> Rc<Cell<usize>> {
self.index.clone()
}
pub fn val(&self, map: &[FieldT]) -> FieldT { pub fn val(&self, map: &[FieldT]) -> FieldT {
let index = self.index.get(); let index = self.index.get();
assert!(index != 0); assert!(index != 0);

View File

@ -16,6 +16,7 @@ static mut INITIALIZED: bool = false;
extern "C" { extern "C" {
fn tinysnark_init_public_params(); fn tinysnark_init_public_params();
pub fn tinysnark_test();
} }
pub fn init() { pub fn init() {

View File

@ -7,6 +7,7 @@ zk-SNARK support using the ALT_BN128 curve.
#include "zk_proof_systems/ppzksnark/r1cs_ppzksnark/r1cs_ppzksnark.hpp" #include "zk_proof_systems/ppzksnark/r1cs_ppzksnark/r1cs_ppzksnark.hpp"
#include "common/default_types/r1cs_ppzksnark_pp.hpp" #include "common/default_types/r1cs_ppzksnark_pp.hpp"
#include "common/utils.hpp" #include "common/utils.hpp"
#include "gadgetlib1/gadgets/hashes/sha256/sha256_gadget.hpp"
using namespace libsnark; using namespace libsnark;
using namespace std; using namespace std;
@ -170,3 +171,31 @@ extern "C" void tinysnark_init_public_params() {
assert(sizeof(p) == 32); assert(sizeof(p) == 32);
} }
} }
extern "C" void tinysnark_test() {
typedef Fr<default_r1cs_ppzksnark_pp> FieldT;
protoboard<FieldT> pb;
auto input_bits = new digest_variable<FieldT>(pb, 512, "input_bits");
auto output_bits = new digest_variable<FieldT>(pb, 256, "output_bits");
auto input_block = new block_variable<FieldT>(pb, {
input_bits->bits
}, "input_block");
auto IV = SHA256_default_IV(pb);
auto sha256 = new sha256_compression_function_gadget<FieldT>(pb,
IV,
input_block->bits,
*output_bits,
"sha256");
input_bits->generate_r1cs_constraints();
output_bits->generate_r1cs_constraints();
sha256->generate_r1cs_constraints();
const r1cs_constraint_system<FieldT> constraint_system = pb.get_constraint_system();
cout << "Number of R1CS constraints: " << constraint_system.num_constraints() << endl;
}