diff --git a/src/lib.rs b/src/lib.rs index 480cb96fb..ba2390168 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,21 +7,22 @@ use std::error::Error; /// This represents a linear combination of some variables, with coefficients /// in the scalar field of a pairing-friendly elliptic curve group. -pub struct LinearCombination(Vec<(T, E::Fr)>); +#[derive(Clone)] +pub struct LinearCombination(Vec<(T, E::Fr)>); -impl AsRef<[(T, E::Fr)]> for LinearCombination { +impl AsRef<[(T, E::Fr)]> for LinearCombination { fn as_ref(&self) -> &[(T, E::Fr)] { &self.0 } } -impl LinearCombination { +impl LinearCombination { pub fn zero() -> LinearCombination { LinearCombination(vec![]) } } -impl Add<(E::Fr, T)> for LinearCombination { +impl Add<(E::Fr, T)> for LinearCombination { type Output = LinearCombination; fn add(mut self, (coeff, var): (E::Fr, T)) -> LinearCombination { @@ -31,7 +32,7 @@ impl Add<(E::Fr, T)> for LinearCombination { } } -impl Sub<(E::Fr, T)> for LinearCombination { +impl Sub<(E::Fr, T)> for LinearCombination { type Output = LinearCombination; fn sub(self, (mut coeff, var): (E::Fr, T)) -> LinearCombination { @@ -41,7 +42,7 @@ impl Sub<(E::Fr, T)> for LinearCombination { } } -impl Add for LinearCombination { +impl Add for LinearCombination { type Output = LinearCombination; fn add(self, other: T) -> LinearCombination { @@ -49,7 +50,7 @@ impl Add for LinearCombination { } } -impl Sub for LinearCombination { +impl Sub for LinearCombination { type Output = LinearCombination; fn sub(self, other: T) -> LinearCombination { @@ -81,9 +82,38 @@ impl<'a, T: Copy, E: Engine> Sub<&'a LinearCombination> for LinearCombinat } } +impl<'a, T: Copy, E: Engine> Add<(E::Fr, &'a LinearCombination)> for LinearCombination { + type Output = LinearCombination; + + fn add(mut self, (coeff, other): (E::Fr, &'a LinearCombination)) -> LinearCombination { + for s in &other.0 { + let mut tmp = s.1; + tmp.mul_assign(&coeff); + self = self + (tmp, s.0); + } + + self + } +} + +impl<'a, T: Copy, E: Engine> Sub<(E::Fr, &'a LinearCombination)> for LinearCombination { + type Output = LinearCombination; + + fn sub(mut self, (coeff, other): (E::Fr, &'a LinearCombination)) -> LinearCombination { + for s in &other.0 { + let mut tmp = s.1; + tmp.mul_assign(&coeff); + self = self - (tmp, s.0); + } + + self + } +} + #[test] fn test_lc() { use pairing::bls12_381::{Bls12, Fr}; + use pairing::PrimeField; let a = LinearCombination::::zero() + 0usize + 1usize + 2usize - 3usize; @@ -94,7 +124,7 @@ fn test_lc() { let x = LinearCombination::::zero() + (Fr::one(), 0usize) - (Fr::one(), 1usize); let y = LinearCombination::::zero() + (Fr::one(), 2usize) - (Fr::one(), 3usize); - let z = x + &y - &y; + let z = x.clone() + &y - &y; assert_eq!(z.0, vec![ (0usize, Fr::one()), @@ -104,6 +134,20 @@ fn test_lc() { (2usize, negone), (3usize, Fr::one()) ]); + + let coeff = Fr::from_str("3").unwrap(); + let mut neg_coeff = coeff; + neg_coeff.negate(); + let z = x + (coeff, &y) - (coeff, &y); + + assert_eq!(z.0, vec![ + (0usize, Fr::one()), + (1usize, negone), + (2usize, Fr::from_str("3").unwrap()), + (3usize, neg_coeff), + (2usize, neg_coeff), + (3usize, Fr::from_str("3").unwrap()) + ]); } /// This is an error that could occur during circuit synthesis contexts,