Make multiscalar multiplication generic for any scalar field size (#347)

* Make multiscalar multiplication generic for any scalar field size, within some limits

* Passes ed448 tests

* remove extreme comparison

* Typo

* Typo

* small optimizations

---------

Co-authored-by: Conrado Gouvea <conradoplg@gmail.com>
This commit is contained in:
Deirdre Connolly 2023-05-16 19:48:51 -04:00 committed by GitHub
parent a8275e12dd
commit 53a30278b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 62 additions and 12 deletions

View File

@ -8,9 +8,30 @@ use std::{
use crate::{Ciphersuite, Element, Field, Group, Scalar};
/// Calculates the quotient of `self` and `rhs`, rounding the result towards positive infinity.
///
/// # Panics
///
/// This function will panic if `rhs` is 0 or the division results in overflow.
///
/// This function is similar to `div_ceil` that is [available on
/// Nightly](https://github.com/rust-lang/rust/issues/88581).
///
// TODO: remove this function and use `div_ceil()` instead when `int_roundings`
// is stabilized.
const fn div_ceil(lhs: usize, rhs: usize) -> usize {
let d = lhs / rhs;
let r = lhs % rhs;
if r > 0 && rhs > 0 {
d + 1
} else {
d
}
}
/// A trait for transforming a scalar generic over a ciphersuite to a non-adjacent form (NAF).
pub trait NonAdjacentForm<C: Ciphersuite> {
fn non_adjacent_form(&self, w: usize) -> [i8; 257];
fn non_adjacent_form(&self, w: usize) -> Vec<i8>;
}
impl<C> NonAdjacentForm<C> for Scalar<C>
@ -24,7 +45,7 @@ where
/// # Safety
///
/// The full scalar field MUST fit in 256 bits in this implementation.
fn non_adjacent_form(&self, w: usize) -> [i8; 257] {
fn non_adjacent_form(&self, w: usize) -> Vec<i8> {
// required by the NAF definition
debug_assert!(w >= 2);
// required so that the NAF digits fit in i8
@ -32,22 +53,45 @@ where
use byteorder::{ByteOrder, LittleEndian};
// Safety: assumes a scalar that fits in 256 bits.
// The length of the NAF is at most one more than the bit length.
let mut naf = [0i8; 257];
let serialized_scalar = <<C::Group as Group>::Field>::little_endian_serialize(self);
// The canonical serialization length of this `Scalar` in bytes.
let serialization_len = serialized_scalar.as_ref().len();
let mut x_u64 = [0u64; 5];
LittleEndian::read_u64_into(
<<C::Group as Group>::Field>::little_endian_serialize(self).as_ref(),
&mut x_u64[0..4],
);
// Compute the size of the non-adjacent form from the number of bytes needed to serialize
// `Scalar`s, plus 1 bit.
//
// The length of the NAF is at most one more than the bit length.
let naf_length: usize = serialization_len * u8::BITS as usize + 1;
// Safety:
//
// The max value of `naf_length` (the number of bits to represent the
// scalar plus 1) _should_ have plenty of room in systems where usize is
// greater than 8 bits (aka, not a u8). If you are able to compile this
// code on a system with 8-bit pointers, well done, but this code will
// probably not compute the right thing for you, use a 16-bit or above
// system. Since the rest of this code uses u64's for limbs, we
// recommend a 64-bit system.
let mut naf = vec![0; naf_length];
// Get the number of 64-bit limbs we need.
let num_limbs: usize = div_ceil(naf_length, u64::BITS as usize);
let mut x_u64 = vec![0u64; num_limbs];
// This length needs to be 8*destination.len(), so pad out to length num_limbs * 8.
let mut padded_le_serialized = vec![0u8; num_limbs * 8];
padded_le_serialized[..serialization_len].copy_from_slice(serialized_scalar.as_ref());
LittleEndian::read_u64_into(padded_le_serialized.as_ref(), &mut x_u64[0..num_limbs]);
let width = 1 << w;
let window_mask = width - 1;
let mut pos = 0;
let mut carry = 0;
while pos < 257 {
while pos < naf_length {
// Construct a buffer of bits of the scalar, starting at bit `pos`
let u64_idx = pos / 64;
let bit_idx = pos % 64;
@ -149,7 +193,13 @@ where
let mut r = <C::Group>::identity();
for i in (0..257).rev() {
// All NAFs will have the same size, so get it from the first
if nafs.is_empty() {
return Some(r);
}
let naf_length = nafs[0].len();
for i in (0..naf_length).rev() {
let mut t = r + r;
for (naf, lookup_table) in nafs.iter().zip(lookup_tables.iter()) {