diff --git a/src/groth16/domain.rs b/src/groth16/domain.rs index 6db3acc0d..c20e164f7 100644 --- a/src/groth16/domain.rs +++ b/src/groth16/domain.rs @@ -51,7 +51,7 @@ impl EvaluationDomain { pub fn ifft>(&self, e: &E, v: &mut [T]) { assert!(v.len() == self.m as usize); - parallel_fft(e, v, &self.omegainv, self.exp); + best_fft(e, v, &self.omegainv, self.exp); let chunk = (v.len() / num_cpus::get()) + 1; @@ -112,58 +112,105 @@ impl EvaluationDomain { }); } - pub fn fft>(&self, e: &E, a: &mut [T]) - { - parallel_fft(e, a, &self.omega, self.exp); - } -} + pub fn mul_assign(&self, e: &E, a: &mut [E::Fr], b: Vec) { + assert_eq!(a.len(), b.len()); -fn parallel_fft>(e: &E, a: &mut [T], omega: &E::Fr, log_n: u64) -{ - let log_cpus = get_log_cpus(); - let num_cpus = 1 << log_cpus; - - if log_n < log_cpus { - serial_fft(e, a, omega, log_n) - } else { - // Shuffle - let log_new_n = log_n - log_cpus; - let mut tmp = vec![vec![T::group_zero(e); 1 << log_new_n]; num_cpus]; - let omega_num_cpus = omega.pow(e, &[num_cpus as u64]); + let chunk = (a.len() / num_cpus::get()) + 1; crossbeam::scope(|scope| { - let a = &*a; - - for (j, tmp) in tmp.iter_mut().enumerate() { + for (a, b) in a.chunks_mut(chunk).zip(b.chunks(chunk)) { scope.spawn(move || { - let omega_j = omega.pow(e, &[j as u64]); - let omega_step = omega.pow(e, &[(j as u64) << log_new_n]); - - let mut elt = E::Fr::one(e); - for i in 0..(1 << log_new_n) { - for s in 0..num_cpus { - let idx = (i + (s << log_new_n)) % (1 << log_n); - let mut t = a[idx]; - t.group_mul_assign(e, &elt); - tmp[i].group_add_assign(e, &t); - elt.mul_assign(e, &omega_step); - } - elt.mul_assign(e, &omega_j); + for (a, b) in a.iter_mut().zip(b.iter()) { + a.mul_assign(e, b); } - - serial_fft(e, tmp, &omega_num_cpus, log_new_n); }); } }); - - // TODO: parallelize - // Unshuffle - for i in 0..num_cpus { - for j in 0..(1 << log_new_n) { - a[(j << log_cpus) + i] = tmp[i][j]; - } - } } + + pub fn sub_assign(&self, e: &E, a: &mut [E::Fr], b: Vec) { + assert_eq!(a.len(), b.len()); + + let chunk = (a.len() / num_cpus::get()) + 1; + + crossbeam::scope(|scope| { + for (a, b) in a.chunks_mut(chunk).zip(b.chunks(chunk)) { + scope.spawn(move || { + for (a, b) in a.iter_mut().zip(b.iter()) { + a.sub_assign(e, b); + } + }); + } + }); + } + + pub fn fft>(&self, e: &E, a: &mut [T]) + { + best_fft(e, a, &self.omega, self.exp); + } +} + +fn best_fft>(e: &E, a: &mut [T], omega: &E::Fr, log_n: u64) +{ + let log_cpus = get_log_cpus(); + + if log_n < log_cpus { + serial_fft(e, a, omega, log_n); + } else { + parallel_fft(e, a, omega, log_n, log_cpus); + } +} + +fn parallel_fft>(e: &E, a: &mut [T], omega: &E::Fr, log_n: u64, log_cpus: u64) +{ + assert!(log_n >= log_cpus); + + let num_cpus = 1 << log_cpus; + let log_new_n = log_n - log_cpus; + let mut tmp = vec![vec![T::group_zero(e); 1 << log_new_n]; num_cpus]; + let omega_num_cpus = omega.pow(e, &[num_cpus as u64]); + + crossbeam::scope(|scope| { + let a = &*a; + + for (j, tmp) in tmp.iter_mut().enumerate() { + scope.spawn(move || { + let omega_j = omega.pow(e, &[j as u64]); + let omega_step = omega.pow(e, &[(j as u64) << log_new_n]); + + let mut elt = E::Fr::one(e); + for i in 0..(1 << log_new_n) { + for s in 0..num_cpus { + let idx = (i + (s << log_new_n)) % (1 << log_n); + let mut t = a[idx]; + t.group_mul_assign(e, &elt); + tmp[i].group_add_assign(e, &t); + elt.mul_assign(e, &omega_step); + } + elt.mul_assign(e, &omega_j); + } + + serial_fft(e, tmp, &omega_num_cpus, log_new_n); + }); + } + }); + + let chunk = (a.len() / num_cpus) + 1; + + crossbeam::scope(|scope| { + let tmp = &tmp; + + for (idx, a) in a.chunks_mut(chunk).enumerate() { + scope.spawn(move || { + let mut idx = idx * chunk; + let mask = (1 << log_cpus) - 1; + for a in a { + *a = tmp[idx & mask][idx >> log_cpus]; + idx += 1; + } + }); + } + }); } fn serial_fft>(e: &E, a: &mut [T], omega: &E::Fr, log_n: u64) @@ -298,3 +345,30 @@ fn test_log2_floor() { assert_eq!(log2_floor(7), 2); assert_eq!(log2_floor(8), 3); } + +#[test] +fn parallel_fft_consistency() { + use curves::*; + use curves::bls381::{Bls381, Fr}; + use std::cmp::min; + use rand; + + let e = &Bls381::new(); + let rng = &mut rand::thread_rng(); + + for log_d in 0..10 { + let d = 1 << log_d; + let domain = EvaluationDomain::new(e, d); + assert_eq!(domain.m, d); + + for log_cpus in 0..min(log_d, 3) { + let mut v1 = (0..d).map(|_| Fr::random(e, rng)).collect::>(); + let mut v2 = v1.clone(); + + parallel_fft(e, &mut v1, &domain.omega, log_d, log_cpus); + serial_fft(e, &mut v2, &domain.omega, log_d); + + assert_eq!(v1, v2); + } + } +} diff --git a/src/groth16/mod.rs b/src/groth16/mod.rs index cee2db0c4..d2f0a4eb8 100644 --- a/src/groth16/mod.rs +++ b/src/groth16/mod.rs @@ -464,12 +464,8 @@ pub fn prove>( domain.coset_fft(e, &mut prover.c); let mut h = prover.a; - for (h, b) in h.iter_mut().zip(prover.b.into_iter()) { - h.mul_assign(e, &b); - } - for (h, c) in h.iter_mut().zip(prover.c.into_iter()) { - h.sub_assign(e, &c); - } + domain.mul_assign(e, &mut h, prover.b); + domain.sub_assign(e, &mut h, prover.c); domain.divide_by_z_on_coset(e, &mut h); domain.icoset_fft(e, &mut h);