From b82a2f60f71e1914c3c566b03a13c91c9f1229f7 Mon Sep 17 00:00:00 2001 From: Sean Bowe Date: Sun, 3 Jan 2016 03:45:20 -0700 Subject: [PATCH] Reorganize and remove (temporary) unsafe witnessing --- src/bit.rs | 44 ++++++++++++++++++++++---------------------- src/main.rs | 4 ++-- src/variable.rs | 39 +++++++++++++++++++++++++++------------ 3 files changed, 51 insertions(+), 36 deletions(-) diff --git a/src/bit.rs b/src/bit.rs index eb2b262..eaebde9 100644 --- a/src/bit.rs +++ b/src/bit.rs @@ -11,11 +11,11 @@ pub enum Bit { } fn resolve_not(v: &Var) -> Var { - gadget(&[v], 1, |i, o| { - if *i[0] == FieldT::zero() { - *o[0] = FieldT::one(); + gadget(&[v], 1, |vars| { + if vars.get_input(0) == FieldT::zero() { + vars.set_output(0, FieldT::one()); } else { - *o[0] = FieldT::zero(); + vars.set_output(0, FieldT::zero()); } }, |i, o, cs| { // (1 - a) * 1 = b @@ -47,7 +47,7 @@ impl Bit { } pub fn new(v: &Var) -> Bit { - Is(gadget(&[v], 0, |_, _| {}, |i, o, cs| { + Is(gadget(&[v], 0, |_| {}, |i, o, cs| { cs.push(Constraint); vec![i[0]] @@ -74,11 +74,11 @@ impl Bit { } }, (&Is(ref a), &Is(ref b)) => { - Is(gadget(&[a, b], 1, |i, o| { - if *i[0] != *i[1] { - *o[0] = FieldT::one(); + Is(gadget(&[a, b], 1, |vars| { + if vars.get_input(0) != vars.get_input(1) { + vars.set_output(0, FieldT::one()); } else { - *o[0] = FieldT::zero(); + vars.set_output(0, FieldT::zero()); } }, |i, o, cs| { // (2*b) * c = b+c - a @@ -119,11 +119,11 @@ impl Bit { } }, (&Is(ref a), &Is(ref b)) => { - Is(gadget(&[a, b], 1, |i, o| { - if *i[0] == FieldT::one() && *i[1] == FieldT::one() { - *o[0] = FieldT::one(); + Is(gadget(&[a, b], 1, |vars| { + if vars.get_input(0) == FieldT::one() && vars.get_input(1) == FieldT::one() { + vars.set_output(0, FieldT::one()); } else { - *o[0] = FieldT::zero(); + vars.set_output(0, FieldT::zero()); } }, |i, o, cs| { // a * b = c @@ -142,11 +142,11 @@ impl Bit { }, (&Not(ref n), &Is(ref i)) | (&Is(ref i), &Not(ref n)) => { //Is(i.clone()).and(&Is(resolve_not(n))) - Is(gadget(&[n, i], 1, |i, o| { - if *i[0] == FieldT::zero() && *i[1] == FieldT::one() { - *o[0] = FieldT::one(); + Is(gadget(&[n, i], 1, |vars| { + if vars.get_input(0) == FieldT::zero() && vars.get_input(1) == FieldT::one() { + vars.set_output(0, FieldT::one()); } else { - *o[0] = FieldT::zero(); + vars.set_output(0, FieldT::zero()); } }, |i, o, cs| { // (1-a) * b = c @@ -157,11 +157,11 @@ impl Bit { }, (&Not(ref a), &Not(ref b)) => { //Is(resolve_not(a)).and(&Is(resolve_not(b))) - Is(gadget(&[a, b], 1, |i, o| { - if *i[0] == FieldT::zero() && *i[1] == FieldT::zero() { - *o[0] = FieldT::one(); + Is(gadget(&[a, b], 1, |vars| { + if vars.get_input(0) == FieldT::zero() && vars.get_input(1) == FieldT::zero() { + vars.set_output(0, FieldT::one()); } else { - *o[0] = FieldT::zero(); + vars.set_output(0, FieldT::zero()); } }, |i, o, cs| { // (1 - a) * (1 - b) = c @@ -203,7 +203,7 @@ fn test_binary_op Bit>(op: F, a_in: i64, b_in: i64, c_out: f[1] = FieldT::from(a_in); f[2] = FieldT::from(b_in); - satisfy_field_elements(&mut f, &witness_map); + witness_field_elements(&mut f, &witness_map); assert_eq!(f[3], FieldT::from(c_out)); } diff --git a/src/main.rs b/src/main.rs index f29fb8f..d855640 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,7 +15,7 @@ mod bit; fn main() { tinysnark::init(); - let inbytes = 64; + let inbytes = 1; //for inbits in 0..1024 { let inbits = inbytes * 8; let input: Vec = (0..inbits).map(|i| Bit::new(&Var::new(i+1))).collect(); @@ -34,7 +34,7 @@ fn main() { let mut vars: Vec = (0..counter).map(|_| FieldT::zero()).collect(); vars[0] = FieldT::one(); - satisfy_field_elements(&mut vars, &witness_map); + witness_field_elements(&mut vars, &witness_map); for b in output.iter().flat_map(|e| e.bits()) { print!("{}", if b.val(&vars) { 1 } else { 0 }); diff --git a/src/variable.rs b/src/variable.rs index 52ec332..63275c4 100644 --- a/src/variable.rs +++ b/src/variable.rs @@ -4,23 +4,38 @@ use std::rc::Rc; use std::fmt; use std::collections::BTreeMap; -pub type WitnessMap = BTreeMap, Vec, Rc)>>; +pub type WitnessMap = BTreeMap, Vec, Rc)>>; + +struct VariableView<'a> { + vars: &'a mut [FieldT], + inputs: &'a [usize], + outputs: &'a [usize] +} + +impl<'a> VariableView<'a> { + /// Sets an output variable at `index` to value `to`. + pub fn set_output(&mut self, index: usize, to: FieldT) { + self.vars[self.outputs[index]] = to; + } + + /// Gets the value of an input variable at `index`. + pub fn get_input(&self, index: usize) -> FieldT { + self.vars[self.inputs[index]] + } +} use std::collections::Bound::Unbounded; -pub fn satisfy_field_elements(vars: &mut [FieldT], witness_map: &WitnessMap) { +pub fn witness_field_elements(vars: &mut [FieldT], witness_map: &WitnessMap) { for (n, group) in witness_map.range(Unbounded, Unbounded) { for &(ref i, ref o, ref f) in group.iter() { - let i: Vec<&FieldT> = i.iter().map(|i| &vars[*i]).collect(); - let o: Vec<&FieldT> = o.iter().map(|o| &vars[*o]).collect(); - - let mut o: Vec<&mut FieldT> = unsafe { - use std::mem::transmute; - - transmute(o) + let mut vars = VariableView { + vars: vars, + inputs: &*i, + outputs: &*o }; - f(&i, &mut o); + f(&mut vars); } } } @@ -31,7 +46,7 @@ pub struct Constraint; struct Gadget { inputs: Vec, aux: Vec, - witness: Rc, + witness: Rc, constraints: Vec, group: usize, visited: Cell @@ -109,7 +124,7 @@ pub fn gadget( constrain: C ) -> Vec where C: for<'a> Fn(&[&'a Var], &[&'a Var], &mut Vec) -> Vec<&'a Var>, - W: Fn(&[&FieldT], &mut [&mut FieldT]) + 'static + W: Fn(&mut VariableView) + 'static { let this_group = inputs.iter().map(|i| i.group()).max().map(|a| a+1).unwrap_or(0);