diff --git a/src/curves/bls381/ec.rs b/src/curves/bls381/ec.rs index 06ade0b..4d4db81 100644 --- a/src/curves/bls381/ec.rs +++ b/src/curves/bls381/ec.rs @@ -58,7 +58,8 @@ macro_rules! curve_impl { } } - impl CurveAffine<$engine, $name> for $name_affine { + impl CurveAffine<$engine> for $name_affine { + type Jacobian = $name; type Uncompressed = $name_uncompressed; fn is_valid(&self, e: &$engine) -> bool { @@ -111,10 +112,10 @@ macro_rules! curve_impl { self.infinity } - fn mul>(&self, e: &$engine, other: &S) -> $name { + fn mul>::Repr, $engine>>(&self, e: &$engine, other: &S) -> $name { let mut res = $name::zero(e); - for i in BitIterator::from((*other.convert(e)).borrow()) + for i in BitIterator::new((*other.convert(e)).borrow()) { res.double(e); @@ -133,6 +134,32 @@ macro_rules! curve_impl { } } + impl multiexp::Projective<$engine> for $name { + type WindowTable = wnaf::WindowTable<$engine, $name>; + + fn identity(e: &$engine) -> Self { + Self::zero(e) + } + + fn add_to_projective(&self, e: &$engine, projective: &mut Self) { + projective.add_assign(e, self); + } + + fn exponentiate(&mut self, + e: &$engine, + scalar: <$scalarfield as PrimeField<$engine>>::Repr, + table: &mut Self::WindowTable, + scratch: &mut wnaf::WNAFTable + ) + { + *self = self.optimal_exp(e, scalar, table, scratch); + } + + fn new_window_table(e: &$engine) -> Self::WindowTable { + wnaf::WindowTable::<$engine, $name>::new(e, $name::zero(e), 2) + } + } + impl Curve<$engine> for $name { type Affine = $name_affine; type Prepared = $name_prepared; @@ -147,7 +174,7 @@ macro_rules! curve_impl { None } - fn optimal_window_batch(&self, engine: &$engine, scalars: usize) -> WindowTable<$engine, $name, Vec<$name>> { + fn optimal_window_batch(&self, engine: &$engine, scalars: usize) -> wnaf::WindowTable<$engine, $name> { let mut window = engine.$params_field.batch_windows.0; for i in &engine.$params_field.batch_windows.1 { @@ -158,10 +185,7 @@ macro_rules! curve_impl { } } - let mut table = WindowTable::new(); - table.set_base(engine, self, window); - - table + wnaf::WindowTable::new(engine, *self, window) } fn zero(engine: &$engine) -> Self { @@ -290,10 +314,10 @@ macro_rules! curve_impl { } } - fn mul_assign>(&mut self, engine: &$engine, other: &S) { + fn mul_assign>::Repr, $engine>>(&mut self, engine: &$engine, other: &S) { let mut res = Self::zero(engine); - for i in BitIterator::from((*other.convert(engine)).borrow()) + for i in BitIterator::new((*other.convert(engine)).borrow()) { res.double(engine); diff --git a/src/curves/bls381/fp.rs b/src/curves/bls381/fp.rs index 0612e1a..43fc77a 100644 --- a/src/curves/bls381/fp.rs +++ b/src/curves/bls381/fp.rs @@ -183,6 +183,7 @@ macro_rules! fp_impl { engine = $engine:ident, params = $params_field:ident : $params_name:ident, arith = $arith_mod:ident, + repr = $repr:ident, limbs = $limbs:expr, $($params:tt)* ) => { @@ -218,15 +219,72 @@ macro_rules! fp_impl { #[repr(C)] pub struct $name([u64; $limbs]); + #[derive(Copy, Clone, PartialEq, Eq)] + #[repr(C)] + pub struct $repr([u64; $limbs]); + + impl PrimeFieldRepr for $repr { + fn from_u64(a: u64) -> Self { + let mut tmp: [u64; $limbs] = Default::default(); + tmp[0] = a; + $repr(tmp) + } + + fn sub_noborrow(&mut self, other: &Self) { + $arith_mod::sub_noborrow(&mut self.0, &other.0); + } + + fn add_nocarry(&mut self, other: &Self) { + $arith_mod::add_nocarry(&mut self.0, &other.0); + } + + fn num_bits(&self) -> usize { + $arith_mod::num_bits(&self.0) + } + + fn is_zero(&self) -> bool { + self.0.iter().all(|&e| e==0) + } + + fn is_odd(&self) -> bool { + $arith_mod::odd(&self.0) + } + + fn div2(&mut self) { + $arith_mod::div2(&mut self.0); + } + } + + impl AsRef<[u64]> for $repr { + fn as_ref(&self) -> &[u64] { + &self.0 + } + } + + impl Ord for $repr { + fn cmp(&self, other: &$repr) -> Ordering { + if $arith_mod::lt(&self.0, &other.0) { + Ordering::Less + } else if self.0 == other.0 { + Ordering::Equal + } else { + Ordering::Greater + } + } + } + + impl PartialOrd for $repr { + fn partial_cmp(&self, other: &$repr) -> Option { + Some(self.cmp(other)) + } + } + impl fmt::Debug for $name { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { ENGINE.with(|e| { - let mut repr = self.into_repr(&e); - repr.reverse(); - try!(write!(f, "Fp(0x")); - for i in &repr { + for i in self.into_repr(&e).0.iter().rev() { try!(write!(f, "{:016x}", *i)); } write!(f, ")") @@ -260,21 +318,21 @@ macro_rules! fp_impl { } } - impl Convert<[u64], $engine> for $name + impl Convert<$repr, $engine> for $name { - type Target = [u64; $limbs]; + type Target = $repr; - fn convert(&self, engine: &$engine) -> Cow<[u64; $limbs]> { + fn convert(&self, engine: &$engine) -> Cow<$repr> { Cow::Owned(self.into_repr(engine)) } } impl PrimeField<$engine> for $name { - type Repr = [u64; $limbs]; + type Repr = $repr; fn from_repr(engine: &$engine, repr: Self::Repr) -> Result { - let mut tmp = $name(repr); + let mut tmp = $name(repr.0); if $arith_mod::lt(&tmp.0, &engine.$params_field.modulus) { tmp.mul_assign(engine, &engine.$params_field.r2); Ok(tmp) @@ -286,14 +344,14 @@ macro_rules! fp_impl { fn into_repr(&self, engine: &$engine) -> Self::Repr { let mut tmp = *self; tmp.mul_assign(engine, &engine.$params_field.one); - tmp.0 + $repr(tmp.0) } fn from_u64(engine: &$engine, n: u64) -> Self { let mut r = [0; $limbs]; r[0] = n; - Self::from_repr(engine, r).unwrap() + Self::from_repr(engine, $repr(r)).unwrap() } fn from_str(engine: &$engine, s: &str) -> Result { @@ -313,12 +371,8 @@ macro_rules! fp_impl { Ok(res) } - fn bits(&self, engine: &$engine) -> BitIterator { - self.into_repr(engine).into() - } - fn char(engine: &$engine) -> Self::Repr { - engine.$params_field.modulus + $repr(engine.$params_field.modulus) } fn num_bits(engine: &$engine) -> usize { @@ -457,19 +511,20 @@ macro_rules! fp_impl { } mod $arith_mod { - use super::BitIterator; // Arithmetic #[allow(dead_code)] pub fn num_bits(v: &[u64; $limbs]) -> usize { - // TODO: optimize - for (i, b) in BitIterator::from(&v[..]).enumerate() { - if b { - return ($limbs*64) - i; + let mut ret = 64 * $limbs; + for i in v.iter().rev() { + let leading = i.leading_zeros() as usize; + ret -= leading; + if leading != 64 { + break; } } - 0 + ret } #[inline] diff --git a/src/curves/bls381/mod.rs b/src/curves/bls381/mod.rs index 91bfc8e..2a56151 100644 --- a/src/curves/bls381/mod.rs +++ b/src/curves/bls381/mod.rs @@ -1,21 +1,24 @@ use rand; use std::fmt; +use std::cmp::Ordering; use std::borrow::Borrow; +use ::BitIterator; use super::{ - WindowTable, Engine, Group, Curve, CurveAffine, CurveRepresentation, PrimeField, + PrimeFieldRepr, Field, SnarkField, SqrtField, - BitIterator, Convert, - Cow + Cow, + multiexp, + wnaf }; use serde::ser::{Serialize, Serializer, SerializeTuple}; @@ -61,6 +64,7 @@ fp_impl!( engine = Bls381, params = fqparams: FqParams, arith = fq_arith, + repr = FqRepr, limbs = 6, // q = 4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787 modulus = [ 0xb9feffffffffaaab, 0x1eabfffeb153ffff, 0x6730d2a0f6b0f624, 0x64774b84f38512bf, 0x4b1ba7b6434bacd7, 0x1a0111ea397fe69a ], @@ -80,6 +84,7 @@ fp_impl!( engine = Bls381, params = frparams: FrParams, arith = fr_arith, + repr = FrRepr, limbs = 4, // r = 52435875175126190479447740508185965837690552500527637822603658699938581184513 modulus = [ 0xffffffff00000001, 0x53bda402fffe5bfe, 0x3339d80809a1d805, 0x73eda753299d7d48 ], @@ -350,7 +355,9 @@ impl<'a> Deserialize<'a> for G2Uncompressed { } } -impl CurveRepresentation for G1Uncompressed { +impl CurveRepresentation for G1Uncompressed { + type Affine = G1Affine; + fn to_affine_unchecked(&self, e: &Bls381) -> Result { match self { &G1Uncompressed::Infinity => { @@ -372,8 +379,8 @@ impl CurveRepresentation for G1Uncompressed { } Ok(G1Affine { - x: try!(Fq::from_repr(e, x)), - y: try!(Fq::from_repr(e, y)), + x: try!(Fq::from_repr(e, FqRepr(x))), + y: try!(Fq::from_repr(e, FqRepr(y))), infinity: false }) } @@ -381,7 +388,9 @@ impl CurveRepresentation for G1Uncompressed { } } -impl CurveRepresentation for G2Uncompressed { +impl CurveRepresentation for G2Uncompressed { + type Affine = G2Affine; + fn to_affine_unchecked(&self, e: &Bls381) -> Result { match self { &G2Uncompressed::Infinity => { @@ -406,12 +415,12 @@ impl CurveRepresentation for G2Uncompressed { if let (Some(y_c1), y_c0) = fq_arith::divrem(&y, &e.fqparams.modulus) { return Ok(G2Affine { x: Fq2 { - c0: try!(Fq::from_repr(e, x_c0)), - c1: try!(Fq::from_repr(e, x_c1)) + c0: try!(Fq::from_repr(e, FqRepr(x_c0))), + c1: try!(Fq::from_repr(e, FqRepr(x_c1))) }, y: Fq2 { - c0: try!(Fq::from_repr(e, y_c0)), - c1: try!(Fq::from_repr(e, y_c1)) + c0: try!(Fq::from_repr(e, FqRepr(y_c0))), + c1: try!(Fq::from_repr(e, FqRepr(y_c1))) }, infinity: false }); @@ -435,14 +444,14 @@ impl G1Uncompressed { { let mut tmp = &mut tmp[0..]; - for &digit in p.x.into_repr(e).iter().rev() { + for &digit in p.x.into_repr(e).0.iter().rev() { tmp.write_u64::(digit).unwrap(); } } { let mut tmp = &mut tmp[48..]; - for &digit in p.y.into_repr(e).iter().rev() { + for &digit in p.y.into_repr(e).0.iter().rev() { tmp.write_u64::(digit).unwrap(); } } @@ -464,8 +473,8 @@ impl G2Uncompressed { { let mut tmp = &mut tmp[0..]; let mut x = [0; 12]; - fq_arith::mac3(&mut x, &p.x.c1.into_repr(e), &e.fqparams.modulus); - fq_arith::add_carry(&mut x, &p.x.c0.into_repr(e)); + fq_arith::mac3(&mut x, &p.x.c1.into_repr(e).0, &e.fqparams.modulus); + fq_arith::add_carry(&mut x, &p.x.c0.into_repr(e).0); for &digit in x.iter().rev() { tmp.write_u64::(digit).unwrap(); @@ -475,8 +484,8 @@ impl G2Uncompressed { { let mut tmp = &mut tmp[96..]; let mut y = [0; 12]; - fq_arith::mac3(&mut y, &p.y.c1.into_repr(e), &e.fqparams.modulus); - fq_arith::add_carry(&mut y, &p.y.c0.into_repr(e)); + fq_arith::mac3(&mut y, &p.y.c1.into_repr(e).0, &e.fqparams.modulus); + fq_arith::add_carry(&mut y, &p.y.c0.into_repr(e).0); for &digit in y.iter().rev() { tmp.write_u64::(digit).unwrap(); @@ -685,7 +694,7 @@ impl G2Prepared { let mut r = q.to_jacobian(e); let mut found_one = false; - for i in BitIterator::from([BLS_X >> 1]) { + for i in BitIterator::new(&[BLS_X >> 1]) { if !found_one { found_one = i; continue; @@ -999,7 +1008,7 @@ impl Engine for Bls381 { let mut f = Fq12::one(self); let mut found_one = false; - for i in BitIterator::from([BLS_X >> 1]) { + for i in BitIterator::new(&[BLS_X >> 1]) { if !found_one { found_one = i; continue; @@ -1037,7 +1046,8 @@ impl Engine for Bls381 { crossbeam::scope(|scope| { for (g, s) in g.chunks_mut(chunk).zip(scalars.as_ref().chunks(chunk)) { scope.spawn(move || { - let mut table = WindowTable::new(); + let mut table = wnaf::WindowTable::new(self, G::zero(self), 2); + let mut scratch = wnaf::WNAFTable::new(); for (g, s) in g.iter_mut().zip(s.iter()) { let mut s = *s; @@ -1047,16 +1057,16 @@ impl Engine for Bls381 { }, _ => {} }; - let mut newg = g.to_jacobian(self); - opt_exp(self, &mut newg, s.into_repr(self), &mut table); - *g = newg.to_affine(self); + *g = g.to_jacobian(self) + .optimal_exp(self, s.into_repr(self), &mut table, &mut scratch) + .to_affine(self); } }); } }); } - fn batch_baseexp, S: AsRef<[Self::Fr]>>(&self, table: &WindowTable>, s: S) -> Vec + fn batch_baseexp, S: AsRef<[Self::Fr]>>(&self, table: &wnaf::WindowTable, s: S) -> Vec { use crossbeam; use num_cpus; @@ -1068,13 +1078,12 @@ impl Engine for Bls381 { let chunk = (s.len() / num_cpus::get()) + 1; for (s, b) in s.chunks(chunk).zip(ret.chunks_mut(chunk)) { - let mut table = table.shared(); - scope.spawn(move || { + let mut scratch = wnaf::WNAFTable::new(); + for (s, b) in s.iter().zip(b.iter_mut()) { - let mut tmp = G::zero(self); - table.exp(self, &mut tmp, s.into_repr(self)); - *b = tmp.to_affine(self); + scratch.set_scalar(table, s.into_repr(self)); + *b = table.exp(self, &scratch).to_affine(self); } }); } @@ -1084,232 +1093,7 @@ impl Engine for Bls381 { } fn multiexp>(&self, g: &[G::Affine], s: &[Fr]) -> Result { - if g.len() != s.len() { - return Err(()); - } - - use crossbeam; - use num_cpus; - - return crossbeam::scope(|scope| { - let mut threads = vec![]; - - let chunk = (s.len() / num_cpus::get()) + 1; - - for (g, s) in g.chunks(chunk).zip(s.chunks(chunk)) { - threads.push(scope.spawn(move || { - multiexp_inner(self, g, s) - })); - } - - let mut acc = G::zero(self); - for t in threads { - acc.add_assign(self, &t.join()); - } - - Ok(acc) - }); - - fn multiexp_inner>(engine: &Bls381, g: &[G::Affine], s: &[Fr]) -> G - { - // This performs a multi-exponentiation calculation, i.e., multiplies - // each group element by the corresponding scalar and adds all of the - // terms together. We use the Bos-Coster algorithm to do this: sort - // the exponents using a max heap, and rewrite the first two terms - // a x + b y = (a-b) x + b(y+x). Reinsert the first element into the - // heap after performing cheap scalar subtraction, and perform the - // point addition. This continues until the heap is emptied as - // elements are multiplied when a certain efficiency threshold is met - // or discarded when their exponents become zero. The result of all - // the multiplications are accumulated and returned when the heap - // is empty. - - assert!(g.len() == s.len()); - - use std::cmp::Ordering; - use std::collections::BinaryHeap; - - struct Exp { - index: usize, - value: >::Repr - } - - impl Exp { - fn bits(&self) -> usize { - fr_arith::num_bits(&self.value) - } - - fn justexp(&self, sub: &Exp) -> bool { - use std::cmp::min; - - let bbits = sub.bits(); - let abits = self.bits(); - let limit = min(abits-bbits, 20); - - if bbits < (1< bool { - self.value.iter().all(|&e| e == 0) - } - } - - impl Ord for Exp { - fn cmp(&self, other: &Exp) -> Ordering { - if fr_arith::lt(&self.value, &other.value) { - Ordering::Less - } else if self.value == other.value { - Ordering::Equal - } else { - Ordering::Greater - } - } - } - - impl PartialOrd for Exp { - fn partial_cmp(&self, other: &Exp) -> Option { - Some(self.cmp(other)) - } - } - - impl PartialEq for Exp { - fn eq(&self, other: &Exp) -> bool { - self.value == other.value - } - } - - impl Eq for Exp { } - - let mut result = G::zero(engine); - let one = Fr::one(engine); - - let mut elements = Vec::with_capacity(g.len()); - let mut heap = BinaryHeap::with_capacity(g.len()); - - for (g, s) in g.iter().zip(s.iter()) { - if s.is_zero() || g.is_zero() { - // Skip. - continue; - } - - if s == &one { - // Just add. - result.add_assign_mixed(engine, &g); - continue; - } - - let index = elements.len(); - elements.push(g.to_jacobian(engine)); - - heap.push(Exp { - index: index, - value: s.into_repr(engine) - }); - } - - let mut table = WindowTable::new(); - - while let Some(mut greatest) = heap.pop() { - { - let second_greatest = heap.peek(); - if second_greatest.is_none() || greatest.justexp(second_greatest.unwrap()) { - // Either this is the last value or multiplying is considered more efficient than - // rewriting and reinsertion into the heap. - opt_exp(engine, &mut elements[greatest.index], greatest.value, &mut table); - result.add_assign(engine, &elements[greatest.index]); - continue; - } else { - // Rewrite - let second_greatest = second_greatest.unwrap(); - - fr_arith::sub_noborrow(&mut greatest.value, &second_greatest.value); - let mut tmp = elements[second_greatest.index]; - tmp.add_assign(engine, &elements[greatest.index]); - elements[second_greatest.index] = tmp; - } - } - if !greatest.is_zero() { - // Reinsert only nonzero scalars. - heap.push(greatest); - } - } - - result - } - } -} - -impl, B: Borrow<[G]>> WindowTable { - fn exp(&mut self, e: &Bls381, into: &mut G, mut c: >::Repr) { - assert!(self.window > 1); - - self.wnaf.truncate(0); - self.wnaf.reserve(Fr::num_bits(e) + 1); - - // Convert the scalar `c` into wNAF form. - { - use std::default::Default; - let mut tmp = >::Repr::default(); - - while !c.iter().all(|&e| e==0) { - let mut u; - if fr_arith::odd(&c) { - u = (c[0] % (1 << (self.window+1))) as i64; - - if u > (1 << self.window) { - u -= 1 << (self.window+1); - } - - if u > 0 { - tmp[0] = u as u64; - fr_arith::sub_noborrow(&mut c, &tmp); - } else { - tmp[0] = (-u) as u64; - fr_arith::add_nocarry(&mut c, &tmp); - } - } else { - u = 0; - } - - self.wnaf.push(u); - - fr_arith::div2(&mut c); - } - } - - // Perform wNAF exponentiation. - *into = G::zero(e); - - for n in self.wnaf.iter().rev() { - into.double(e); - - if *n != 0 { - if *n > 0 { - into.add_assign(e, &self.table.borrow()[(n/2) as usize]); - } else { - into.sub_assign(e, &self.table.borrow()[((-n)/2) as usize]); - } - } - } - } -} - -// Performs optimal exponentiation -fn opt_exp>(e: &Bls381, base: &mut G, scalar: >::Repr, table: &mut WindowTable>) -{ - let bits = fr_arith::num_bits(&scalar); - match G::optimal_window(e, bits) { - Some(window) => { - table.set_base(e, base, window); - table.exp(e, base, scalar); - }, - None => { - base.mul_assign(e, &scalar); - } + super::multiexp::perform_multiexp(self, g, s) } } diff --git a/src/curves/bls381/tests/mod.rs b/src/curves/bls381/tests/mod.rs index 1a59782..29d98ea 100644 --- a/src/curves/bls381/tests/mod.rs +++ b/src/curves/bls381/tests/mod.rs @@ -11,7 +11,7 @@ fn test_vectors>(e: &E, expected: &[u8]) { for _ in 0..10000 { { let acc = acc.to_affine(e); - let exp: >::Uncompressed = + let exp: >::Uncompressed = bincode::deserialize_from(&mut expected_reader, bincode::Infinite).unwrap(); assert!(acc == exp.to_affine(e).unwrap()); diff --git a/src/groth16/domain.rs b/src/curves/domain.rs similarity index 99% rename from src/groth16/domain.rs rename to src/curves/domain.rs index c20e164..de0779f 100644 --- a/src/groth16/domain.rs +++ b/src/curves/domain.rs @@ -1,4 +1,4 @@ -use curves::{Engine, Field, SnarkField, PrimeField, Group}; +use super::{Engine, Field, SnarkField, PrimeField, Group}; use crossbeam; use num_cpus; diff --git a/src/curves/mod.rs b/src/curves/mod.rs index 7da0ac8..3768b87 100644 --- a/src/curves/mod.rs +++ b/src/curves/mod.rs @@ -2,17 +2,20 @@ use rand; use std::fmt; use std::borrow::Borrow; -use std::marker::PhantomData; use serde::{Serialize, Deserialize}; +use super::BitIterator; use super::{Cow, Convert}; pub mod bls381; +pub mod multiexp; +pub mod wnaf; +pub mod domain; pub trait Engine: Sized + Clone + Send + Sync { - type Fq: PrimeField; - type Fr: SnarkField; + type Fq: PrimeField + Convert<>::Repr, Self>; + type Fr: SnarkField + Convert<>::Repr, Self>; type Fqe: SqrtField; type Fqk: Field; type G1: Curve + Convert<>::Affine, Self>; @@ -43,7 +46,7 @@ pub trait Engine: Sized + Clone + Send + Sync /// Perform multi-exponentiation. g and s must have the same length. fn multiexp>(&self, g: &[G::Affine], s: &[Self::Fr]) -> Result; - fn batch_baseexp, S: AsRef<[Self::Fr]>>(&self, table: &WindowTable>, scalars: S) -> Vec; + fn batch_baseexp, S: AsRef<[Self::Fr]>>(&self, table: &wnaf::WindowTable, scalars: S) -> Vec; fn batchexp, S: AsRef<[Self::Fr]>>(&self, g: &mut [G::Affine], scalars: S, coeff: Option<&Self::Fr>); } @@ -63,9 +66,10 @@ pub trait Curve: Sized + Sync + fmt::Debug + 'static + - Group + Group + + self::multiexp::Projective { - type Affine: CurveAffine; + type Affine: CurveAffine; type Prepared: Clone + Send + Sync + 'static; fn zero(&E) -> Self; @@ -83,28 +87,53 @@ pub trait Curve: Sized + fn add_assign(&mut self, &E, other: &Self); fn sub_assign(&mut self, &E, other: &Self); fn add_assign_mixed(&mut self, &E, other: &Self::Affine); - fn mul_assign>(&mut self, &E, other: &S); + fn mul_assign>::Repr, E>>(&mut self, &E, other: &S); fn optimal_window(&E, scalar_bits: usize) -> Option; - fn optimal_window_batch(&self, &E, scalars: usize) -> WindowTable>; + fn optimal_window_batch(&self, &E, scalars: usize) -> wnaf::WindowTable; + + /// Performs optimal exponentiation of this curve element given the scalar, using + /// wNAF when necessary. + fn optimal_exp( + &self, + e: &E, + scalar: >::Repr, + table: &mut wnaf::WindowTable, + scratch: &mut wnaf::WNAFTable + ) -> Self { + let bits = scalar.num_bits(); + match Self::optimal_window(e, bits) { + Some(window) => { + table.set_base(e, *self, window); + scratch.set_scalar(table, scalar); + table.exp(e, scratch) + }, + None => { + let mut tmp = *self; + tmp.mul_assign(e, &scalar); + tmp + } + } + } } -pub trait CurveAffine>: Copy + - Clone + - Sized + - Send + - Sync + - fmt::Debug + - PartialEq + - Eq + - 'static +pub trait CurveAffine: Copy + + Clone + + Sized + + Send + + Sync + + fmt::Debug + + PartialEq + + Eq + + 'static { - type Uncompressed: CurveRepresentation; + type Jacobian: Curve; + type Uncompressed: CurveRepresentation; - fn to_jacobian(&self, &E) -> G; - fn prepare(self, &E) -> G::Prepared; + fn to_jacobian(&self, &E) -> Self::Jacobian; + fn prepare(self, &E) -> >::Prepared; fn is_zero(&self) -> bool; - fn mul>(&self, &E, other: &S) -> G; + fn mul>::Repr, E>>(&self, &E, other: &S) -> Self::Jacobian; fn negate(&mut self, &E); /// Returns true iff the point is on the curve and in the correct @@ -117,11 +146,13 @@ pub trait CurveAffine>: Copy + fn to_uncompressed(&self, &E) -> Self::Uncompressed; } -pub trait CurveRepresentation>: Serialize + for<'a> Deserialize<'a> +pub trait CurveRepresentation: Serialize + for<'a> Deserialize<'a> { + type Affine: CurveAffine; + /// If the point representation is valid (lies on the curve, correct /// subgroup) this function will return it. - fn to_affine(&self, e: &E) -> Result { + fn to_affine(&self, e: &E) -> Result { let p = try!(self.to_affine_unchecked(e)); if p.is_valid(e) { @@ -133,7 +164,7 @@ pub trait CurveRepresentation>: Serialize + for<'a> Deser /// Returns the point under the assumption that it is valid. Undefined /// behavior if `to_affine` would have rejected the point. - fn to_affine_unchecked(&self, &E) -> Result; + fn to_affine_unchecked(&self, &E) -> Result; } pub trait Field: Sized + @@ -158,11 +189,11 @@ pub trait Field: Sized + fn mul_assign(&mut self, &E, other: &Self); fn inverse(&self, &E) -> Option; fn frobenius_map(&mut self, &E, power: usize); - fn pow>(&self, engine: &E, exp: &S) -> Self + fn pow>(&self, engine: &E, exp: S) -> Self { let mut res = Self::one(engine); - for i in BitIterator::from((*exp.convert(engine)).borrow()) { + for i in BitIterator::new(exp) { res.square(engine); if i { res.mul_assign(engine, self); @@ -180,18 +211,25 @@ pub trait SqrtField: Field fn sqrt(&self, engine: &E) -> Option; } -pub trait PrimeField: SqrtField + Convert<[u64], E> +pub trait PrimeFieldRepr: Clone + Eq + Ord + AsRef<[u64]> { + fn from_u64(a: u64) -> Self; + fn sub_noborrow(&mut self, other: &Self); + fn add_nocarry(&mut self, other: &Self); + fn num_bits(&self) -> usize; + fn is_zero(&self) -> bool; + fn is_odd(&self) -> bool; + fn div2(&mut self); +} + +pub trait PrimeField: SqrtField { - /// Little endian representation of a field element. - type Repr: Convert<[u64], E> + Eq + Clone; + type Repr: PrimeFieldRepr; + fn from_u64(&E, u64) -> Self; fn from_str(&E, s: &str) -> Result; fn from_repr(&E, Self::Repr) -> Result; fn into_repr(&self, &E) -> Self::Repr; - /// Returns an interator over all bits, most significant bit first. - fn bits(&self, &E) -> BitIterator; - /// Returns the field characteristic; the modulus. fn char(&E) -> Self::Repr; @@ -211,111 +249,6 @@ pub trait SnarkField: PrimeField + Group fn root_of_unity(&E) -> Self; } -pub struct WindowTable> { - table: Table, - wnaf: Vec, - window: usize, - _marker: PhantomData<(E, G)> -} - -impl> WindowTable> { - fn new() -> Self { - WindowTable { - table: vec![], - wnaf: vec![], - window: 0, - _marker: PhantomData - } - } - - fn set_base(&mut self, e: &E, base: &G, window: usize) { - assert!(window > 1); - - self.window = window; - self.table.truncate(0); - self.table.reserve(1 << (window-1)); - - let mut tmp = *base; - let mut dbl = tmp; - dbl.double(e); - - for _ in 0..(1 << (window-1)) { - self.table.push(tmp); - tmp.add_assign(e, &dbl); - } - } - - fn shared(&self) -> WindowTable { - WindowTable { - table: &self.table[..], - wnaf: vec![], - window: self.window, - _marker: PhantomData - } - } -} - -pub struct BitIterator { - t: T, - n: usize -} - -impl> Iterator for BitIterator { - type Item = bool; - - fn next(&mut self) -> Option { - if self.n == 0 { - None - } else { - self.n -= 1; - let part = self.n / 64; - let bit = self.n - (64 * part); - - Some(self.t.as_ref()[part] & (1 << bit) > 0) - } - } -} - -impl<'a> From<&'a [u64]> for BitIterator<&'a [u64]> -{ - fn from(v: &'a [u64]) -> Self { - assert!(v.len() < 100); - - BitIterator { - t: v, - n: v.len() * 64 - } - } -} - -macro_rules! bit_iter_impl( - ($n:expr) => { - impl From<[u64; $n]> for BitIterator<[u64; $n]> { - fn from(v: [u64; $n]) -> Self { - BitIterator { - t: v, - n: $n * 64 - } - } - } - - impl Convert<[u64], E> for [u64; $n] { - type Target = [u64; $n]; - - fn convert(&self, _: &E) -> Cow<[u64; $n]> { - Cow::Borrowed(self) - } - } - }; -); - -bit_iter_impl!(1); -bit_iter_impl!(2); -bit_iter_impl!(3); -bit_iter_impl!(4); -bit_iter_impl!(5); -bit_iter_impl!(6); - #[cfg(test)] mod tests; diff --git a/src/curves/multiexp.rs b/src/curves/multiexp.rs new file mode 100644 index 0000000..37b583b --- /dev/null +++ b/src/curves/multiexp.rs @@ -0,0 +1,232 @@ +//! This module provides an abstract implementation of the Bos-Coster multi-exponentiation algorithm. + +use super::{Engine, Curve, CurveAffine, Field, PrimeField, PrimeFieldRepr}; +use super::wnaf; +use std::cmp::Ordering; +use std::collections::BinaryHeap; + +pub trait Projective: Sized + Copy + Clone + Send { + type WindowTable; + + /// Constructs an identity element. + fn identity(e: &E) -> Self; + + /// Adds this projective element to another projective element. + fn add_to_projective(&self, e: &E, projective: &mut Self); + + /// Exponentiates by a scalar. + fn exponentiate( + &mut self, + e: &E, + scalar: >::Repr, + table: &mut Self::WindowTable, + scratch: &mut wnaf::WNAFTable + ); + + /// Construct a blank window table + fn new_window_table(e: &E) -> Self::WindowTable; +} + +pub trait Chunk: Send { + type Projective: Projective; + + /// Skips the next element from the source. + fn skip(&mut self, e: &E) -> Result<(), ()>; + + /// Adds the next element from the source to a projective element + fn add_to_projective(&mut self, e: &E, acc: &mut Self::Projective) -> Result<(), ()>; + + /// Turns the next element of the source into a projective element. + fn into_projective(&mut self, e: &E) -> Result; +} + +/// An `ElementSource` is something that contains a sequence of group elements or +/// group element tuples. +pub trait ElementSource { + type Chunk: Chunk; + + /// Gets the number of elements from the source. + fn num_elements(&self) -> usize; + + /// Returns a chunk size and a vector of chunks. + fn chunks(&mut self, chunks: usize) -> (usize, Vec); +} + +impl<'a, E: Engine, G: CurveAffine> ElementSource for &'a [G] { + type Chunk = &'a [G]; + + fn num_elements(&self) -> usize { + self.len() + } + + fn chunks(&mut self, chunks: usize) -> (usize, Vec) { + let chunk_size = (self.len() / chunks) + 1; + + (chunk_size, (*self).chunks(chunk_size).collect()) + } +} + +impl<'a, E: Engine, G: CurveAffine> Chunk for &'a [G] +{ + type Projective = G::Jacobian; + + fn skip(&mut self, _: &E) -> Result<(), ()> { + if self.len() == 0 { + Err(()) + } else { + *self = &self[1..]; + Ok(()) + } + } + + /// Adds the next element from the source to a projective element + fn add_to_projective(&mut self, e: &E, acc: &mut Self::Projective) -> Result<(), ()> { + if self.len() == 0 { + Err(()) + } else { + acc.add_assign_mixed(e, &self[0]); + *self = &self[1..]; + Ok(()) + } + } + + /// Turns the next element of the accumulator into a projective element. + fn into_projective(&mut self, e: &E) -> Result { + if self.len() == 0 { + Err(()) + } else { + let ret = Ok(self[0].to_jacobian(e)); + *self = &self[1..]; + ret + } + } +} + +fn justexp( + largest: &>::Repr, + smallest: &>::Repr +) -> bool +{ + use std::cmp::min; + + let abits = largest.num_bits(); + let bbits = smallest.num_bits(); + let limit = min(abits-bbits, 20); + + if bbits < (1<>( + e: &E, + mut bases: Source, + scalars: &[E::Fr] +) -> Result<>::Projective, ()> +{ + if bases.num_elements() != scalars.len() { + return Err(()) + } + + use crossbeam; + use num_cpus; + + let (chunk_len, bases) = bases.chunks(num_cpus::get()); + + return crossbeam::scope(|scope| { + let mut threads = vec![]; + + for (mut chunk, scalars) in bases.into_iter().zip(scalars.chunks(chunk_len)) { + threads.push(scope.spawn(move || { + let mut heap: BinaryHeap> = BinaryHeap::with_capacity(scalars.len()); + let mut elements = Vec::with_capacity(scalars.len()); + + let mut acc = Projective::::identity(e); + let one = E::Fr::one(e); + + for scalar in scalars { + if scalar.is_zero() { + // Skip processing bases when we're multiplying by a zero anyway. + chunk.skip(e)?; + } else if *scalar == one { + // Just perform mixed addition when we're multiplying by one. + chunk.add_to_projective(e, &mut acc)?; + } else { + elements.push(chunk.into_projective(e)?); + heap.push(Exp { + scalar: scalar.into_repr(e), + index: elements.len() - 1 + }); + } + } + + let mut window = <>::Projective as Projective>::new_window_table(e); + let mut scratch = wnaf::WNAFTable::new(); + + // Now that the heap is populated... + while let Some(mut greatest) = heap.pop() { + { + let second_greatest = heap.peek(); + if second_greatest.is_none() || justexp::(&greatest.scalar, &second_greatest.unwrap().scalar) { + // Either this is the last value or multiplying is considered more efficient than + // rewriting and reinsertion into the heap. + //opt_exp(engine, &mut elements[greatest.index], greatest.scalar, &mut table); + elements[greatest.index].exponentiate(e, greatest.scalar, &mut window, &mut scratch); + elements[greatest.index].add_to_projective(e, &mut acc); + continue; + } else { + // Rewrite + let second_greatest = second_greatest.unwrap(); + + greatest.scalar.sub_noborrow(&second_greatest.scalar); + let mut tmp = elements[second_greatest.index]; + elements[greatest.index].add_to_projective(e, &mut tmp); + elements[second_greatest.index] = tmp; + } + } + if !greatest.scalar.is_zero() { + // Reinsert only nonzero scalars. + heap.push(greatest); + } + } + + Ok(acc) + })); + } + + + let mut acc = Projective::::identity(e); + for t in threads { + t.join()?.add_to_projective(e, &mut acc); + } + + Ok(acc) + }) +} + +struct Exp { + scalar: >::Repr, + index: usize +} + +impl Ord for Exp { + fn cmp(&self, other: &Exp) -> Ordering { + self.scalar.cmp(&other.scalar) + } +} + +impl PartialOrd for Exp { + fn partial_cmp(&self, other: &Exp) -> Option { + Some(self.cmp(other)) + } +} + +impl PartialEq for Exp { + fn eq(&self, other: &Exp) -> bool { + self.scalar == other.scalar + } +} + +impl Eq for Exp { } diff --git a/src/curves/tests/mod.rs b/src/curves/tests/mod.rs index 2d8da83..c615526 100644 --- a/src/curves/tests/mod.rs +++ b/src/curves/tests/mod.rs @@ -106,7 +106,7 @@ fn test_bilinearity(e: &E) { let mut test4 = e.pairing(&a, &b); assert!(test4 != test1); - test4 = test4.pow(e, &s); + test4 = test4.pow(e, &s.into_repr(e)); assert_eq!(test1, test4); } diff --git a/src/curves/wnaf.rs b/src/curves/wnaf.rs new file mode 100644 index 0000000..c4f062e --- /dev/null +++ b/src/curves/wnaf.rs @@ -0,0 +1,110 @@ +use std::marker::PhantomData; +use super::{Engine, Curve, PrimeField, PrimeFieldRepr}; + +/// Represents the scratch space for a wNAF form scalar. +pub struct WNAFTable { + window: usize, + wnaf: Vec +} + +impl WNAFTable { + pub fn new() -> WNAFTable { + WNAFTable { + window: 0, + wnaf: vec![] + } + } + + /// Convert the scalar into wNAF form. + pub fn set_scalar>(&mut self, table: &WindowTable, mut c: >::Repr) { + self.window = table.window; + self.wnaf.truncate(0); + + while !c.is_zero() { + let mut u; + if c.is_odd() { + u = (c.as_ref()[0] % (1 << (self.window+1))) as i64; + + if u > (1 << self.window) { + u -= 1 << (self.window+1); + } + + if u > 0 { + c.sub_noborrow(&<>::Repr as PrimeFieldRepr>::from_u64(u as u64)); + } else { + c.add_nocarry(&<>::Repr as PrimeFieldRepr>::from_u64((-u) as u64)); + } + } else { + u = 0; + } + + self.wnaf.push(u); + + c.div2(); + } + } +} + +/// Represents a window table for a base curve point. +pub struct WindowTable>{ + window: usize, + table: Vec, + _marker: PhantomData +} + +impl> WindowTable { + /// Construct a new window table for a given base. + pub fn new(e: &E, base: G, window: usize) -> Self { + let mut tmp = WindowTable { + window: 0, + table: vec![], + _marker: PhantomData + }; + + tmp.set_base(e, base, window); + + tmp + } + + /// Replace this window table with a new one generated by a different base. + pub fn set_base(&mut self, e: &E, mut base: G, window: usize) { + assert!(window < 23); + assert!(window > 1); + + self.window = window; + self.table.truncate(0); + self.table.reserve(1 << (window-1)); + + let mut dbl = base; + dbl.double(e); + + for _ in 0..(1 << (window-1)) { + self.table.push(base); + base.add_assign(e, &dbl); + } + } + + pub fn exp(&self, e: &E, wnaf: &WNAFTable) -> G { + assert_eq!(wnaf.window, self.window); + + let mut result = G::zero(e); + + for n in wnaf.wnaf.iter().rev() { + result.double(e); + + if *n != 0 { + if *n > 0 { + result.add_assign(e, &self.table[(n/2) as usize]); + } else { + result.sub_assign(e, &self.table[((-n)/2) as usize]); + } + } + } + + result + } + + pub fn current_window(&self) -> usize { + self.window + } +} diff --git a/src/groth16/mod.rs b/src/groth16/mod.rs index d2f0a4e..f7902c9 100644 --- a/src/groth16/mod.rs +++ b/src/groth16/mod.rs @@ -1,8 +1,6 @@ use curves::*; use super::*; -pub mod domain; - pub struct ProvingKey { a_inputs: Vec<>::Affine>, b1_inputs: Vec<>::Affine>, @@ -307,55 +305,73 @@ pub fn prepare_verifying_key( } } -pub fn verify, F: FnOnce(&mut ConstraintSystem) -> C>( - e: &E, - circuit: F, - proof: &Proof, - pvk: &PreparedVerifyingKey -) -> bool -{ - struct VerifierInput<'a, E: Engine + 'a> { - e: &'a E, - acc: E::G1, - ic: &'a [>::Affine], - insufficient_inputs: bool, - num_inputs: usize, - num_aux: usize +pub struct VerifierInput<'a, E: Engine + 'a> { + e: &'a E, + acc: E::G1, + ic: &'a [>::Affine], + insufficient_inputs: bool, + num_inputs: usize, + num_aux: usize +} + +impl<'a, E: Engine> ConstraintSystem for VerifierInput<'a, E> { + fn alloc(&mut self, _: E::Fr) -> Variable { + let index = self.num_aux; + self.num_aux += 1; + + Variable(Index::Aux(index)) } - impl<'a, E: Engine> PublicConstraintSystem for VerifierInput<'a, E> { + fn enforce( + &mut self, + _: LinearCombination, + _: LinearCombination, + _: LinearCombination + ) + { + // Do nothing; we don't care about the constraint system + // in this context. + } +} + +pub fn verify<'a, E: Engine, C: Input, F: FnOnce(&mut VerifierInput<'a, E>) -> C>( + e: &'a E, + circuit: F, + proof: &Proof, + pvk: &'a PreparedVerifyingKey +) -> bool +{ + struct InputAllocator(T); + + impl<'a, 'b, E: Engine> PublicConstraintSystem for InputAllocator<&'b mut VerifierInput<'a, E>> { fn alloc_input(&mut self, value: E::Fr) -> Variable { - if self.ic.len() == 0 { - self.insufficient_inputs = true; + if self.0.ic.len() == 0 { + self.0.insufficient_inputs = true; } else { - self.acc.add_assign(self.e, &self.ic[0].mul(self.e, &value)); - self.ic = &self.ic[1..]; + self.0.acc.add_assign(self.0.e, &self.0.ic[0].mul(self.0.e, &value)); + self.0.ic = &self.0.ic[1..]; } - let index = self.num_inputs; - self.num_inputs += 1; + let index = self.0.num_inputs; + self.0.num_inputs += 1; Variable(Index::Input(index)) } } - impl<'a, E: Engine> ConstraintSystem for VerifierInput<'a, E> { - fn alloc(&mut self, _: E::Fr) -> Variable { - let index = self.num_aux; - self.num_aux += 1; - - Variable(Index::Aux(index)) + impl<'a, 'b, E: Engine> ConstraintSystem for InputAllocator<&'b mut VerifierInput<'a, E>> { + fn alloc(&mut self, num: E::Fr) -> Variable { + self.0.alloc(num) } fn enforce( &mut self, - _: LinearCombination, - _: LinearCombination, - _: LinearCombination + a: LinearCombination, + b: LinearCombination, + c: LinearCombination ) { - // Do nothing; we don't care about the constraint system - // in this context. + self.0.enforce(a, b, c); } } @@ -368,7 +384,7 @@ pub fn verify, F: FnOnce(&mut ConstraintSystem) -> C>( num_aux: 0 }; - circuit(&mut witness).synthesize(e, &mut witness); + circuit(&mut witness).synthesize(e, &mut InputAllocator(&mut witness)); if witness.ic.len() != 0 || witness.insufficient_inputs { return false; diff --git a/src/lib.rs b/src/lib.rs index 489c98f..52ccb71 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -181,3 +181,35 @@ impl Convert for T { Cow::Borrowed(self) } } + +pub struct BitIterator { + t: T, + n: usize +} + +impl> BitIterator { + fn new(t: T) -> Self { + let bits = 64 * t.as_ref().len(); + + BitIterator { + t: t, + n: bits + } + } +} + +impl> Iterator for BitIterator { + type Item = bool; + + fn next(&mut self) -> Option { + if self.n == 0 { + None + } else { + self.n -= 1; + let part = self.n / 64; + let bit = self.n - (64 * part); + + Some(self.t.as_ref()[part] & (1 << bit) > 0) + } + } +}