diff --git a/snark/src/lib.rs b/snark/src/lib.rs index b149769..a4d6326 100644 --- a/snark/src/lib.rs +++ b/snark/src/lib.rs @@ -39,7 +39,8 @@ extern "C" { Bt2: *const G2, Ct: *const G1) -> bool; fn libsnarkwrap_test_compare_tau( - i: *const G1, + i1: *const G1, + i2: *const G2, tau: *const Fr, d: libc::uint64_t, qap: *const libc::c_void) -> bool; @@ -123,8 +124,9 @@ pub fn getqap() -> (usize, usize, Fr, CS) { /// Check that the lagrange coefficients computed by tau over /// G1 equal the expected vector. -pub fn compare_tau(v: &[G1], tau: &Fr, cs: &CS) -> bool { - unsafe { libsnarkwrap_test_compare_tau(&v[0], tau, v.len() as u64, cs.0) } +pub fn compare_tau(v1: &[G1], v2: &[G2], tau: &Fr, cs: &CS) -> bool { + assert_eq!(v1.len(), v2.len()); + unsafe { libsnarkwrap_test_compare_tau(&v1[0], &v2[0], tau, v1.len() as u64, cs.0) } } pub trait Pairing { diff --git a/snark/src/libsnarkwrap.cpp b/snark/src/libsnarkwrap.cpp index f8ef8af..4cb925c 100644 --- a/snark/src/libsnarkwrap.cpp +++ b/snark/src/libsnarkwrap.cpp @@ -208,7 +208,8 @@ extern "C" void libsnarkwrap_dropcs(r1cs_constraint_system *cs) } extern "C" bool libsnarkwrap_test_compare_tau( - const curve_G1 *inputs, + const curve_G1 *inputs1, + const curve_G2 *inputs2, const curve_Fr *tau, uint64_t d, const r1cs_constraint_system *cs @@ -221,7 +222,8 @@ extern "C" bool libsnarkwrap_test_compare_tau( bool res = true; for (size_t i = 0; i < d; i++) { - res &= (coeffs[i] * curve_G1::one()) == inputs[i]; + res &= (coeffs[i] * curve_G1::one()) == inputs1[i]; + res &= (coeffs[i] * curve_G2::one()) == inputs2[i]; } return res; diff --git a/src/fft.rs b/src/fft.rs index 37636c4..a4c10fe 100644 --- a/src/fft.rs +++ b/src/fft.rs @@ -1,6 +1,16 @@ use snark::{Group, Fr}; -pub fn fft(v: &[G], omega: Fr) -> Vec +pub fn lagrange_coeffs(v: &[G], omega: Fr, d: usize) -> Vec +{ + let overd = Fr::from_str(&format!("{}", d)).inverse(); + fft(v, omega) + .into_iter() + .rev() // coefficients are in reverse + .map(|e| e * overd) // divide by d + .collect::>() +} + +fn fft(v: &[G], omega: Fr) -> Vec { if v.len() == 2 { vec![ @@ -37,7 +47,7 @@ pub fn fft(v: &[G], omega: Fr) -> Vec #[cfg(test)] mod test { - use super::fft; + use super::lagrange_coeffs; use snark::*; use util::*; @@ -53,25 +63,30 @@ mod test { // Generate powers of tau in G1, from 0 to d exclusive of d let powers_of_tau_g1 = TauPowers::new(tau).take(d).map(|e| G1::one() * e).collect::>(); + // Generate powers of tau in G2, from 0 to d exclusive of d let powers_of_tau_g2 = TauPowers::new(tau).take(d).map(|e| G2::one() * e).collect::>(); + // Perform FFT to compute lagrange coeffs in G1/G2 let overd = Fr::from_str(&format!("{}", d)).inverse(); - let lc1 = fft(&powers_of_tau_g1, omega) // omit tau^d - .into_iter() - .rev() // coefficients are in reverse - .map(|e| e * overd) // divide by d - .collect::>(); - let lc2 = fft(&powers_of_tau_g2, omega) // omit tau^d - .into_iter() - .rev() // coefficients are in reverse - .map(|e| e * overd) // divide by d - .collect::>(); + let lc1 = lagrange_coeffs(&powers_of_tau_g1, omega, d); + let lc2 = lagrange_coeffs(&powers_of_tau_g2, omega, d); + + { + // Perform G1 FFT with wrong omega + let lc1 = lagrange_coeffs(&powers_of_tau_g1, Fr::random(), d); + assert!(!compare_tau(&lc1, &lc2, &tau, &cs)); + } + { + // Perform G2 FFT with wrong omega + let lc2 = lagrange_coeffs(&powers_of_tau_g2, Fr::random(), d); + assert!(!compare_tau(&lc1, &lc2, &tau, &cs)); + } // Compare against libsnark - assert!(compare_tau(&lc1, &tau, &cs)); + assert!(compare_tau(&lc1, &lc2, &tau, &cs)); // Wrong tau - assert!(!compare_tau(&lc1, &Fr::random(), &cs)); + assert!(!compare_tau(&lc1, &lc2, &Fr::random(), &cs)); // Evaluate At, Ct in G1 and Bt in G1/G2 let mut At = (0..num_vars).map(|_| G1::zero()).collect::>();