arithmetic::best_multiexp refactor buckets

This commit is contained in:
ashWhiteHat 2023-09-06 12:57:57 +09:00
parent e00f0d1233
commit 24e3ec3633
1 changed files with 49 additions and 54 deletions

View File

@ -56,24 +56,58 @@ impl<C: CurveAffine> Bucket<C> {
} }
} }
fn get_at<F: PrimeField>(segment: usize, c: usize, bytes: &F::Repr) -> usize { #[derive(Clone)]
let skip_bits = segment * c; struct Buckets<C: CurveAffine> {
let skip_bytes = skip_bits / 8; c: usize,
coeffs: Vec<Bucket<C>>,
}
if skip_bytes >= 32 { impl<C: CurveAffine> Buckets<C> {
return 0; fn new(c: usize) -> Self {
Self {
c,
coeffs: vec![Bucket::None; (1 << c) - 1],
}
} }
let mut v = [0; 8]; fn sum(&mut self, coeffs: &[C::Scalar], bases: &[C], i: usize) -> C::Curve {
for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) { // get segmentation and add coeff to buckets content
*v = *o; for (coeff, base) in coeffs.iter().zip(bases.iter()) {
let seg = self.get_at::<C::Scalar>(i, &coeff.to_repr());
if seg != 0 {
self.coeffs[seg - 1].add_assign(base);
}
}
// Summation by parts
// e.g. 3a + 2b + 1c = a +
// (a) + b +
// ((a) + b) + c
let mut acc = C::Curve::identity();
let mut sum = C::Curve::identity();
self.coeffs.iter().rev().for_each(|b| {
sum = b.add(sum);
acc += sum;
});
acc
} }
let mut tmp = u64::from_le_bytes(v); fn get_at<F: PrimeField>(&self, segment: usize, bytes: &F::Repr) -> usize {
tmp >>= skip_bits - (skip_bytes * 8); let skip_bits = segment * self.c;
tmp %= 1 << c; let skip_bytes = skip_bits / 8;
tmp as usize if skip_bytes >= 32 {
0
} else {
let mut v = [0; 8];
for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) {
*v = *o;
}
let mut tmp = u64::from_le_bytes(v);
tmp >>= skip_bits - (skip_bytes * 8);
(tmp % (1 << self.c)) as usize
}
}
} }
/// Performs a small multi-exponentiation operation. /// Performs a small multi-exponentiation operation.
@ -116,9 +150,7 @@ pub fn best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Cu
(f64::from(bases.len() as u32)).ln().ceil() as usize (f64::from(bases.len() as u32)).ln().ceil() as usize
}; };
let mut multi_buckets: Vec<Vec<Bucket<C>>> = let mut multi_buckets: Vec<Buckets<C>> = vec![Buckets::new(c); (256 / c) + 1];
vec![vec![Bucket::None; (1 << c) - 1]; (256 / c) + 1];
let num_threads = multicore::current_num_threads(); let num_threads = multicore::current_num_threads();
if coeffs.len() > num_threads { if coeffs.len() > num_threads {
multi_buckets multi_buckets
@ -126,24 +158,7 @@ pub fn best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Cu
.enumerate() .enumerate()
.rev() .rev()
.map(|(i, buckets)| { .map(|(i, buckets)| {
// get segmentation and add coeff to buckets content let mut acc = buckets.sum(coeffs, bases, i);
for (coeff, base) in coeffs.iter().zip(bases.iter()) {
let seg = get_at::<C::Scalar>(i, c, &coeff.to_repr());
if seg != 0 {
buckets[seg - 1].add_assign(base);
}
}
// Summation by parts
// e.g. 3a + 2b + 1c = a +
// (a) + b +
// ((a) + b) + c
let mut acc = C::Curve::identity();
let mut sum = C::Curve::identity();
buckets.iter().rev().for_each(|b| {
sum = b.add(sum);
acc += sum;
});
(0..c * i).for_each(|_| acc = acc.double()); (0..c * i).for_each(|_| acc = acc.double());
acc acc
}) })
@ -153,27 +168,7 @@ pub fn best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Cu
.iter_mut() .iter_mut()
.enumerate() .enumerate()
.rev() .rev()
.map(|(i, buckets)| { .map(|(i, buckets)| buckets.sum(coeffs, bases, i))
// get segmentation and add coeff to buckets content
for (coeff, base) in coeffs.iter().zip(bases.iter()) {
let seg = get_at::<C::Scalar>(i, c, &coeff.to_repr());
if seg != 0 {
buckets[seg - 1].add_assign(base);
}
}
// Summation by parts
// e.g. 3a + 2b + 1c = a +
// (a) + b +
// ((a) + b) + c
let mut acc = C::Curve::identity();
let mut sum = C::Curve::identity();
buckets.iter().rev().for_each(|b| {
sum = b.add(sum);
acc += sum;
});
acc
})
.fold(C::Curve::identity(), |mut sum, bucket| { .fold(C::Curve::identity(), |mut sum, bucket| {
// restore original evaluation point // restore original evaluation point
(0..c).for_each(|_| sum = sum.double()); (0..c).for_each(|_| sum = sum.double());