diff --git a/src/bit.rs b/src/bit.rs new file mode 100644 index 0000000..eb2b262 --- /dev/null +++ b/src/bit.rs @@ -0,0 +1,233 @@ +use tinysnark::FieldT; + +use super::variable::*; +use self::Bit::*; + +#[derive(Clone)] +pub enum Bit { + Constant(bool), + Is(Var), + Not(Var) +} + +fn resolve_not(v: &Var) -> Var { + gadget(&[v], 1, |i, o| { + if *i[0] == FieldT::zero() { + *o[0] = FieldT::one(); + } else { + *o[0] = FieldT::zero(); + } + }, |i, o, cs| { + // (1 - a) * 1 = b + cs.push(Constraint); + + vec![o[0]] + }).remove(0) +} + +impl Bit { + pub fn val(&self, map: &[FieldT]) -> bool { + match *self { + Constant(c) => c, + Not(ref v) => v.val(map) == FieldT::zero(), + Is(ref v) => v.val(map) == FieldT::one() + } + } + + pub fn walk(&self, counter: &mut usize, constraints: &mut Vec, witness_map: &mut WitnessMap) { + match *self { + Constant(_) => {}, + Not(ref v) => { + v.walk(counter, constraints, witness_map); + }, + Is(ref v) => { + v.walk(counter, constraints, witness_map); + } + } + } + + pub fn new(v: &Var) -> Bit { + Is(gadget(&[v], 0, |_, _| {}, |i, o, cs| { + cs.push(Constraint); + + vec![i[0]] + }).remove(0)) + } + + pub fn constant(num: bool) -> Bit { + Constant(num) + } + + // self xor other + pub fn xor(&self, other: &Bit) -> Bit { + match (self, other) { + (&Constant(a), &Constant(b)) => { + Constant(a != b) + }, + (&Is(ref v), &Constant(a)) | (&Constant(a), &Is(ref v)) => { + if a { + // Anything XOR 1 is the NOT of that thing. + Not(v.clone()) + } else { + // Anything XOR 0 equals that thing. + Is(v.clone()) + } + }, + (&Is(ref a), &Is(ref b)) => { + Is(gadget(&[a, b], 1, |i, o| { + if *i[0] != *i[1] { + *o[0] = FieldT::one(); + } else { + *o[0] = FieldT::zero(); + } + }, |i, o, cs| { + // (2*b) * c = b+c - a + cs.push(Constraint); + + vec![o[0]] + }).remove(0)) + }, + (&Not(ref v), &Constant(a)) | (&Constant(a), &Not(ref v)) => { + if a { + // Anything XOR 1 is the NOT of that thing. + // !A XOR 1 = A + Is(v.clone()) + } else { + Not(v.clone()) + } + }, + (&Not(ref a), &Not(ref b)) => { + // !A xor !B is equivalent to A xor B + Is(a.clone()).xor(&Is(b.clone())) + }, + (&Is(ref i), &Not(ref n)) | (&Not(ref n), &Is(ref i)) => { + Is(i.clone()).xor(&Is(resolve_not(n))) + } + } + } + + fn and(&self, other: &Bit) -> Bit { + match (self, other) { + (&Constant(a), &Constant(b)) => { + Constant(a && b) + }, + (&Constant(a), &Is(ref v)) | (&Is(ref v), &Constant(a)) => { + if a { + Is(v.clone()) + } else { + Constant(false) + } + }, + (&Is(ref a), &Is(ref b)) => { + Is(gadget(&[a, b], 1, |i, o| { + if *i[0] == FieldT::one() && *i[1] == FieldT::one() { + *o[0] = FieldT::one(); + } else { + *o[0] = FieldT::zero(); + } + }, |i, o, cs| { + // a * b = c + cs.push(Constraint); + + vec![o[0]] + }).remove(0)) + }, + (&Not(ref a), &Constant(c)) | (&Constant(c), &Not(ref a)) => { + if c { + // X and 1 is the identity of X + Not(a.clone()) + } else { + Constant(false) + } + }, + (&Not(ref n), &Is(ref i)) | (&Is(ref i), &Not(ref n)) => { + //Is(i.clone()).and(&Is(resolve_not(n))) + Is(gadget(&[n, i], 1, |i, o| { + if *i[0] == FieldT::zero() && *i[1] == FieldT::one() { + *o[0] = FieldT::one(); + } else { + *o[0] = FieldT::zero(); + } + }, |i, o, cs| { + // (1-a) * b = c + cs.push(Constraint); + + vec![o[0]] + }).remove(0)) + }, + (&Not(ref a), &Not(ref b)) => { + //Is(resolve_not(a)).and(&Is(resolve_not(b))) + Is(gadget(&[a, b], 1, |i, o| { + if *i[0] == FieldT::zero() && *i[1] == FieldT::zero() { + *o[0] = FieldT::one(); + } else { + *o[0] = FieldT::zero(); + } + }, |i, o, cs| { + // (1 - a) * (1 - b) = c + cs.push(Constraint); + + vec![o[0]] + }).remove(0)) + } + } + } + + // (not self) and other + pub fn notand(&self, other: &Bit) -> Bit { + self.xor(&Constant(true)).and(other) + } +} + +#[cfg(test)] +fn test_binary_op Bit>(op: F, a_in: i64, b_in: i64, c_out: i64) +{ + let a = Var::new(1); + let b = Var::new(2); + let a = Bit::new(&a); + let b = Bit::new(&b); + let mut counter = 3; + let mut witness_map = WitnessMap::new(); + let mut constraints = vec![]; + + let c = op(&a, &b); + c.walk(&mut counter, &mut constraints, &mut witness_map); + assert_eq!(counter, 4); + assert_eq!(constraints.len(), 3); + assert_eq!(witness_map.len(), 2); + assert_eq!(witness_map[&1].len(), 2); + assert_eq!(witness_map[&2].len(), 1); + + let mut f: Vec = (0..counter).map(|_| FieldT::zero()).collect(); + f[0] = FieldT::one(); + f[1] = FieldT::from(a_in); + f[2] = FieldT::from(b_in); + + satisfy_field_elements(&mut f, &witness_map); + + assert_eq!(f[3], FieldT::from(c_out)); +} + +#[test] +fn test_xor() { + use tinysnark; + + tinysnark::init(); + + test_binary_op(Bit::xor, 0, 0, 0); + test_binary_op(Bit::xor, 0, 1, 1); + test_binary_op(Bit::xor, 1, 0, 1); + test_binary_op(Bit::xor, 1, 1, 0); +} + +#[test] +fn test_and() { + use tinysnark; + + tinysnark::init(); + + test_binary_op(Bit::and, 0, 0, 0); + test_binary_op(Bit::and, 0, 1, 0); + test_binary_op(Bit::and, 1, 0, 0); + test_binary_op(Bit::and, 1, 1, 1); +} diff --git a/src/keccak.rs b/src/keccak.rs index 0574809..83927d9 100644 --- a/src/keccak.rs +++ b/src/keccak.rs @@ -1,3 +1,6 @@ +use super::bit::Bit; +use std::slice::IterMut; + const KECCAKF_RNDC: [u64; 24] = [ 0x0000000000000001, 0x0000000000008082, 0x800000000000808a, @@ -35,7 +38,7 @@ fn keccakf(st: &mut [Byte], rounds: usize) State { bits: bytes.iter_mut() .rev() // Endianness - .flat_map(|b| b.bits.iter_mut()) + .flat_map(|b| b.iter_mut()) .collect() } } @@ -188,7 +191,7 @@ fn keccakf(st: &mut [Byte], rounds: usize) } } -fn sha3_256(message: &[Byte]) -> Vec { +pub fn sha3_256(message: &[Byte]) -> Vec { // As defined by FIPS202 keccak(1088, 512, message, 0x06, 32, 24) } @@ -249,82 +252,15 @@ fn keccak(rate: usize, capacity: usize, mut input: &[Byte], delimited_suffix: u8 output } -#[derive(Debug, PartialEq, Clone)] -enum Bit { - Constant(bool) -} - -#[derive(Clone, Debug, PartialEq)] -struct Byte { - bits: Vec -} - -impl Byte { - fn new(byte: u8) -> Byte { - Byte { - bits: (0..8).map(|i| Bit::constant(byte & (1 << i) != 0)) - .rev() - .collect() - } - } - - fn unwrap_constant(&self) -> u8 { - let mut cur = 7; - let mut acc = 0; - - for bit in &self.bits { - match bit { - &Bit::Constant(true) => { - acc |= 1 << cur; - }, - &Bit::Constant(false) => {}, - //_ => panic!("Tried to unwrap a constant from a non-constant") - } - cur -= 1; - } - - acc - } - - fn xor(&self, other: &Byte) -> Byte { - Byte { - bits: self.bits.iter() - .zip(other.bits.iter()) - .map(|(a, b)| a.xor(b)) - .collect() - } - } -} - -impl Bit { - fn constant(num: bool) -> Bit { - Bit::Constant(num) - } - - // self xor other - fn xor(&self, other: &Bit) -> Bit { - match (self, other) { - (&Bit::Constant(a), &Bit::Constant(b)) => { - Bit::constant(a != b) - }, - //_ => unimplemented!() - } - } - - // (not self) and other - fn notand(&self, other: &Bit) -> Bit { - match (self, other) { - (&Bit::Constant(a), &Bit::Constant(b)) => { - Bit::constant((!a) && b) - }, - //_ => unimplemented!() - } - } -} - #[test] fn test_sha3_256() { let test_vector: Vec<(Vec, [u8; 32])> = vec![ + (vec![0xff], + [0x44,0x4b,0x89,0xec,0xce,0x39,0x5a,0xec,0x5d,0xc9,0x8f,0x19,0xde,0xfd,0x3a,0x23,0xbc,0xa0,0x82,0x2f,0xc7,0x22,0x26,0xf5,0x8c,0xa4,0x6a,0x17,0xee,0xec,0xa4,0x42] + ), + (vec![0x00], + [0x5d,0x53,0x46,0x9f,0x20,0xfe,0xf4,0xf8,0xea,0xb5,0x2b,0x88,0x04,0x4e,0xde,0x69,0xc7,0x7a,0x6a,0x68,0xa6,0x07,0x28,0x60,0x9f,0xc4,0xa6,0x5f,0xf5,0x31,0xe7,0xd0] + ), (vec![0x30, 0x31, 0x30, 0x31], [0xe5,0xbf,0x4a,0xd7,0xda,0x2b,0x4d,0x64,0x0d,0x2b,0x8d,0xd3,0xae,0x9b,0x6e,0x71,0xb3,0x6e,0x0f,0x3d,0xb7,0x6a,0x1e,0xc0,0xad,0x6b,0x87,0x2f,0x3e,0xcc,0x2e,0xbc] ), @@ -368,3 +304,63 @@ fn test_sha3_256() { } } } + +#[derive(Clone)] +pub struct Byte { + bits: Vec +} + +impl From> for Byte { + fn from(a: Vec) -> Byte { + assert_eq!(8, a.len()); + + Byte { + bits: a + } + } +} + +impl Byte { + pub fn bits(&self) -> Vec { + self.bits.clone() + } + + pub fn new(byte: u8) -> Byte { + Byte { + bits: (0..8).map(|i| Bit::constant(byte & (1 << i) != 0)) + .rev() + .collect() + } + } + + pub fn iter_mut(&mut self) -> IterMut { + self.bits.iter_mut() + } + + pub fn unwrap_constant(&self) -> u8 { + let mut cur = 7; + let mut acc = 0; + + for bit in &self.bits { + match bit { + &Bit::Constant(true) => { + acc |= 1 << cur; + }, + &Bit::Constant(false) => {}, + _ => panic!("Tried to unwrap a constant from a non-constant") + } + cur -= 1; + } + + acc + } + + pub fn xor(&self, other: &Byte) -> Byte { + Byte { + bits: self.bits.iter() + .zip(other.bits.iter()) + .map(|(a, b)| a.xor(b)) + .collect() + } + } +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 8c0a793..f29fb8f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,27 +1,46 @@ +#![feature(iter_arith, btree_range, collections_bound)] + extern crate tinysnark; extern crate rand; use tinysnark::{Proof, Keypair, FieldT, LinearTerm, ConstraintSystem}; +use variable::*; +use keccak::*; +use bit::*; +mod variable; mod keccak; +mod bit; fn main() { tinysnark::init(); - let mut cs = ConstraintSystem::new(2, 1); - // xor - // (2*b) * c = b+c - a - cs.add_constraint( - &[LinearTerm{coeff: FieldT::from(2), index: 2}], - &[LinearTerm{coeff: FieldT::one(), index: 3}], - &[LinearTerm{coeff: FieldT::one(), index: 2}, - LinearTerm{coeff: FieldT::one(), index: 3}, - LinearTerm{coeff: -FieldT::one(), index: 1}] - ); - let prompt = [0.into(), 1.into()]; - let solution = [1.into()]; - assert!(cs.test(&prompt, &solution)); - let kp = Keypair::new(&cs); - let proof = Proof::new(&kp, &prompt, &solution); - assert!(proof.verify(&kp, &prompt)); + let inbytes = 64; + //for inbits in 0..1024 { + let inbits = inbytes * 8; + let input: Vec = (0..inbits).map(|i| Bit::new(&Var::new(i+1))).collect(); + let input: Vec = input.chunks(8).map(|c| Byte::from(c.to_owned())).collect(); + + let output = sha3_256(&input); + + let mut counter = 1 + (8*input.len()); + let mut constraints = vec![]; + let mut witness_map = WitnessMap::new(); + + for o in output.iter().flat_map(|e| e.bits().into_iter()) { + o.walk(&mut counter, &mut constraints, &mut witness_map); + } + + let mut vars: Vec = (0..counter).map(|_| FieldT::zero()).collect(); + vars[0] = FieldT::one(); + + satisfy_field_elements(&mut vars, &witness_map); + + for b in output.iter().flat_map(|e| e.bits()) { + print!("{}", if b.val(&vars) { 1 } else { 0 }); + } + println!(""); + + println!("{}: {} constraints", inbits, constraints.len()); + //} } diff --git a/src/variable.rs b/src/variable.rs new file mode 100644 index 0000000..52ec332 --- /dev/null +++ b/src/variable.rs @@ -0,0 +1,142 @@ +use tinysnark::FieldT; +use std::cell::Cell; +use std::rc::Rc; +use std::fmt; +use std::collections::BTreeMap; + +pub type WitnessMap = BTreeMap, Vec, Rc)>>; + +use std::collections::Bound::Unbounded; + +pub fn satisfy_field_elements(vars: &mut [FieldT], witness_map: &WitnessMap) { + for (n, group) in witness_map.range(Unbounded, Unbounded) { + for &(ref i, ref o, ref f) in group.iter() { + let i: Vec<&FieldT> = i.iter().map(|i| &vars[*i]).collect(); + let o: Vec<&FieldT> = o.iter().map(|o| &vars[*o]).collect(); + + let mut o: Vec<&mut FieldT> = unsafe { + use std::mem::transmute; + + transmute(o) + }; + + f(&i, &mut o); + } + } +} + +#[derive(Clone)] +pub struct Constraint; + +struct Gadget { + inputs: Vec, + aux: Vec, + witness: Rc, + constraints: Vec, + group: usize, + visited: Cell +} + +impl Gadget { + pub fn walk(&self, counter: &mut usize, constraints: &mut Vec, witness_map: &mut WitnessMap) { + if self.visited.get() { + return; + } + + self.visited.set(true); + + for a in &self.aux { + assert!(a.index.get() == 0); + a.index.set(*counter); + *counter += 1; + } + + constraints.extend_from_slice(&self.constraints); + + for i in &self.inputs { + i.walk(counter, constraints, witness_map); + } + + let input_indexes = self.inputs.iter().map(|i| i.index.get()).collect(); + let output_indexes = self.aux.iter().map(|i| i.index.get()).collect(); + + witness_map.entry(self.group) + .or_insert_with(|| Vec::new()) + .push((input_indexes, output_indexes, self.witness.clone())); + } +} + +#[derive(Clone)] +pub struct Var { + index: Rc>, + gadget: Option> +} + +impl Var { + // todo: make this not public + pub fn new(i: usize) -> Var { + Var { + index: Rc::new(Cell::new(i)), + gadget: None + } + } + + pub fn val(&self, map: &[FieldT]) -> FieldT { + let index = self.index.get(); + assert!(index != 0); + map[index] + } + + fn group(&self) -> usize { + match self.gadget { + None => 0, + Some(ref g) => g.group + } + } + + pub fn walk(&self, counter: &mut usize, constraints: &mut Vec, witness_map: &mut WitnessMap) { + match self.gadget { + None => {}, + Some(ref g) => g.walk(counter, constraints, witness_map) + } + } +} + +pub fn gadget( + inputs: &[&Var], + aux: usize, + witness: W, + constrain: C +) -> Vec + where C: for<'a> Fn(&[&'a Var], &[&'a Var], &mut Vec) -> Vec<&'a Var>, + W: Fn(&[&FieldT], &mut [&mut FieldT]) + 'static +{ + let this_group = inputs.iter().map(|i| i.group()).max().map(|a| a+1).unwrap_or(0); + + let aux: Vec<_> = (0..aux).map(|_| Var::new(0)).collect(); + let aux: Vec<_> = aux.iter().collect(); + + let mut constraints = vec![]; + + let outputs = constrain(inputs, &*aux, &mut constraints); + + let gadget = Rc::new(Gadget { + inputs: inputs.iter().map(|a| (*a).clone()).collect(), + aux: aux.iter().map(|a| (*a).clone()).collect(), + witness: Rc::new(witness), + constraints: constraints, + group: this_group, + visited: Cell::new(false) + }); + + outputs.into_iter().map(|a| { + let mut a = (*a).clone(); + + // TODO: we should augment the gadget instead + // of replacing it + debug_assert!(a.gadget.is_none()); + + a.gadget = Some(gadget.clone()); + a + }).collect() +}