diff --git a/Cargo.toml b/Cargo.toml index 4a2aefd..b32997d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ rev = "7a5b5fc99ae483a0043db7547fb79a6fa44b88a9" [dev-dependencies] hex-literal = "0.1" +rust-crypto = "0.2" [features] default = ["u128-support"] diff --git a/src/circuit/boolean.rs b/src/circuit/boolean.rs index 22eb9b0..08f407e 100644 --- a/src/circuit/boolean.rs +++ b/src/circuit/boolean.rs @@ -523,6 +523,278 @@ impl Boolean { } } } + + /// Computes (a and b) xor ((not a) and c) + pub fn sha256_ch<'a, E, CS>( + mut cs: CS, + a: &'a Self, + b: &'a Self, + c: &'a Self + ) -> Result + where E: Engine, + CS: ConstraintSystem + { + let ch_value = match (a.get_value(), b.get_value(), c.get_value()) { + (Some(a), Some(b), Some(c)) => { + // (a and b) xor ((not a) and c) + Some((a & b) ^ ((!a) & c)) + }, + _ => None + }; + + match (a, b, c) { + (&Boolean::Constant(_), + &Boolean::Constant(_), + &Boolean::Constant(_)) => { + // They're all constants, so we can just compute the value. + + return Ok(Boolean::Constant(ch_value.expect("they're all constants"))); + }, + (&Boolean::Constant(false), _, c) => { + // If a is false + // (a and b) xor ((not a) and c) + // equals + // (false) xor (c) + // equals + // c + return Ok(c.clone()); + }, + (a, &Boolean::Constant(false), c) => { + // If b is false + // (a and b) xor ((not a) and c) + // equals + // ((not a) and c) + return Boolean::and( + cs, + &a.not(), + &c + ); + }, + (a, b, &Boolean::Constant(false)) => { + // If c is false + // (a and b) xor ((not a) and c) + // equals + // (a and b) + return Boolean::and( + cs, + &a, + &b + ); + }, + (a, b, &Boolean::Constant(true)) => { + // If c is true + // (a and b) xor ((not a) and c) + // equals + // (a and b) xor (not a) + // equals + // not (a and (not b)) + return Ok(Boolean::and( + cs, + &a, + &b.not() + )?.not()); + }, + (a, &Boolean::Constant(true), c) => { + // If b is true + // (a and b) xor ((not a) and c) + // equals + // a xor ((not a) and c) + // equals + // not ((not a) and (not c)) + return Ok(Boolean::and( + cs, + &a.not(), + &c.not() + )?.not()); + }, + (&Boolean::Constant(true), _, _) => { + // If a is true + // (a and b) xor ((not a) and c) + // equals + // b xor ((not a) and c) + // So we just continue! + }, + (&Boolean::Is(_), &Boolean::Is(_), &Boolean::Is(_)) | + (&Boolean::Is(_), &Boolean::Is(_), &Boolean::Not(_)) | + (&Boolean::Is(_), &Boolean::Not(_), &Boolean::Is(_)) | + (&Boolean::Is(_), &Boolean::Not(_), &Boolean::Not(_)) | + (&Boolean::Not(_), &Boolean::Is(_), &Boolean::Is(_)) | + (&Boolean::Not(_), &Boolean::Is(_), &Boolean::Not(_)) | + (&Boolean::Not(_), &Boolean::Not(_), &Boolean::Is(_)) | + (&Boolean::Not(_), &Boolean::Not(_), &Boolean::Not(_)) + => {} + } + + let ch = cs.alloc(|| "ch", || { + ch_value.get().map(|v| { + if *v { + E::Fr::one() + } else { + E::Fr::zero() + } + }) + })?; + + // a(b - c) = ch - c + cs.enforce( + || "ch computation", + |_| b.lc(CS::one(), E::Fr::one()) + - &c.lc(CS::one(), E::Fr::one()), + |_| a.lc(CS::one(), E::Fr::one()), + |lc| lc + ch - &c.lc(CS::one(), E::Fr::one()) + ); + + Ok(AllocatedBit { + value: ch_value, + variable: ch + }.into()) + } + + /// Computes (a and b) xor (a and c) xor (b and c) + pub fn sha256_maj<'a, E, CS>( + mut cs: CS, + a: &'a Self, + b: &'a Self, + c: &'a Self, + ) -> Result + where E: Engine, + CS: ConstraintSystem + { + let maj_value = match (a.get_value(), b.get_value(), c.get_value()) { + (Some(a), Some(b), Some(c)) => { + // (a and b) xor (a and c) xor (b and c) + Some((a & b) ^ (a & c) ^ (b & c)) + }, + _ => None + }; + + match (a, b, c) { + (&Boolean::Constant(_), + &Boolean::Constant(_), + &Boolean::Constant(_)) => { + // They're all constants, so we can just compute the value. + + return Ok(Boolean::Constant(maj_value.expect("they're all constants"))); + }, + (&Boolean::Constant(false), b, c) => { + // If a is false, + // (a and b) xor (a and c) xor (b and c) + // equals + // (b and c) + return Boolean::and( + cs, + b, + c + ); + }, + (a, &Boolean::Constant(false), c) => { + // If b is false, + // (a and b) xor (a and c) xor (b and c) + // equals + // (a and c) + return Boolean::and( + cs, + a, + c + ); + }, + (a, b, &Boolean::Constant(false)) => { + // If c is false, + // (a and b) xor (a and c) xor (b and c) + // equals + // (a and b) + return Boolean::and( + cs, + a, + b + ); + }, + (a, b, &Boolean::Constant(true)) => { + // If c is true, + // (a and b) xor (a and c) xor (b and c) + // equals + // (a and b) xor (a) xor (b) + // equals + // not ((not a) and (not b)) + return Ok(Boolean::and( + cs, + &a.not(), + &b.not() + )?.not()); + }, + (a, &Boolean::Constant(true), c) => { + // If b is true, + // (a and b) xor (a and c) xor (b and c) + // equals + // (a) xor (a and c) xor (c) + return Ok(Boolean::and( + cs, + &a.not(), + &c.not() + )?.not()); + }, + (&Boolean::Constant(true), b, c) => { + // If a is true, + // (a and b) xor (a and c) xor (b and c) + // equals + // (b) xor (c) xor (b and c) + return Ok(Boolean::and( + cs, + &b.not(), + &c.not() + )?.not()); + }, + (&Boolean::Is(_), &Boolean::Is(_), &Boolean::Is(_)) | + (&Boolean::Is(_), &Boolean::Is(_), &Boolean::Not(_)) | + (&Boolean::Is(_), &Boolean::Not(_), &Boolean::Is(_)) | + (&Boolean::Is(_), &Boolean::Not(_), &Boolean::Not(_)) | + (&Boolean::Not(_), &Boolean::Is(_), &Boolean::Is(_)) | + (&Boolean::Not(_), &Boolean::Is(_), &Boolean::Not(_)) | + (&Boolean::Not(_), &Boolean::Not(_), &Boolean::Is(_)) | + (&Boolean::Not(_), &Boolean::Not(_), &Boolean::Not(_)) + => {} + } + + let maj = cs.alloc(|| "maj", || { + maj_value.get().map(|v| { + if *v { + E::Fr::one() + } else { + E::Fr::zero() + } + }) + })?; + + // ¬(¬a ∧ ¬b) ∧ ¬(¬a ∧ ¬c) ∧ ¬(¬b ∧ ¬c) + // (1 - ((1 - a) * (1 - b))) * (1 - ((1 - a) * (1 - c))) * (1 - ((1 - b) * (1 - c))) + // (a + b - ab) * (a + c - ac) * (b + c - bc) + // -2abc + ab + ac + bc + // a (-2bc + b + c) + bc + // + // (b) * (c) = (bc) + // (2bc - b - c) * (a) = bc - maj + + let bc = Self::and( + cs.namespace(|| "b and c"), + b, + c + )?; + + cs.enforce( + || "maj computation", + |_| bc.lc(CS::one(), E::Fr::one()) + + &bc.lc(CS::one(), E::Fr::one()) + - &b.lc(CS::one(), E::Fr::one()) + - &c.lc(CS::one(), E::Fr::one()), + |_| a.lc(CS::one(), E::Fr::one()), + |_| bc.lc(CS::one(), E::Fr::one()) - maj + ); + + Ok(AllocatedBit { + value: maj_value, + variable: maj + }.into()) + } } impl From for Boolean { @@ -797,6 +1069,31 @@ mod test { NegatedAllocatedFalse } + impl OperandType { + fn is_constant(&self) -> bool { + match *self { + OperandType::True => true, + OperandType::False => true, + OperandType::AllocatedTrue => false, + OperandType::AllocatedFalse => false, + OperandType::NegatedAllocatedTrue => false, + OperandType::NegatedAllocatedFalse => false + } + } + + fn val(&self) -> bool { + match *self { + OperandType::True => true, + OperandType::False => false, + OperandType::AllocatedTrue => true, + OperandType::AllocatedFalse => false, + OperandType::NegatedAllocatedTrue => false, + OperandType::NegatedAllocatedFalse => true + } + } + } + + #[test] fn test_boolean_xor() { let variants = [ @@ -1115,4 +1412,171 @@ mod test { assert_eq!(bits[254 - 20].value.unwrap(), true); assert_eq!(bits[254 - 23].value.unwrap(), true); } + + #[test] + fn test_boolean_sha256_ch() { + let variants = [ + OperandType::True, + OperandType::False, + OperandType::AllocatedTrue, + OperandType::AllocatedFalse, + OperandType::NegatedAllocatedTrue, + OperandType::NegatedAllocatedFalse + ]; + + for first_operand in variants.iter().cloned() { + for second_operand in variants.iter().cloned() { + for third_operand in variants.iter().cloned() { + let mut cs = TestConstraintSystem::::new(); + + let a; + let b; + let c; + + // ch = (a and b) xor ((not a) and c) + let expected = (first_operand.val() & second_operand.val()) ^ + ((!first_operand.val()) & third_operand.val()); + + { + let mut dyn_construct = |operand, name| { + let cs = cs.namespace(|| name); + + match operand { + OperandType::True => Boolean::constant(true), + OperandType::False => Boolean::constant(false), + OperandType::AllocatedTrue => Boolean::from(AllocatedBit::alloc(cs, Some(true)).unwrap()), + OperandType::AllocatedFalse => Boolean::from(AllocatedBit::alloc(cs, Some(false)).unwrap()), + OperandType::NegatedAllocatedTrue => Boolean::from(AllocatedBit::alloc(cs, Some(true)).unwrap()).not(), + OperandType::NegatedAllocatedFalse => Boolean::from(AllocatedBit::alloc(cs, Some(false)).unwrap()).not(), + } + }; + + a = dyn_construct(first_operand, "a"); + b = dyn_construct(second_operand, "b"); + c = dyn_construct(third_operand, "c"); + } + + let maj = Boolean::sha256_ch(&mut cs, &a, &b, &c).unwrap(); + + assert!(cs.is_satisfied()); + + assert_eq!(maj.get_value().unwrap(), expected); + + if first_operand.is_constant() || + second_operand.is_constant() || + third_operand.is_constant() + { + if first_operand.is_constant() && + second_operand.is_constant() && + third_operand.is_constant() + { + assert_eq!(cs.num_constraints(), 0); + } + } + else + { + assert_eq!(cs.get("ch"), { + if expected { + Fr::one() + } else { + Fr::zero() + } + }); + cs.set("ch", { + if expected { + Fr::zero() + } else { + Fr::one() + } + }); + assert_eq!(cs.which_is_unsatisfied().unwrap(), "ch computation"); + } + } + } + } + } + + #[test] + fn test_boolean_sha256_maj() { + let variants = [ + OperandType::True, + OperandType::False, + OperandType::AllocatedTrue, + OperandType::AllocatedFalse, + OperandType::NegatedAllocatedTrue, + OperandType::NegatedAllocatedFalse + ]; + + for first_operand in variants.iter().cloned() { + for second_operand in variants.iter().cloned() { + for third_operand in variants.iter().cloned() { + let mut cs = TestConstraintSystem::::new(); + + let a; + let b; + let c; + + // maj = (a and b) xor (a and c) xor (b and c) + let expected = (first_operand.val() & second_operand.val()) ^ + (first_operand.val() & third_operand.val()) ^ + (second_operand.val() & third_operand.val()); + + { + let mut dyn_construct = |operand, name| { + let cs = cs.namespace(|| name); + + match operand { + OperandType::True => Boolean::constant(true), + OperandType::False => Boolean::constant(false), + OperandType::AllocatedTrue => Boolean::from(AllocatedBit::alloc(cs, Some(true)).unwrap()), + OperandType::AllocatedFalse => Boolean::from(AllocatedBit::alloc(cs, Some(false)).unwrap()), + OperandType::NegatedAllocatedTrue => Boolean::from(AllocatedBit::alloc(cs, Some(true)).unwrap()).not(), + OperandType::NegatedAllocatedFalse => Boolean::from(AllocatedBit::alloc(cs, Some(false)).unwrap()).not(), + } + }; + + a = dyn_construct(first_operand, "a"); + b = dyn_construct(second_operand, "b"); + c = dyn_construct(third_operand, "c"); + } + + let maj = Boolean::sha256_maj(&mut cs, &a, &b, &c).unwrap(); + + assert!(cs.is_satisfied()); + + assert_eq!(maj.get_value().unwrap(), expected); + + if first_operand.is_constant() || + second_operand.is_constant() || + third_operand.is_constant() + { + if first_operand.is_constant() && + second_operand.is_constant() && + third_operand.is_constant() + { + assert_eq!(cs.num_constraints(), 0); + } + } + else + { + assert_eq!(cs.get("maj"), { + if expected { + Fr::one() + } else { + Fr::zero() + } + }); + cs.set("maj", { + if expected { + Fr::zero() + } else { + Fr::one() + } + }); + assert_eq!(cs.which_is_unsatisfied().unwrap(), "maj computation"); + } + } + } + } + } } diff --git a/src/circuit/mod.rs b/src/circuit/mod.rs index 4cde222..fe0fe50 100644 --- a/src/circuit/mod.rs +++ b/src/circuit/mod.rs @@ -10,8 +10,10 @@ pub mod lookup; pub mod ecc; pub mod pedersen_hash; pub mod multipack; +pub mod sha256; pub mod sapling; +pub mod sprout; use bellman::{ SynthesisError diff --git a/src/circuit/sha256.rs b/src/circuit/sha256.rs new file mode 100644 index 0000000..7b55fc8 --- /dev/null +++ b/src/circuit/sha256.rs @@ -0,0 +1,417 @@ +use super::uint32::UInt32; +use super::multieq::MultiEq; +use super::boolean::Boolean; +use bellman::{ConstraintSystem, SynthesisError}; +use pairing::Engine; + +const ROUND_CONSTANTS: [u32; 64] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2 +]; + +const IV: [u32; 8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, + 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19 +]; + +pub fn sha256_block_no_padding( + mut cs: CS, + input: &[Boolean] +) -> Result, SynthesisError> + where E: Engine, CS: ConstraintSystem +{ + assert_eq!(input.len(), 512); + + Ok(sha256_compression_function( + &mut cs, + &input, + &get_sha256_iv() + )? + .into_iter() + .flat_map(|e| e.into_bits_be()) + .collect()) +} + +pub fn sha256( + mut cs: CS, + input: &[Boolean] +) -> Result, SynthesisError> + where E: Engine, CS: ConstraintSystem +{ + assert!(input.len() % 8 == 0); + + let mut padded = input.to_vec(); + let plen = padded.len() as u64; + // append a single '1' bit + padded.push(Boolean::constant(true)); + // append K '0' bits, where K is the minimum number >= 0 such that L + 1 + K + 64 is a multiple of 512 + while (padded.len() + 64) % 512 != 0 { + padded.push(Boolean::constant(false)); + } + // append L as a 64-bit big-endian integer, making the total post-processed length a multiple of 512 bits + for b in (0..64).rev().map(|i| (plen >> i) & 1 == 1) { + padded.push(Boolean::constant(b)); + } + assert!(padded.len() % 512 == 0); + + let mut cur = get_sha256_iv(); + for (i, block) in padded.chunks(512).enumerate() { + cur = sha256_compression_function( + cs.namespace(|| format!("block {}", i)), + block, + &cur + )?; + } + + Ok(cur.into_iter() + .flat_map(|e| e.into_bits_be()) + .collect()) +} + +fn get_sha256_iv() -> Vec { + IV.iter().map(|&v| UInt32::constant(v)).collect() +} + +fn sha256_compression_function( + cs: CS, + input: &[Boolean], + current_hash_value: &[UInt32] +) -> Result, SynthesisError> + where E: Engine, CS: ConstraintSystem +{ + assert_eq!(input.len(), 512); + assert_eq!(current_hash_value.len(), 8); + + let mut w = input.chunks(32) + .map(|e| UInt32::from_bits_be(e)) + .collect::>(); + + // We can save some constraints by combining some of + // the constraints in different u32 additions + let mut cs = MultiEq::new(cs); + + for i in 16..64 { + let cs = &mut cs.namespace(|| format!("w extension {}", i)); + + // s0 := (w[i-15] rightrotate 7) xor (w[i-15] rightrotate 18) xor (w[i-15] rightshift 3) + let mut s0 = w[i-15].rotr(7); + s0 = s0.xor( + cs.namespace(|| "first xor for s0"), + &w[i-15].rotr(18) + )?; + s0 = s0.xor( + cs.namespace(|| "second xor for s0"), + &w[i-15].shr(3) + )?; + + // s1 := (w[i-2] rightrotate 17) xor (w[i-2] rightrotate 19) xor (w[i-2] rightshift 10) + let mut s1 = w[i-2].rotr(17); + s1 = s1.xor( + cs.namespace(|| "first xor for s1"), + &w[i-2].rotr(19) + )?; + s1 = s1.xor( + cs.namespace(|| "second xor for s1"), + &w[i-2].shr(10) + )?; + + let tmp = UInt32::addmany( + cs.namespace(|| "computation of w[i]"), + &[w[i-16].clone(), s0, w[i-7].clone(), s1] + )?; + + // w[i] := w[i-16] + s0 + w[i-7] + s1 + w.push(tmp); + } + + assert_eq!(w.len(), 64); + + enum Maybe { + Deferred(Vec), + Concrete(UInt32) + } + + impl Maybe { + fn compute( + self, + cs: M, + others: &[UInt32] + ) -> Result + where E: Engine, + CS: ConstraintSystem, + M: ConstraintSystem> + { + Ok(match self { + Maybe::Concrete(ref v) => { + return Ok(v.clone()) + }, + Maybe::Deferred(mut v) => { + v.extend(others.into_iter().cloned()); + UInt32::addmany( + cs, + &v + )? + } + }) + } + } + + let mut a = Maybe::Concrete(current_hash_value[0].clone()); + let mut b = current_hash_value[1].clone(); + let mut c = current_hash_value[2].clone(); + let mut d = current_hash_value[3].clone(); + let mut e = Maybe::Concrete(current_hash_value[4].clone()); + let mut f = current_hash_value[5].clone(); + let mut g = current_hash_value[6].clone(); + let mut h = current_hash_value[7].clone(); + + for i in 0..64 { + let cs = &mut cs.namespace(|| format!("compression round {}", i)); + + // S1 := (e rightrotate 6) xor (e rightrotate 11) xor (e rightrotate 25) + let new_e = e.compute(cs.namespace(|| "deferred e computation"), &[])?; + let mut s1 = new_e.rotr(6); + s1 = s1.xor( + cs.namespace(|| "first xor for s1"), + &new_e.rotr(11) + )?; + s1 = s1.xor( + cs.namespace(|| "second xor for s1"), + &new_e.rotr(25) + )?; + + // ch := (e and f) xor ((not e) and g) + let ch = UInt32::sha256_ch( + cs.namespace(|| "ch"), + &new_e, + &f, + &g + )?; + + // temp1 := h + S1 + ch + k[i] + w[i] + let temp1 = vec![ + h.clone(), + s1, + ch, + UInt32::constant(ROUND_CONSTANTS[i]), + w[i].clone() + ]; + + // S0 := (a rightrotate 2) xor (a rightrotate 13) xor (a rightrotate 22) + let new_a = a.compute(cs.namespace(|| "deferred a computation"), &[])?; + let mut s0 = new_a.rotr(2); + s0 = s0.xor( + cs.namespace(|| "first xor for s0"), + &new_a.rotr(13) + )?; + s0 = s0.xor( + cs.namespace(|| "second xor for s0"), + &new_a.rotr(22) + )?; + + // maj := (a and b) xor (a and c) xor (b and c) + let maj = UInt32::sha256_maj( + cs.namespace(|| "maj"), + &new_a, + &b, + &c + )?; + + // temp2 := S0 + maj + let temp2 = vec![s0, maj]; + + /* + h := g + g := f + f := e + e := d + temp1 + d := c + c := b + b := a + a := temp1 + temp2 + */ + + h = g; + g = f; + f = new_e; + e = Maybe::Deferred(temp1.iter().cloned().chain(Some(d)).collect::>()); + d = c; + c = b; + b = new_a; + a = Maybe::Deferred(temp1.iter().cloned().chain(temp2.iter().cloned()).collect::>()); + } + + /* + Add the compressed chunk to the current hash value: + h0 := h0 + a + h1 := h1 + b + h2 := h2 + c + h3 := h3 + d + h4 := h4 + e + h5 := h5 + f + h6 := h6 + g + h7 := h7 + h + */ + + let h0 = a.compute( + cs.namespace(|| "deferred h0 computation"), + &[current_hash_value[0].clone()] + )?; + + let h1 = UInt32::addmany( + cs.namespace(|| "new h1"), + &[current_hash_value[1].clone(), b] + )?; + + let h2 = UInt32::addmany( + cs.namespace(|| "new h2"), + &[current_hash_value[2].clone(), c] + )?; + + let h3 = UInt32::addmany( + cs.namespace(|| "new h3"), + &[current_hash_value[3].clone(), d] + )?; + + let h4 = e.compute( + cs.namespace(|| "deferred h4 computation"), + &[current_hash_value[4].clone()] + )?; + + let h5 = UInt32::addmany( + cs.namespace(|| "new h5"), + &[current_hash_value[5].clone(), f] + )?; + + let h6 = UInt32::addmany( + cs.namespace(|| "new h6"), + &[current_hash_value[6].clone(), g] + )?; + + let h7 = UInt32::addmany( + cs.namespace(|| "new h7"), + &[current_hash_value[7].clone(), h] + )?; + + Ok(vec![h0, h1, h2, h3, h4, h5, h6, h7]) +} + +#[cfg(test)] +mod test { + use super::*; + use circuit::boolean::AllocatedBit; + use pairing::bls12_381::Bls12; + use circuit::test::TestConstraintSystem; + use rand::{XorShiftRng, SeedableRng, Rng}; + + #[test] + fn test_blank_hash() { + let iv = get_sha256_iv(); + + let mut cs = TestConstraintSystem::::new(); + let mut input_bits: Vec<_> = (0..512).map(|_| Boolean::Constant(false)).collect(); + input_bits[0] = Boolean::Constant(true); + let out = sha256_compression_function( + &mut cs, + &input_bits, + &iv + ).unwrap(); + let out_bits: Vec<_> = out.into_iter().flat_map(|e| e.into_bits_be()).collect(); + + assert!(cs.is_satisfied()); + assert_eq!(cs.num_constraints(), 0); + + let expected = hex!("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"); + + let mut out = out_bits.into_iter(); + for b in expected.into_iter() { + for i in (0..8).rev() { + let c = out.next().unwrap().get_value().unwrap(); + + assert_eq!(c, (b >> i) & 1u8 == 1u8); + } + } + } + + #[test] + fn test_full_block() { + let mut rng = XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + + let iv = get_sha256_iv(); + + let mut cs = TestConstraintSystem::::new(); + let input_bits: Vec<_> = (0..512).map(|i| { + Boolean::from( + AllocatedBit::alloc( + cs.namespace(|| format!("input bit {}", i)), + Some(rng.gen()) + ).unwrap() + ) + }).collect(); + + sha256_compression_function( + cs.namespace(|| "sha256"), + &input_bits, + &iv + ).unwrap(); + + assert!(cs.is_satisfied()); + assert_eq!(cs.num_constraints() - 512, 25840); + } + + #[test] + fn test_against_vectors() { + use crypto::sha2::Sha256; + use crypto::digest::Digest; + + let mut rng = XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + + for input_len in (0..32).chain((32..256).filter(|a| a % 8 == 0)) + { + let mut h = Sha256::new(); + let data: Vec = (0..input_len).map(|_| rng.gen()).collect(); + h.input(&data); + let mut hash_result = [0u8; 32]; + h.result(&mut hash_result[..]); + + let mut cs = TestConstraintSystem::::new(); + let mut input_bits = vec![]; + + for (byte_i, input_byte) in data.into_iter().enumerate() { + for bit_i in (0..8).rev() { + let cs = cs.namespace(|| format!("input bit {} {}", byte_i, bit_i)); + + input_bits.push(AllocatedBit::alloc(cs, Some((input_byte >> bit_i) & 1u8 == 1u8)).unwrap().into()); + } + } + + let r = sha256(&mut cs, &input_bits).unwrap(); + + assert!(cs.is_satisfied()); + + let mut s = hash_result.as_ref().iter() + .flat_map(|&byte| (0..8).rev().map(move |i| (byte >> i) & 1u8 == 1u8)); + + for b in r { + match b { + Boolean::Is(b) => { + assert!(s.next().unwrap() == b.get_value().unwrap()); + }, + Boolean::Not(b) => { + assert!(s.next().unwrap() != b.get_value().unwrap()); + }, + Boolean::Constant(b) => { + assert!(input_len == 0); + assert!(s.next().unwrap() == b); + } + } + } + } + } +} diff --git a/src/circuit/sprout/commitment.rs b/src/circuit/sprout/commitment.rs new file mode 100644 index 0000000..a32f05c --- /dev/null +++ b/src/circuit/sprout/commitment.rs @@ -0,0 +1,42 @@ +use pairing::{Engine}; +use bellman::{ConstraintSystem, SynthesisError}; +use circuit::sha256::{ + sha256 +}; +use circuit::boolean::{ + Boolean +}; + +pub fn note_comm( + cs: CS, + a_pk: &[Boolean], + value: &[Boolean], + rho: &[Boolean], + r: &[Boolean] +) -> Result, SynthesisError> + where E: Engine, CS: ConstraintSystem +{ + assert_eq!(a_pk.len(), 256); + assert_eq!(value.len(), 64); + assert_eq!(rho.len(), 256); + assert_eq!(r.len(), 256); + + let mut image = vec![]; + image.push(Boolean::constant(true)); + image.push(Boolean::constant(false)); + image.push(Boolean::constant(true)); + image.push(Boolean::constant(true)); + image.push(Boolean::constant(false)); + image.push(Boolean::constant(false)); + image.push(Boolean::constant(false)); + image.push(Boolean::constant(false)); + image.extend(a_pk.iter().cloned()); + image.extend(value.iter().cloned()); + image.extend(rho.iter().cloned()); + image.extend(r.iter().cloned()); + + sha256( + cs, + &image + ) +} diff --git a/src/circuit/sprout/input.rs b/src/circuit/sprout/input.rs new file mode 100644 index 0000000..ce69bc0 --- /dev/null +++ b/src/circuit/sprout/input.rs @@ -0,0 +1,226 @@ +use pairing::{Engine}; +use bellman::{ConstraintSystem, SynthesisError}; +use circuit::sha256::{ + sha256_block_no_padding +}; +use circuit::boolean::{ + AllocatedBit, + Boolean +}; + +use super::*; +use super::prfs::*; +use super::commitment::note_comm; + +pub struct InputNote { + pub nf: Vec, + pub mac: Vec, +} + +impl InputNote { + pub fn compute( + mut cs: CS, + a_sk: Option, + rho: Option, + r: Option, + value: &NoteValue, + h_sig: &[Boolean], + nonce: bool, + auth_path: [Option<([u8; 32], bool)>; TREE_DEPTH], + rt: &[Boolean] + ) -> Result + where E: Engine, CS: ConstraintSystem + { + let a_sk = witness_u252( + cs.namespace(|| "a_sk"), + a_sk.as_ref().map(|a_sk| &a_sk.0[..]) + )?; + + let rho = witness_u256( + cs.namespace(|| "rho"), + rho.as_ref().map(|rho| &rho.0[..]) + )?; + + let r = witness_u256( + cs.namespace(|| "r"), + r.as_ref().map(|r| &r.0[..]) + )?; + + let a_pk = prf_a_pk( + cs.namespace(|| "a_pk computation"), + &a_sk + )?; + + let nf = prf_nf( + cs.namespace(|| "nf computation"), + &a_sk, + &rho + )?; + + let mac = prf_pk( + cs.namespace(|| "mac computation"), + &a_sk, + h_sig, + nonce + )?; + + let cm = note_comm( + cs.namespace(|| "cm computation"), + &a_pk, + &value.bits_le(), + &rho, + &r + )?; + + // Witness into the merkle tree + let mut cur = cm.clone(); + + for (i, layer) in auth_path.into_iter().enumerate() { + let cs = &mut cs.namespace(|| format!("layer {}", i)); + + let cur_is_right = AllocatedBit::alloc( + cs.namespace(|| "cur is right"), + layer.as_ref().map(|&(_, p)| p) + )?; + + let lhs = cur; + let rhs = witness_u256( + cs.namespace(|| "sibling"), + layer.as_ref().map(|&(ref sibling, _)| &sibling[..]) + )?; + + // Conditionally swap if cur is right + let preimage = conditionally_swap_u256( + cs.namespace(|| "conditional swap"), + &lhs[..], + &rhs[..], + &cur_is_right + )?; + + cur = sha256_block_no_padding( + cs.namespace(|| "hash of this layer"), + &preimage + )?; + } + + // enforce must be true if the value is nonzero + let enforce = AllocatedBit::alloc( + cs.namespace(|| "enforce"), + value.get_value().map(|n| n != 0) + )?; + + // value * (1 - enforce) = 0 + // If `value` is zero, `enforce` _can_ be zero. + // If `value` is nonzero, `enforce` _must_ be one. + cs.enforce( + || "enforce validity", + |_| value.lc(), + |lc| lc + CS::one() - enforce.get_variable(), + |lc| lc + ); + + assert_eq!(cur.len(), rt.len()); + + // Check that the anchor (exposed as a public input) + // is equal to the merkle tree root that we calculated + // for this note + for (i, (cur, rt)) in cur.into_iter().zip(rt.iter()).enumerate() { + // (cur - rt) * enforce = 0 + // if enforce is zero, cur and rt can be different + // if enforce is one, they must be equal + cs.enforce( + || format!("conditionally enforce correct root for bit {}", i), + |_| cur.lc(CS::one(), E::Fr::one()) - &rt.lc(CS::one(), E::Fr::one()), + |lc| lc + enforce.get_variable(), + |lc| lc + ); + } + + Ok(InputNote { + mac: mac, + nf: nf + }) + } +} + +/// Swaps two 256-bit blobs conditionally, returning the +/// 512-bit concatenation. +pub fn conditionally_swap_u256( + mut cs: CS, + lhs: &[Boolean], + rhs: &[Boolean], + condition: &AllocatedBit +) -> Result, SynthesisError> + where E: Engine, CS: ConstraintSystem, +{ + assert_eq!(lhs.len(), 256); + assert_eq!(rhs.len(), 256); + + let mut new_lhs = vec![]; + let mut new_rhs = vec![]; + + for (i, (lhs, rhs)) in lhs.iter().zip(rhs.iter()).enumerate() { + let cs = &mut cs.namespace(|| format!("bit {}", i)); + + let x = Boolean::from(AllocatedBit::alloc( + cs.namespace(|| "x"), + condition.get_value().and_then(|v| { + if v { + rhs.get_value() + } else { + lhs.get_value() + } + }) + )?); + + // x = (1-condition)lhs + (condition)rhs + // x = lhs - lhs(condition) + rhs(condition) + // x - lhs = condition (rhs - lhs) + // if condition is zero, we don't swap, so + // x - lhs = 0 + // x = lhs + // if condition is one, we do swap, so + // x - lhs = rhs - lhs + // x = rhs + cs.enforce( + || "conditional swap for x", + |lc| lc + &rhs.lc(CS::one(), E::Fr::one()) + - &lhs.lc(CS::one(), E::Fr::one()), + |lc| lc + condition.get_variable(), + |lc| lc + &x.lc(CS::one(), E::Fr::one()) + - &lhs.lc(CS::one(), E::Fr::one()) + ); + + let y = Boolean::from(AllocatedBit::alloc( + cs.namespace(|| "y"), + condition.get_value().and_then(|v| { + if v { + lhs.get_value() + } else { + rhs.get_value() + } + }) + )?); + + // y = (1-condition)rhs + (condition)lhs + // y - rhs = condition (lhs - rhs) + cs.enforce( + || "conditional swap for y", + |lc| lc + &lhs.lc(CS::one(), E::Fr::one()) + - &rhs.lc(CS::one(), E::Fr::one()), + |lc| lc + condition.get_variable(), + |lc| lc + &y.lc(CS::one(), E::Fr::one()) + - &rhs.lc(CS::one(), E::Fr::one()) + ); + + new_lhs.push(x); + new_rhs.push(y); + } + + let mut f = new_lhs; + f.extend(new_rhs); + + assert_eq!(f.len(), 512); + + Ok(f) +} diff --git a/src/circuit/sprout/mod.rs b/src/circuit/sprout/mod.rs new file mode 100644 index 0000000..586de8c --- /dev/null +++ b/src/circuit/sprout/mod.rs @@ -0,0 +1,488 @@ +use pairing::{Engine, Field}; +use bellman::{ConstraintSystem, SynthesisError, Circuit, LinearCombination}; +use circuit::boolean::{ + AllocatedBit, + Boolean +}; +use circuit::multipack::pack_into_inputs; + +mod prfs; +mod commitment; +mod input; +mod output; + +use self::input::*; +use self::output::*; + +pub const TREE_DEPTH: usize = 29; + +pub struct SpendingKey(pub [u8; 32]); +pub struct PayingKey(pub [u8; 32]); +pub struct UniqueRandomness(pub [u8; 32]); +pub struct CommitmentRandomness(pub [u8; 32]); + +pub struct JoinSplit { + pub vpub_old: Option, + pub vpub_new: Option, + pub h_sig: Option<[u8; 32]>, + pub phi: Option<[u8; 32]>, + pub inputs: Vec, + pub outputs: Vec, + pub rt: Option<[u8; 32]>, +} + +pub struct JSInput { + pub value: Option, + pub a_sk: Option, + pub rho: Option, + pub r: Option, + pub auth_path: [Option<([u8; 32], bool)>; TREE_DEPTH] +} + +pub struct JSOutput { + pub value: Option, + pub a_pk: Option, + pub r: Option +} + +impl Circuit for JoinSplit { + fn synthesize>( + self, + cs: &mut CS + ) -> Result<(), SynthesisError> + { + assert_eq!(self.inputs.len(), 2); + assert_eq!(self.outputs.len(), 2); + + // vpub_old is the value entering the + // JoinSplit from the "outside" value + // pool + let vpub_old = NoteValue::new( + cs.namespace(|| "vpub_old"), + self.vpub_old + )?; + + // vpub_new is the value leaving the + // JoinSplit into the "outside" value + // pool + let vpub_new = NoteValue::new( + cs.namespace(|| "vpub_new"), + self.vpub_new + )?; + + // The left hand side of the balance equation + // vpub_old + inputs[0].value + inputs[1].value + let mut lhs = vpub_old.lc(); + + // The right hand side of the balance equation + // vpub_old + inputs[0].value + inputs[1].value + let mut rhs = vpub_new.lc(); + + // Witness rt (merkle tree root) + let rt = witness_u256( + cs.namespace(|| "rt"), + self.rt.as_ref().map(|v| &v[..]) + ).unwrap(); + + // Witness h_sig + let h_sig = witness_u256( + cs.namespace(|| "h_sig"), + self.h_sig.as_ref().map(|v| &v[..]) + ).unwrap(); + + // Witness phi + let phi = witness_u252( + cs.namespace(|| "phi"), + self.phi.as_ref().map(|v| &v[..]) + ).unwrap(); + + let mut input_notes = vec![]; + let mut lhs_total = self.vpub_old; + + // Iterate over the JoinSplit inputs + for (i, input) in self.inputs.into_iter().enumerate() { + let cs = &mut cs.namespace(|| format!("input {}", i)); + + // Accumulate the value of the left hand side + if let Some(value) = input.value { + lhs_total = lhs_total.map(|v| v.wrapping_add(value)); + } + + // Allocate the value of the note + let value = NoteValue::new( + cs.namespace(|| "value"), + input.value + )?; + + // Compute the nonce (for PRF inputs) which is false + // for the first input, and true for the second input. + let nonce = match i { + 0 => false, + 1 => true, + _ => unreachable!() + }; + + // Perform input note computations + input_notes.push(InputNote::compute( + cs.namespace(|| "note"), + input.a_sk, + input.rho, + input.r, + &value, + &h_sig, + nonce, + input.auth_path, + &rt + )?); + + // Add the note value to the left hand side of + // the balance equation + lhs = lhs + &value.lc(); + } + + // Rebind lhs so that it isn't mutable anymore + let lhs = lhs; + + // See zcash/zcash/issues/854 + { + // Expected sum of the left hand side of the balance + // equation, expressed as a 64-bit unsigned integer + let lhs_total = NoteValue::new( + cs.namespace(|| "total value of left hand side"), + lhs_total + )?; + + // Enforce that the left hand side can be expressed as a 64-bit + // integer + cs.enforce( + || "left hand side can be expressed as a 64-bit unsigned integer", + |_| lhs.clone(), + |lc| lc + CS::one(), + |_| lhs_total.lc() + ); + } + + let mut output_notes = vec![]; + + // Iterate over the JoinSplit outputs + for (i, output) in self.outputs.into_iter().enumerate() { + let cs = &mut cs.namespace(|| format!("output {}", i)); + + let value = NoteValue::new( + cs.namespace(|| "value"), + output.value + )?; + + // Compute the nonce (for PRF inputs) which is false + // for the first output, and true for the second output. + let nonce = match i { + 0 => false, + 1 => true, + _ => unreachable!() + }; + + // Perform output note computations + output_notes.push(OutputNote::compute( + cs.namespace(|| "note"), + output.a_pk, + &value, + output.r, + &phi, + &h_sig, + nonce + )?); + + // Add the note value to the right hand side of + // the balance equation + rhs = rhs + &value.lc(); + } + + // Enforce that balance is equal + cs.enforce( + || "balance equation", + |_| lhs.clone(), + |lc| lc + CS::one(), + |_| rhs + ); + + let mut public_inputs = vec![]; + public_inputs.extend(rt); + public_inputs.extend(h_sig); + + for note in input_notes { + public_inputs.extend(note.nf); + public_inputs.extend(note.mac); + } + + for note in output_notes { + public_inputs.extend(note.cm); + } + + public_inputs.extend(vpub_old.bits_le()); + public_inputs.extend(vpub_new.bits_le()); + + pack_into_inputs(cs.namespace(|| "input packing"), &public_inputs) + } +} + +pub struct NoteValue { + value: Option, + // Least significant digit first + bits: Vec +} + +impl NoteValue { + fn new( + mut cs: CS, + value: Option + ) -> Result + where E: Engine, CS: ConstraintSystem, + { + let mut values; + match value { + Some(mut val) => { + values = vec![]; + for _ in 0..64 { + values.push(Some(val & 1 == 1)); + val >>= 1; + } + }, + None => { + values = vec![None; 64]; + } + } + + let mut bits = vec![]; + for (i, value) in values.into_iter().enumerate() { + bits.push( + AllocatedBit::alloc( + cs.namespace(|| format!("bit {}", i)), + value + )? + ); + } + + Ok(NoteValue { + value: value, + bits: bits + }) + } + + /// Encodes the bits of the value into little-endian + /// byte order. + fn bits_le(&self) -> Vec { + self.bits.chunks(8) + .flat_map(|v| v.iter().rev()) + .cloned() + .map(|e| Boolean::from(e)) + .collect() + } + + /// Computes this value as a linear combination of + /// its bits. + fn lc(&self) -> LinearCombination { + let mut tmp = LinearCombination::zero(); + + let mut coeff = E::Fr::one(); + for b in &self.bits { + tmp = tmp + (coeff, b.get_variable()); + coeff.double(); + } + + tmp + } + + fn get_value(&self) -> Option { + self.value + } +} + +/// Witnesses some bytes in the constraint system, +/// skipping the first `skip_bits`. +fn witness_bits( + mut cs: CS, + value: Option<&[u8]>, + num_bits: usize, + skip_bits: usize +) -> Result, SynthesisError> + where E: Engine, CS: ConstraintSystem, +{ + let bit_values = if let Some(value) = value { + let mut tmp = vec![]; + for b in value.iter() + .flat_map(|&m| (0..8).rev().map(move |i| m >> i & 1 == 1)) + .skip(skip_bits) + { + tmp.push(Some(b)); + } + tmp + } else { + vec![None; num_bits] + }; + assert_eq!(bit_values.len(), num_bits); + + let mut bits = vec![]; + + for (i, value) in bit_values.into_iter().enumerate() { + bits.push(Boolean::from(AllocatedBit::alloc( + cs.namespace(|| format!("bit {}", i)), + value + )?)); + } + + Ok(bits) +} + +fn witness_u256( + cs: CS, + value: Option<&[u8]>, +) -> Result, SynthesisError> + where E: Engine, CS: ConstraintSystem, +{ + witness_bits(cs, value, 256, 0) +} + +fn witness_u252( + cs: CS, + value: Option<&[u8]>, +) -> Result, SynthesisError> + where E: Engine, CS: ConstraintSystem, +{ + witness_bits(cs, value, 252, 4) +} + +#[test] +fn test_sprout_constraints() { + use pairing::bls12_381::{Bls12}; + use ::circuit::test::*; + + use byteorder::{WriteBytesExt, ReadBytesExt, LittleEndian}; + + let test_vector = include_bytes!("test_vectors.dat"); + let mut test_vector = &test_vector[..]; + + fn get_u256(mut reader: R) -> [u8; 32] { + let mut result = [0u8; 32]; + + for i in 0..32 { + result[i] = reader.read_u8().unwrap(); + } + + result + } + + while test_vector.len() != 0 { + let mut cs = TestConstraintSystem::::new(); + + let phi = Some(get_u256(&mut test_vector)); + let rt = Some(get_u256(&mut test_vector)); + let h_sig = Some(get_u256(&mut test_vector)); + + let mut inputs = vec![]; + for _ in 0..2 { + test_vector.read_u8().unwrap(); + + let mut auth_path = [None; TREE_DEPTH]; + for i in (0..TREE_DEPTH).rev() { + test_vector.read_u8().unwrap(); + + let sibling = get_u256(&mut test_vector); + + auth_path[i] = Some((sibling, false)); + } + let mut position = test_vector.read_u64::().unwrap(); + for i in 0..TREE_DEPTH { + auth_path[i].as_mut().map(|p| { + p.1 = (position & 1) == 1 + }); + + position >>= 1; + } + + // a_pk + let _ = Some(SpendingKey(get_u256(&mut test_vector))); + let value = Some(test_vector.read_u64::().unwrap()); + let rho = Some(UniqueRandomness(get_u256(&mut test_vector))); + let r = Some(CommitmentRandomness(get_u256(&mut test_vector))); + let a_sk = Some(SpendingKey(get_u256(&mut test_vector))); + + inputs.push( + JSInput { + value: value, + a_sk: a_sk, + rho: rho, + r: r, + auth_path: auth_path + } + ); + } + + let mut outputs = vec![]; + + for _ in 0..2 { + let a_pk = Some(PayingKey(get_u256(&mut test_vector))); + let value = Some(test_vector.read_u64::().unwrap()); + get_u256(&mut test_vector); + let r = Some(CommitmentRandomness(get_u256(&mut test_vector))); + + outputs.push( + JSOutput { + value: value, + a_pk: a_pk, + r: r + } + ); + } + + let vpub_old = Some(test_vector.read_u64::().unwrap()); + let vpub_new = Some(test_vector.read_u64::().unwrap()); + + let nf1 = get_u256(&mut test_vector); + let nf2 = get_u256(&mut test_vector); + + let cm1 = get_u256(&mut test_vector); + let cm2 = get_u256(&mut test_vector); + + let mac1 = get_u256(&mut test_vector); + let mac2 = get_u256(&mut test_vector); + + let js = JoinSplit { + vpub_old: vpub_old, + vpub_new: vpub_new, + h_sig: h_sig, + phi: phi, + inputs: inputs, + outputs: outputs, + rt: rt + }; + + js.synthesize(&mut cs).unwrap(); + + if let Some(s) = cs.which_is_unsatisfied() { + panic!("{:?}", s); + } + assert!(cs.is_satisfied()); + assert_eq!(cs.num_constraints(), 1989085); + assert_eq!(cs.num_inputs(), 10); + assert_eq!(cs.hash(), "1a228d3c6377130d1778c7885811dc8b8864049cb5af8aff7e6cd46c5bc4b84c"); + + let mut expected_inputs = vec![]; + expected_inputs.extend(rt.unwrap().to_vec()); + expected_inputs.extend(h_sig.unwrap().to_vec()); + expected_inputs.extend(nf1.to_vec()); + expected_inputs.extend(mac1.to_vec()); + expected_inputs.extend(nf2.to_vec()); + expected_inputs.extend(mac2.to_vec()); + expected_inputs.extend(cm1.to_vec()); + expected_inputs.extend(cm2.to_vec()); + expected_inputs.write_u64::(vpub_old.unwrap()).unwrap(); + expected_inputs.write_u64::(vpub_new.unwrap()).unwrap(); + + use circuit::multipack; + + let expected_inputs = multipack::bytes_to_bits(&expected_inputs); + let expected_inputs = multipack::compute_multipacking::(&expected_inputs); + + assert!(cs.verify(&expected_inputs)); + } +} diff --git a/src/circuit/sprout/output.rs b/src/circuit/sprout/output.rs new file mode 100644 index 0000000..9cdbf52 --- /dev/null +++ b/src/circuit/sprout/output.rs @@ -0,0 +1,54 @@ +use pairing::{Engine}; +use bellman::{ConstraintSystem, SynthesisError}; +use circuit::boolean::{Boolean}; + +use super::*; +use super::prfs::*; +use super::commitment::note_comm; + +pub struct OutputNote { + pub cm: Vec +} + +impl OutputNote { + pub fn compute<'a, E, CS>( + mut cs: CS, + a_pk: Option, + value: &NoteValue, + r: Option, + phi: &[Boolean], + h_sig: &[Boolean], + nonce: bool + ) -> Result + where E: Engine, CS: ConstraintSystem, + { + let rho = prf_rho( + cs.namespace(|| "rho"), + phi, + h_sig, + nonce + )?; + + let a_pk = witness_u256( + cs.namespace(|| "a_pk"), + a_pk.as_ref().map(|a_pk| &a_pk.0[..]) + )?; + + let r = witness_u256( + cs.namespace(|| "r"), + r.as_ref().map(|r| &r.0[..]) + )?; + + let cm = note_comm( + cs.namespace(|| "cm computation"), + &a_pk, + &value.bits_le(), + &rho, + &r + )?; + + Ok(OutputNote { + cm: cm + }) + } +} diff --git a/src/circuit/sprout/prfs.rs b/src/circuit/sprout/prfs.rs new file mode 100644 index 0000000..fff8648 --- /dev/null +++ b/src/circuit/sprout/prfs.rs @@ -0,0 +1,79 @@ +use pairing::{Engine}; +use bellman::{ConstraintSystem, SynthesisError}; +use circuit::sha256::{ + sha256_block_no_padding +}; +use circuit::boolean::{ + Boolean +}; + +fn prf( + cs: CS, + a: bool, + b: bool, + c: bool, + d: bool, + x: &[Boolean], + y: &[Boolean] +) -> Result, SynthesisError> + where E: Engine, CS: ConstraintSystem +{ + assert_eq!(x.len(), 252); + assert_eq!(y.len(), 256); + + let mut image = vec![]; + image.push(Boolean::constant(a)); + image.push(Boolean::constant(b)); + image.push(Boolean::constant(c)); + image.push(Boolean::constant(d)); + image.extend(x.iter().cloned()); + image.extend(y.iter().cloned()); + + assert_eq!(image.len(), 512); + + sha256_block_no_padding( + cs, + &image + ) +} + +pub fn prf_a_pk( + cs: CS, + a_sk: &[Boolean] +) -> Result, SynthesisError> + where E: Engine, CS: ConstraintSystem +{ + prf(cs, true, true, false, false, a_sk, &(0..256).map(|_| Boolean::constant(false)).collect::>()) +} + +pub fn prf_nf( + cs: CS, + a_sk: &[Boolean], + rho: &[Boolean] +) -> Result, SynthesisError> + where E: Engine, CS: ConstraintSystem +{ + prf(cs, true, true, true, false, a_sk, rho) +} + +pub fn prf_pk( + cs: CS, + a_sk: &[Boolean], + h_sig: &[Boolean], + nonce: bool +) -> Result, SynthesisError> + where E: Engine, CS: ConstraintSystem +{ + prf(cs, false, nonce, false, false, a_sk, h_sig) +} + +pub fn prf_rho( + cs: CS, + phi: &[Boolean], + h_sig: &[Boolean], + nonce: bool +) -> Result, SynthesisError> + where E: Engine, CS: ConstraintSystem +{ + prf(cs, false, nonce, true, false, phi, h_sig) +} diff --git a/src/circuit/sprout/test_vectors.dat b/src/circuit/sprout/test_vectors.dat new file mode 100644 index 0000000..1316955 Binary files /dev/null and b/src/circuit/sprout/test_vectors.dat differ diff --git a/src/circuit/uint32.rs b/src/circuit/uint32.rs index 4714724..e254132 100644 --- a/src/circuit/uint32.rs +++ b/src/circuit/uint32.rs @@ -87,6 +87,31 @@ impl UInt32 { }) } + pub fn into_bits_be(&self) -> Vec { + self.bits.iter().rev().cloned().collect() + } + + pub fn from_bits_be(bits: &[Boolean]) -> Self { + assert_eq!(bits.len(), 32); + + let mut value = Some(0u32); + for b in bits { + value.as_mut().map(|v| *v <<= 1); + + match b.get_value() { + Some(true) => { value.as_mut().map(|v| *v |= 1); }, + Some(false) => {}, + None => { value = None; } + } + } + + UInt32 { + value: value, + bits: bits.iter().rev().cloned().collect() + } + } + + /// Turns this `UInt32` into its little-endian byte order representation. pub fn into_bits(&self) -> Vec { self.bits.chunks(8) @@ -155,6 +180,104 @@ impl UInt32 { } } + pub fn shr(&self, by: usize) -> Self { + let by = by % 32; + + let fill = Boolean::constant(false); + + let new_bits = self.bits + .iter() // The bits are least significant first + .skip(by) // Skip the bits that will be lost during the shift + .chain(Some(&fill).into_iter().cycle()) // Rest will be zeros + .take(32) // Only 32 bits needed! + .cloned() + .collect(); + + UInt32 { + bits: new_bits, + value: self.value.map(|v| v >> by as u32) + } + } + + fn triop( + mut cs: CS, + a: &Self, + b: &Self, + c: &Self, + tri_fn: F, + circuit_fn: U + ) -> Result + where E: Engine, + CS: ConstraintSystem, + F: Fn(u32, u32, u32) -> u32, + U: Fn(&mut CS, usize, &Boolean, &Boolean, &Boolean) -> Result + { + let new_value = match (a.value, b.value, c.value) { + (Some(a), Some(b), Some(c)) => { + Some(tri_fn(a, b, c)) + }, + _ => None + }; + + let bits = a.bits.iter() + .zip(b.bits.iter()) + .zip(c.bits.iter()) + .enumerate() + .map(|(i, ((a, b), c))| circuit_fn(&mut cs, i, a, b, c)) + .collect::>()?; + + Ok(UInt32 { + bits: bits, + value: new_value + }) + } + + /// Compute the `maj` value (a and b) xor (a and c) xor (b and c) + /// during SHA256. + pub fn sha256_maj( + cs: CS, + a: &Self, + b: &Self, + c: &Self + ) -> Result + where E: Engine, + CS: ConstraintSystem + { + Self::triop(cs, a, b, c, |a, b, c| (a & b) ^ (a & c) ^ (b & c), + |cs, i, a, b, c| { + Boolean::sha256_maj( + cs.namespace(|| format!("maj {}", i)), + a, + b, + c + ) + } + ) + } + + /// Compute the `ch` value `(a and b) xor ((not a) and c)` + /// during SHA256. + pub fn sha256_ch( + cs: CS, + a: &Self, + b: &Self, + c: &Self + ) -> Result + where E: Engine, + CS: ConstraintSystem + { + Self::triop(cs, a, b, c, |a, b, c| (a & b) ^ ((!a) & c), + |cs, i, a, b, c| { + Boolean::sha256_ch( + cs.namespace(|| format!("ch {}", i)), + a, + b, + c + ) + } + ) + } + /// XOR this `UInt32` with another `UInt32` pub fn xor( &self, @@ -304,6 +427,37 @@ mod test { use bellman::{ConstraintSystem}; use circuit::multieq::MultiEq; + #[test] + fn test_uint32_from_bits_be() { + let mut rng = XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0653]); + + for _ in 0..1000 { + let mut v = (0..32).map(|_| Boolean::constant(rng.gen())).collect::>(); + + let b = UInt32::from_bits_be(&v); + + for (i, bit) in b.bits.iter().enumerate() { + match bit { + &Boolean::Constant(bit) => { + assert!(bit == ((b.value.unwrap() >> i) & 1 == 1)); + }, + _ => unreachable!() + } + } + + let expected_to_be_same = b.into_bits_be(); + + for x in v.iter().zip(expected_to_be_same.iter()) + { + match x { + (&Boolean::Constant(true), &Boolean::Constant(true)) => {}, + (&Boolean::Constant(false), &Boolean::Constant(false)) => {}, + _ => unreachable!() + } + } + } + } + #[test] fn test_uint32_from_bits() { let mut rng = XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0653]); @@ -483,6 +637,7 @@ mod test { for i in 0..32 { let b = a.rotr(i); + assert_eq!(a.bits.len(), b.bits.len()); assert!(b.value.unwrap() == num); @@ -501,4 +656,106 @@ mod test { num = num.rotate_right(1); } } + + #[test] + fn test_uint32_shr() { + let mut rng = XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0654]); + + for _ in 0..50 { + for i in 0..60 { + let num = rng.gen(); + let a = UInt32::constant(num).shr(i); + let b = UInt32::constant(num >> i); + + assert_eq!(a.value.unwrap(), num >> i); + + assert_eq!(a.bits.len(), b.bits.len()); + for (a, b) in a.bits.iter().zip(b.bits.iter()) { + assert_eq!(a.get_value().unwrap(), b.get_value().unwrap()); + } + } + } + } + + #[test] + fn test_uint32_sha256_maj() { + let mut rng = XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0653]); + + for _ in 0..1000 { + let mut cs = TestConstraintSystem::::new(); + + let a: u32 = rng.gen(); + let b: u32 = rng.gen(); + let c: u32 = rng.gen(); + + let mut expected = (a & b) ^ (a & c) ^ (b & c); + + let a_bit = UInt32::alloc(cs.namespace(|| "a_bit"), Some(a)).unwrap(); + let b_bit = UInt32::constant(b); + let c_bit = UInt32::alloc(cs.namespace(|| "c_bit"), Some(c)).unwrap(); + + let r = UInt32::sha256_maj(&mut cs, &a_bit, &b_bit, &c_bit).unwrap(); + + assert!(cs.is_satisfied()); + + assert!(r.value == Some(expected)); + + for b in r.bits.iter() { + match b { + &Boolean::Is(ref b) => { + assert!(b.get_value().unwrap() == (expected & 1 == 1)); + }, + &Boolean::Not(ref b) => { + assert!(!b.get_value().unwrap() == (expected & 1 == 1)); + }, + &Boolean::Constant(b) => { + assert!(b == (expected & 1 == 1)); + } + } + + expected >>= 1; + } + } + } + + #[test] + fn test_uint32_sha256_ch() { + let mut rng = XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0653]); + + for _ in 0..1000 { + let mut cs = TestConstraintSystem::::new(); + + let a: u32 = rng.gen(); + let b: u32 = rng.gen(); + let c: u32 = rng.gen(); + + let mut expected = (a & b) ^ ((!a) & c); + + let a_bit = UInt32::alloc(cs.namespace(|| "a_bit"), Some(a)).unwrap(); + let b_bit = UInt32::constant(b); + let c_bit = UInt32::alloc(cs.namespace(|| "c_bit"), Some(c)).unwrap(); + + let r = UInt32::sha256_ch(&mut cs, &a_bit, &b_bit, &c_bit).unwrap(); + + assert!(cs.is_satisfied()); + + assert!(r.value == Some(expected)); + + for b in r.bits.iter() { + match b { + &Boolean::Is(ref b) => { + assert!(b.get_value().unwrap() == (expected & 1 == 1)); + }, + &Boolean::Not(ref b) => { + assert!(!b.get_value().unwrap() == (expected & 1 == 1)); + }, + &Boolean::Constant(b) => { + assert!(b == (expected & 1 == 1)); + } + } + + expected >>= 1; + } + } + } } diff --git a/src/lib.rs b/src/lib.rs index a053dd1..44e10c0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,6 +9,9 @@ extern crate byteorder; #[macro_use] extern crate hex_literal; +#[cfg(test)] +extern crate crypto; + pub mod jubjub; pub mod group_hash; pub mod circuit;