Added unconstrained zk-SNARK implementation for SHA3

This commit is contained in:
Sean Bowe 2016-01-03 02:13:25 -07:00
parent bc77a837df
commit 7415d5ff3c
4 changed files with 481 additions and 91 deletions

233
src/bit.rs Normal file
View File

@ -0,0 +1,233 @@
use tinysnark::FieldT;
use super::variable::*;
use self::Bit::*;
#[derive(Clone)]
pub enum Bit {
Constant(bool),
Is(Var),
Not(Var)
}
fn resolve_not(v: &Var) -> Var {
gadget(&[v], 1, |i, o| {
if *i[0] == FieldT::zero() {
*o[0] = FieldT::one();
} else {
*o[0] = FieldT::zero();
}
}, |i, o, cs| {
// (1 - a) * 1 = b
cs.push(Constraint);
vec![o[0]]
}).remove(0)
}
impl Bit {
pub fn val(&self, map: &[FieldT]) -> bool {
match *self {
Constant(c) => c,
Not(ref v) => v.val(map) == FieldT::zero(),
Is(ref v) => v.val(map) == FieldT::one()
}
}
pub fn walk(&self, counter: &mut usize, constraints: &mut Vec<Constraint>, witness_map: &mut WitnessMap) {
match *self {
Constant(_) => {},
Not(ref v) => {
v.walk(counter, constraints, witness_map);
},
Is(ref v) => {
v.walk(counter, constraints, witness_map);
}
}
}
pub fn new(v: &Var) -> Bit {
Is(gadget(&[v], 0, |_, _| {}, |i, o, cs| {
cs.push(Constraint);
vec![i[0]]
}).remove(0))
}
pub fn constant(num: bool) -> Bit {
Constant(num)
}
// self xor other
pub fn xor(&self, other: &Bit) -> Bit {
match (self, other) {
(&Constant(a), &Constant(b)) => {
Constant(a != b)
},
(&Is(ref v), &Constant(a)) | (&Constant(a), &Is(ref v)) => {
if a {
// Anything XOR 1 is the NOT of that thing.
Not(v.clone())
} else {
// Anything XOR 0 equals that thing.
Is(v.clone())
}
},
(&Is(ref a), &Is(ref b)) => {
Is(gadget(&[a, b], 1, |i, o| {
if *i[0] != *i[1] {
*o[0] = FieldT::one();
} else {
*o[0] = FieldT::zero();
}
}, |i, o, cs| {
// (2*b) * c = b+c - a
cs.push(Constraint);
vec![o[0]]
}).remove(0))
},
(&Not(ref v), &Constant(a)) | (&Constant(a), &Not(ref v)) => {
if a {
// Anything XOR 1 is the NOT of that thing.
// !A XOR 1 = A
Is(v.clone())
} else {
Not(v.clone())
}
},
(&Not(ref a), &Not(ref b)) => {
// !A xor !B is equivalent to A xor B
Is(a.clone()).xor(&Is(b.clone()))
},
(&Is(ref i), &Not(ref n)) | (&Not(ref n), &Is(ref i)) => {
Is(i.clone()).xor(&Is(resolve_not(n)))
}
}
}
fn and(&self, other: &Bit) -> Bit {
match (self, other) {
(&Constant(a), &Constant(b)) => {
Constant(a && b)
},
(&Constant(a), &Is(ref v)) | (&Is(ref v), &Constant(a)) => {
if a {
Is(v.clone())
} else {
Constant(false)
}
},
(&Is(ref a), &Is(ref b)) => {
Is(gadget(&[a, b], 1, |i, o| {
if *i[0] == FieldT::one() && *i[1] == FieldT::one() {
*o[0] = FieldT::one();
} else {
*o[0] = FieldT::zero();
}
}, |i, o, cs| {
// a * b = c
cs.push(Constraint);
vec![o[0]]
}).remove(0))
},
(&Not(ref a), &Constant(c)) | (&Constant(c), &Not(ref a)) => {
if c {
// X and 1 is the identity of X
Not(a.clone())
} else {
Constant(false)
}
},
(&Not(ref n), &Is(ref i)) | (&Is(ref i), &Not(ref n)) => {
//Is(i.clone()).and(&Is(resolve_not(n)))
Is(gadget(&[n, i], 1, |i, o| {
if *i[0] == FieldT::zero() && *i[1] == FieldT::one() {
*o[0] = FieldT::one();
} else {
*o[0] = FieldT::zero();
}
}, |i, o, cs| {
// (1-a) * b = c
cs.push(Constraint);
vec![o[0]]
}).remove(0))
},
(&Not(ref a), &Not(ref b)) => {
//Is(resolve_not(a)).and(&Is(resolve_not(b)))
Is(gadget(&[a, b], 1, |i, o| {
if *i[0] == FieldT::zero() && *i[1] == FieldT::zero() {
*o[0] = FieldT::one();
} else {
*o[0] = FieldT::zero();
}
}, |i, o, cs| {
// (1 - a) * (1 - b) = c
cs.push(Constraint);
vec![o[0]]
}).remove(0))
}
}
}
// (not self) and other
pub fn notand(&self, other: &Bit) -> Bit {
self.xor(&Constant(true)).and(other)
}
}
#[cfg(test)]
fn test_binary_op<F: Fn(&Bit, &Bit) -> Bit>(op: F, a_in: i64, b_in: i64, c_out: i64)
{
let a = Var::new(1);
let b = Var::new(2);
let a = Bit::new(&a);
let b = Bit::new(&b);
let mut counter = 3;
let mut witness_map = WitnessMap::new();
let mut constraints = vec![];
let c = op(&a, &b);
c.walk(&mut counter, &mut constraints, &mut witness_map);
assert_eq!(counter, 4);
assert_eq!(constraints.len(), 3);
assert_eq!(witness_map.len(), 2);
assert_eq!(witness_map[&1].len(), 2);
assert_eq!(witness_map[&2].len(), 1);
let mut f: Vec<FieldT> = (0..counter).map(|_| FieldT::zero()).collect();
f[0] = FieldT::one();
f[1] = FieldT::from(a_in);
f[2] = FieldT::from(b_in);
satisfy_field_elements(&mut f, &witness_map);
assert_eq!(f[3], FieldT::from(c_out));
}
#[test]
fn test_xor() {
use tinysnark;
tinysnark::init();
test_binary_op(Bit::xor, 0, 0, 0);
test_binary_op(Bit::xor, 0, 1, 1);
test_binary_op(Bit::xor, 1, 0, 1);
test_binary_op(Bit::xor, 1, 1, 0);
}
#[test]
fn test_and() {
use tinysnark;
tinysnark::init();
test_binary_op(Bit::and, 0, 0, 0);
test_binary_op(Bit::and, 0, 1, 0);
test_binary_op(Bit::and, 1, 0, 0);
test_binary_op(Bit::and, 1, 1, 1);
}

View File

@ -1,3 +1,6 @@
use super::bit::Bit;
use std::slice::IterMut;
const KECCAKF_RNDC: [u64; 24] =
[
0x0000000000000001, 0x0000000000008082, 0x800000000000808a,
@ -35,7 +38,7 @@ fn keccakf(st: &mut [Byte], rounds: usize)
State {
bits: bytes.iter_mut()
.rev() // Endianness
.flat_map(|b| b.bits.iter_mut())
.flat_map(|b| b.iter_mut())
.collect()
}
}
@ -188,7 +191,7 @@ fn keccakf(st: &mut [Byte], rounds: usize)
}
}
fn sha3_256(message: &[Byte]) -> Vec<Byte> {
pub fn sha3_256(message: &[Byte]) -> Vec<Byte> {
// As defined by FIPS202
keccak(1088, 512, message, 0x06, 32, 24)
}
@ -249,82 +252,15 @@ fn keccak(rate: usize, capacity: usize, mut input: &[Byte], delimited_suffix: u8
output
}
#[derive(Debug, PartialEq, Clone)]
enum Bit {
Constant(bool)
}
#[derive(Clone, Debug, PartialEq)]
struct Byte {
bits: Vec<Bit>
}
impl Byte {
fn new(byte: u8) -> Byte {
Byte {
bits: (0..8).map(|i| Bit::constant(byte & (1 << i) != 0))
.rev()
.collect()
}
}
fn unwrap_constant(&self) -> u8 {
let mut cur = 7;
let mut acc = 0;
for bit in &self.bits {
match bit {
&Bit::Constant(true) => {
acc |= 1 << cur;
},
&Bit::Constant(false) => {},
//_ => panic!("Tried to unwrap a constant from a non-constant")
}
cur -= 1;
}
acc
}
fn xor(&self, other: &Byte) -> Byte {
Byte {
bits: self.bits.iter()
.zip(other.bits.iter())
.map(|(a, b)| a.xor(b))
.collect()
}
}
}
impl Bit {
fn constant(num: bool) -> Bit {
Bit::Constant(num)
}
// self xor other
fn xor(&self, other: &Bit) -> Bit {
match (self, other) {
(&Bit::Constant(a), &Bit::Constant(b)) => {
Bit::constant(a != b)
},
//_ => unimplemented!()
}
}
// (not self) and other
fn notand(&self, other: &Bit) -> Bit {
match (self, other) {
(&Bit::Constant(a), &Bit::Constant(b)) => {
Bit::constant((!a) && b)
},
//_ => unimplemented!()
}
}
}
#[test]
fn test_sha3_256() {
let test_vector: Vec<(Vec<u8>, [u8; 32])> = vec![
(vec![0xff],
[0x44,0x4b,0x89,0xec,0xce,0x39,0x5a,0xec,0x5d,0xc9,0x8f,0x19,0xde,0xfd,0x3a,0x23,0xbc,0xa0,0x82,0x2f,0xc7,0x22,0x26,0xf5,0x8c,0xa4,0x6a,0x17,0xee,0xec,0xa4,0x42]
),
(vec![0x00],
[0x5d,0x53,0x46,0x9f,0x20,0xfe,0xf4,0xf8,0xea,0xb5,0x2b,0x88,0x04,0x4e,0xde,0x69,0xc7,0x7a,0x6a,0x68,0xa6,0x07,0x28,0x60,0x9f,0xc4,0xa6,0x5f,0xf5,0x31,0xe7,0xd0]
),
(vec![0x30, 0x31, 0x30, 0x31],
[0xe5,0xbf,0x4a,0xd7,0xda,0x2b,0x4d,0x64,0x0d,0x2b,0x8d,0xd3,0xae,0x9b,0x6e,0x71,0xb3,0x6e,0x0f,0x3d,0xb7,0x6a,0x1e,0xc0,0xad,0x6b,0x87,0x2f,0x3e,0xcc,0x2e,0xbc]
),
@ -368,3 +304,63 @@ fn test_sha3_256() {
}
}
}
#[derive(Clone)]
pub struct Byte {
bits: Vec<Bit>
}
impl From<Vec<Bit>> for Byte {
fn from(a: Vec<Bit>) -> Byte {
assert_eq!(8, a.len());
Byte {
bits: a
}
}
}
impl Byte {
pub fn bits(&self) -> Vec<Bit> {
self.bits.clone()
}
pub fn new(byte: u8) -> Byte {
Byte {
bits: (0..8).map(|i| Bit::constant(byte & (1 << i) != 0))
.rev()
.collect()
}
}
pub fn iter_mut(&mut self) -> IterMut<Bit> {
self.bits.iter_mut()
}
pub fn unwrap_constant(&self) -> u8 {
let mut cur = 7;
let mut acc = 0;
for bit in &self.bits {
match bit {
&Bit::Constant(true) => {
acc |= 1 << cur;
},
&Bit::Constant(false) => {},
_ => panic!("Tried to unwrap a constant from a non-constant")
}
cur -= 1;
}
acc
}
pub fn xor(&self, other: &Byte) -> Byte {
Byte {
bits: self.bits.iter()
.zip(other.bits.iter())
.map(|(a, b)| a.xor(b))
.collect()
}
}
}

View File

@ -1,27 +1,46 @@
#![feature(iter_arith, btree_range, collections_bound)]
extern crate tinysnark;
extern crate rand;
use tinysnark::{Proof, Keypair, FieldT, LinearTerm, ConstraintSystem};
use variable::*;
use keccak::*;
use bit::*;
mod variable;
mod keccak;
mod bit;
fn main() {
tinysnark::init();
let mut cs = ConstraintSystem::new(2, 1);
// xor
// (2*b) * c = b+c - a
cs.add_constraint(
&[LinearTerm{coeff: FieldT::from(2), index: 2}],
&[LinearTerm{coeff: FieldT::one(), index: 3}],
&[LinearTerm{coeff: FieldT::one(), index: 2},
LinearTerm{coeff: FieldT::one(), index: 3},
LinearTerm{coeff: -FieldT::one(), index: 1}]
);
let prompt = [0.into(), 1.into()];
let solution = [1.into()];
assert!(cs.test(&prompt, &solution));
let kp = Keypair::new(&cs);
let proof = Proof::new(&kp, &prompt, &solution);
assert!(proof.verify(&kp, &prompt));
let inbytes = 64;
//for inbits in 0..1024 {
let inbits = inbytes * 8;
let input: Vec<Bit> = (0..inbits).map(|i| Bit::new(&Var::new(i+1))).collect();
let input: Vec<Byte> = input.chunks(8).map(|c| Byte::from(c.to_owned())).collect();
let output = sha3_256(&input);
let mut counter = 1 + (8*input.len());
let mut constraints = vec![];
let mut witness_map = WitnessMap::new();
for o in output.iter().flat_map(|e| e.bits().into_iter()) {
o.walk(&mut counter, &mut constraints, &mut witness_map);
}
let mut vars: Vec<FieldT> = (0..counter).map(|_| FieldT::zero()).collect();
vars[0] = FieldT::one();
satisfy_field_elements(&mut vars, &witness_map);
for b in output.iter().flat_map(|e| e.bits()) {
print!("{}", if b.val(&vars) { 1 } else { 0 });
}
println!("");
println!("{}: {} constraints", inbits, constraints.len());
//}
}

142
src/variable.rs Normal file
View File

@ -0,0 +1,142 @@
use tinysnark::FieldT;
use std::cell::Cell;
use std::rc::Rc;
use std::fmt;
use std::collections::BTreeMap;
pub type WitnessMap = BTreeMap<usize, Vec<(Vec<usize>, Vec<usize>, Rc<Fn(&[&FieldT], &mut [&mut FieldT]) + 'static>)>>;
use std::collections::Bound::Unbounded;
pub fn satisfy_field_elements(vars: &mut [FieldT], witness_map: &WitnessMap) {
for (n, group) in witness_map.range(Unbounded, Unbounded) {
for &(ref i, ref o, ref f) in group.iter() {
let i: Vec<&FieldT> = i.iter().map(|i| &vars[*i]).collect();
let o: Vec<&FieldT> = o.iter().map(|o| &vars[*o]).collect();
let mut o: Vec<&mut FieldT> = unsafe {
use std::mem::transmute;
transmute(o)
};
f(&i, &mut o);
}
}
}
#[derive(Clone)]
pub struct Constraint;
struct Gadget {
inputs: Vec<Var>,
aux: Vec<Var>,
witness: Rc<Fn(&[&FieldT], &mut [&mut FieldT]) + 'static>,
constraints: Vec<Constraint>,
group: usize,
visited: Cell<bool>
}
impl Gadget {
pub fn walk(&self, counter: &mut usize, constraints: &mut Vec<Constraint>, witness_map: &mut WitnessMap) {
if self.visited.get() {
return;
}
self.visited.set(true);
for a in &self.aux {
assert!(a.index.get() == 0);
a.index.set(*counter);
*counter += 1;
}
constraints.extend_from_slice(&self.constraints);
for i in &self.inputs {
i.walk(counter, constraints, witness_map);
}
let input_indexes = self.inputs.iter().map(|i| i.index.get()).collect();
let output_indexes = self.aux.iter().map(|i| i.index.get()).collect();
witness_map.entry(self.group)
.or_insert_with(|| Vec::new())
.push((input_indexes, output_indexes, self.witness.clone()));
}
}
#[derive(Clone)]
pub struct Var {
index: Rc<Cell<usize>>,
gadget: Option<Rc<Gadget>>
}
impl Var {
// todo: make this not public
pub fn new(i: usize) -> Var {
Var {
index: Rc::new(Cell::new(i)),
gadget: None
}
}
pub fn val(&self, map: &[FieldT]) -> FieldT {
let index = self.index.get();
assert!(index != 0);
map[index]
}
fn group(&self) -> usize {
match self.gadget {
None => 0,
Some(ref g) => g.group
}
}
pub fn walk(&self, counter: &mut usize, constraints: &mut Vec<Constraint>, witness_map: &mut WitnessMap) {
match self.gadget {
None => {},
Some(ref g) => g.walk(counter, constraints, witness_map)
}
}
}
pub fn gadget<W, C>(
inputs: &[&Var],
aux: usize,
witness: W,
constrain: C
) -> Vec<Var>
where C: for<'a> Fn(&[&'a Var], &[&'a Var], &mut Vec<Constraint>) -> Vec<&'a Var>,
W: Fn(&[&FieldT], &mut [&mut FieldT]) + 'static
{
let this_group = inputs.iter().map(|i| i.group()).max().map(|a| a+1).unwrap_or(0);
let aux: Vec<_> = (0..aux).map(|_| Var::new(0)).collect();
let aux: Vec<_> = aux.iter().collect();
let mut constraints = vec![];
let outputs = constrain(inputs, &*aux, &mut constraints);
let gadget = Rc::new(Gadget {
inputs: inputs.iter().map(|a| (*a).clone()).collect(),
aux: aux.iter().map(|a| (*a).clone()).collect(),
witness: Rc::new(witness),
constraints: constraints,
group: this_group,
visited: Cell::new(false)
});
outputs.into_iter().map(|a| {
let mut a = (*a).clone();
// TODO: we should augment the gadget instead
// of replacing it
debug_assert!(a.gadget.is_none());
a.gadget = Some(gadget.clone());
a
}).collect()
}