diff --git a/src/circuit/boolean.rs b/src/circuit/boolean.rs index 3575cf2..6b5b9cd 100644 --- a/src/circuit/boolean.rs +++ b/src/circuit/boolean.rs @@ -270,6 +270,25 @@ impl Boolean { } } + pub fn lc(&self, one: Var, coeff: E::Fr) -> LinearCombination + { + match self { + &Boolean::Constant(c) => { + if c { + LinearCombination::::zero() + (coeff, one) + } else { + LinearCombination::::zero() + } + }, + &Boolean::Is(ref v) => { + LinearCombination::::zero() + (coeff, v.get_variable()) + }, + &Boolean::Not(ref v) => { + LinearCombination::::zero() + (coeff, one) - (coeff, v.get_variable()) + } + } + } + /// Construct a boolean from a known constant pub fn constant(b: bool) -> Self { Boolean::Constant(b) diff --git a/src/circuit/num.rs b/src/circuit/num.rs index aeda7c5..cce455c 100644 --- a/src/circuit/num.rs +++ b/src/circuit/num.rs @@ -292,6 +292,39 @@ impl AllocatedNum { Ok(()) } + pub fn conditionally_negate( + &self, + mut cs: CS, + condition: &Boolean + ) -> Result + where CS: ConstraintSystem + { + let r = Self::alloc( + cs.namespace(|| "conditional negation result"), + || { + let mut tmp = *self.value.get()?; + if *condition.get_value().get()? { + tmp.negate(); + } + Ok(tmp) + } + )?; + + // (1-c)(x) + (c)(-x) = r + // x - 2cx = r + // (2x) * (c) = x - r + + let one = cs.one(); + cs.enforce( + || "conditional negation", + LinearCombination::zero() + self.variable + self.variable, + condition.lc(one, E::Fr::one()), + LinearCombination::zero() + self.variable - r.variable + ); + + Ok(r) + } + pub fn get_value(&self) -> Option { self.value } @@ -349,6 +382,107 @@ mod test { assert!(!cs.is_satisfied()); } + #[test] + fn test_num_conditional_negation() { + { + let mut cs = TestConstraintSystem::::new(); + + let n = AllocatedNum::alloc(cs.namespace(|| "a"), || Ok(Fr::one())).unwrap(); + let b = Boolean::constant(true); + let n2 = n.conditionally_negate(&mut cs, &b).unwrap(); + + let mut negone = Fr::one(); + negone.negate(); + + assert!(cs.is_satisfied()); + assert!(cs.get("conditional negation result/num") == negone); + assert!(n2.value.unwrap() == negone); + cs.set("conditional negation result/num", Fr::from_str("1").unwrap()); + assert!(!cs.is_satisfied()); + } + { + let mut cs = TestConstraintSystem::::new(); + + let n = AllocatedNum::alloc(cs.namespace(|| "a"), || Ok(Fr::one())).unwrap(); + let b = Boolean::constant(false); + let n2 = n.conditionally_negate(&mut cs, &b).unwrap(); + + assert!(cs.is_satisfied()); + assert!(cs.get("conditional negation result/num") == Fr::one()); + assert!(n2.value.unwrap() == Fr::one()); + cs.set("conditional negation result/num", Fr::from_str("2").unwrap()); + assert!(!cs.is_satisfied()); + } + + { + let mut cs = TestConstraintSystem::::new(); + + let n = AllocatedNum::alloc(cs.namespace(|| "a"), || Ok(Fr::one())).unwrap(); + let b = Boolean::from( + AllocatedBit::alloc(cs.namespace(|| "condition"), Some(true)).unwrap() + ); + let n2 = n.conditionally_negate(&mut cs, &b).unwrap(); + + let mut negone = Fr::one(); + negone.negate(); + + assert!(cs.is_satisfied()); + assert!(cs.get("conditional negation result/num") == negone); + assert!(n2.value.unwrap() == negone); + cs.set("conditional negation result/num", Fr::from_str("1").unwrap()); + assert!(!cs.is_satisfied()); + } + { + let mut cs = TestConstraintSystem::::new(); + + let n = AllocatedNum::alloc(cs.namespace(|| "a"), || Ok(Fr::one())).unwrap(); + let b = Boolean::from( + AllocatedBit::alloc(cs.namespace(|| "condition"), Some(false)).unwrap() + ); + let n2 = n.conditionally_negate(&mut cs, &b).unwrap(); + + assert!(cs.is_satisfied()); + assert!(cs.get("conditional negation result/num") == Fr::one()); + assert!(n2.value.unwrap() == Fr::one()); + cs.set("conditional negation result/num", Fr::from_str("2").unwrap()); + assert!(!cs.is_satisfied()); + } + + { + let mut cs = TestConstraintSystem::::new(); + + let n = AllocatedNum::alloc(cs.namespace(|| "a"), || Ok(Fr::one())).unwrap(); + let b = Boolean::from( + AllocatedBit::alloc(cs.namespace(|| "condition"), Some(false)).unwrap() + ).not(); + let n2 = n.conditionally_negate(&mut cs, &b).unwrap(); + + let mut negone = Fr::one(); + negone.negate(); + + assert!(cs.is_satisfied()); + assert!(cs.get("conditional negation result/num") == negone); + assert!(n2.value.unwrap() == negone); + cs.set("conditional negation result/num", Fr::from_str("1").unwrap()); + assert!(!cs.is_satisfied()); + } + { + let mut cs = TestConstraintSystem::::new(); + + let n = AllocatedNum::alloc(cs.namespace(|| "a"), || Ok(Fr::one())).unwrap(); + let b = Boolean::from( + AllocatedBit::alloc(cs.namespace(|| "condition"), Some(true)).unwrap() + ).not(); + let n2 = n.conditionally_negate(&mut cs, &b).unwrap(); + + assert!(cs.is_satisfied()); + assert!(cs.get("conditional negation result/num") == Fr::one()); + assert!(n2.value.unwrap() == Fr::one()); + cs.set("conditional negation result/num", Fr::from_str("2").unwrap()); + assert!(!cs.is_satisfied()); + } + } + #[test] fn test_num_nonzero() { {