change: Refactor & optimize the NAF (#63)

* Make the NAF function generic

* Use the `jubjub` prefix for Jubjub types in tests

* Add tests for the NAF for Jubjub & Pallas scalars

* Use Rust's TryInto for [u8; 32]

Co-authored-by: Conrado Gouvea <conradoplg@gmail.com>

* Simplify the scalar conversion

* Revert "Simplify the scalar conversion"

This reverts commit f50ff9dd8a.

* Revert "Use Rust's TryInto for [u8; 32]"

This reverts commit 282c3b16ac.

---------

Co-authored-by: Deirdre Connolly <deirdre@zfnd.org>
Co-authored-by: Conrado Gouvea <conradoplg@gmail.com>
This commit is contained in:
Marek 2023-04-25 19:51:13 +02:00 committed by GitHub
parent c31c5c4a4f
commit 4f8ce48cd5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 160 additions and 87 deletions

View File

@ -51,6 +51,8 @@ rand = "0.8"
rand_chacha = "0.3"
serde_json = "1.0"
frost-rerandomized = { version = "0.2", features=["test-impl"] }
num-bigint = "0.4.3"
num-traits = "0.2.15"
# `alloc` is only used in test code
[dev-dependencies.pasta_curves]

View File

@ -87,63 +87,14 @@ impl private::Sealed<Binding> for Binding {
#[cfg(feature = "alloc")]
impl NonAdjacentForm for pallas::Scalar {
/// Compute a width-\\(w\\) "Non-Adjacent Form" of this scalar.
///
/// Thanks to curve25519-dalek
fn non_adjacent_form(&self, w: usize) -> [i8; 256] {
// required by the NAF definition
debug_assert!(w >= 2);
// required so that the NAF digits fit in i8
debug_assert!(w <= 8);
fn inner_to_bytes(&self) -> [u8; 32] {
self.to_repr()
}
use byteorder::{ByteOrder, LittleEndian};
let mut naf = [0i8; 256];
let mut x_u64 = [0u64; 5];
LittleEndian::read_u64_into(self.to_repr().as_ref(), &mut x_u64[0..4]);
let width = 1 << w;
let window_mask = width - 1;
let mut pos = 0;
let mut carry = 0;
while pos < 256 {
// Construct a buffer of bits of the scalar, starting at bit `pos`
let u64_idx = pos / 64;
let bit_idx = pos % 64;
let bit_buf = if bit_idx < 64 - w {
// This window's bits are contained in a single u64
x_u64[u64_idx] >> bit_idx
} else {
// Combine the current u64's bits with the bits from the next u64
(x_u64[u64_idx] >> bit_idx) | (x_u64[1 + u64_idx] << (64 - bit_idx))
};
// Add the carry into the current window
let window = carry + (bit_buf & window_mask);
if window & 1 == 0 {
// If the window value is even, preserve the carry and continue.
// Why is the carry preserved?
// If carry == 0 and window & 1 == 0, then the next carry should be 0
// If carry == 1 and window & 1 == 0, then bit_buf & 1 == 1 so the next carry should be 1
pos += 1;
continue;
}
if window < width / 2 {
carry = 0;
naf[pos] = window as i8;
} else {
carry = 1;
naf[pos] = (window as i8).wrapping_sub(width as i8);
}
pos += w;
}
naf
/// The NAF length for Pallas is 255 since Pallas' order is about 2<sup>254</sup> +
/// 2<sup>125.1</sup>.
fn naf_length() -> usize {
255
}
}
@ -184,14 +135,15 @@ impl VartimeMultiscalarMul for pallas::Point {
.collect::<Option<Vec<_>>>()?;
let mut r = pallas::Point::identity();
let naf_size = Self::Scalar::naf_length();
for i in (0..256).rev() {
for i in (0..naf_size).rev() {
let mut t = r.double();
for (naf, lookup_table) in nafs.iter().zip(lookup_tables.iter()) {
#[allow(clippy::comparison_chain)]
if naf[i] > 0 {
t += lookup_table.select(naf[i] as usize)
t += lookup_table.select(naf[i] as usize);
} else if naf[i] < 0 {
t -= lookup_table.select(-naf[i] as usize);
}

View File

@ -1,6 +1,8 @@
use crate::scalar_mul::VartimeMultiscalarMul;
use crate::scalar_mul::{self, VartimeMultiscalarMul};
use alloc::vec::Vec;
use group::ff::Field;
use group::{ff::PrimeField, GroupEncoding};
use rand::thread_rng;
use pasta_curves::arithmetic::CurveExt;
use pasta_curves::pallas;
@ -27,8 +29,7 @@ fn orchard_binding_basepoint() {
// #[test]
#[allow(dead_code)]
fn gen_pallas_test_vectors() {
use group::{ff::Field, Group};
use rand::thread_rng;
use group::Group;
use std::println;
let rng = thread_rng();
@ -105,3 +106,12 @@ fn test_pallas_vartime_multiscalar_mul() {
let product = pallas::Point::vartime_multiscalar_mul(scalars, points);
assert_eq!(expected_product, product);
}
/// Tests the non-adjacent form for a Pallas scalar.
#[test]
fn test_non_adjacent_form() {
let rng = thread_rng();
let scalar = pallas::Scalar::random(rng);
scalar_mul::tests::test_non_adjacent_form_for_scalar(5, scalar);
}

View File

@ -17,12 +17,8 @@ use core::{borrow::Borrow, fmt::Debug};
use jubjub::{ExtendedNielsPoint, ExtendedPoint};
pub trait NonAdjacentForm {
fn non_adjacent_form(&self, w: usize) -> [i8; 256];
}
#[cfg(test)]
mod tests;
pub(crate) mod tests;
/// A trait for variable-time multiscalar multiplication without precomputation.
pub trait VartimeMultiscalarMul {
@ -67,13 +63,33 @@ pub trait VartimeMultiscalarMul {
}
}
impl NonAdjacentForm for jubjub::Scalar {
/// Compute a width-\\(w\\) "Non-Adjacent Form" of this scalar.
/// Produces the non-adjacent form (NAF) of a 32-byte scalar.
pub trait NonAdjacentForm {
/// Returns the scalar represented as a little-endian byte array.
fn inner_to_bytes(&self) -> [u8; 32];
/// Returns the number of coefficients in the NAF.
///
/// Claim: The length of the NAF requires at most one more coefficient than the length of the
/// binary representation of the scalar. [^1]
///
/// This trait works with scalars of at most 256 binary bits, so the default implementation
/// returns 257. However, some (sub)groups' orders don't reach 256 bits and their scalars don't
/// need the full 256 bits. Setting the corresponding NAF length for a particular curve will
/// speed up the multiscalar multiplication since the number of loop iterations required for the
/// multiplication is equal to the length of the NAF.
///
/// [^1]: The proof is left as an exercise to the reader.
fn naf_length() -> usize {
257
}
/// Computes the width-`w` non-adjacent form (width-`w` NAF) of the scalar.
///
/// Thanks to [`curve25519-dalek`].
///
/// [`curve25519-dalek`]: https://github.com/dalek-cryptography/curve25519-dalek/blob/3e189820da03cc034f5fa143fc7b2ccb21fffa5e/src/scalar.rs#L907
fn non_adjacent_form(&self, w: usize) -> [i8; 256] {
fn non_adjacent_form(&self, w: usize) -> Vec<i8> {
// required by the NAF definition
debug_assert!(w >= 2);
// required so that the NAF digits fit in i8
@ -81,17 +97,19 @@ impl NonAdjacentForm for jubjub::Scalar {
use byteorder::{ByteOrder, LittleEndian};
let mut naf = [0i8; 256];
let naf_length = Self::naf_length();
let mut naf = vec![0; naf_length];
let mut x_u64 = [0u64; 5];
LittleEndian::read_u64_into(&self.to_bytes(), &mut x_u64[0..4]);
LittleEndian::read_u64_into(&self.inner_to_bytes(), &mut x_u64[0..4]);
let width = 1 << w;
let window_mask = width - 1;
let mut pos = 0;
let mut carry = 0;
while pos < 256 {
while pos < naf_length {
// Construct a buffer of bits of the scalar, starting at bit `pos`
let u64_idx = pos / 64;
let bit_idx = pos % 64;
@ -130,6 +148,17 @@ impl NonAdjacentForm for jubjub::Scalar {
}
}
impl NonAdjacentForm for jubjub::Scalar {
fn inner_to_bytes(&self) -> [u8; 32] {
self.to_bytes()
}
/// The NAF length for Jubjub is 253 since Jubjub's order is about 2<sup>251.85</sup>.
fn naf_length() -> usize {
253
}
}
/// Holds odd multiples 1A, 3A, ..., 15A of a point A.
#[derive(Copy, Clone)]
pub(crate) struct LookupTable5<T>(pub(crate) [T; 8]);
@ -195,8 +224,9 @@ impl VartimeMultiscalarMul for ExtendedPoint {
.collect::<Option<Vec<_>>>()?;
let mut r = ExtendedPoint::identity();
let naf_size = Self::Scalar::naf_length();
for i in (0..256).rev() {
for i in (0..naf_size).rev() {
let mut t = r.double();
for (naf, lookup_table) in nafs.iter().zip(lookup_tables.iter()) {

View File

@ -1,35 +1,41 @@
use alloc::vec::Vec;
use group::GroupEncoding;
use jubjub::{ExtendedPoint, Scalar};
use group::{ff::Field, GroupEncoding};
use num_bigint::BigInt;
use num_traits::Zero;
use rand::thread_rng;
use crate::scalar_mul::VartimeMultiscalarMul;
use super::NonAdjacentForm;
/// Generates test vectors for [`test_jubjub_vartime_multiscalar_mul`].
// #[test]
#[allow(dead_code)]
fn gen_jubjub_test_vectors() {
use group::{ff::Field, Group};
use rand::thread_rng;
use group::Group;
use std::println;
let rng = thread_rng();
let scalars = [Scalar::random(rng.clone()), Scalar::random(rng.clone())];
let scalars = [
jubjub::Scalar::random(rng.clone()),
jubjub::Scalar::random(rng.clone()),
];
println!("Scalars:");
for scalar in scalars {
println!("{:?}", scalar.to_bytes());
}
let points = [
ExtendedPoint::random(rng.clone()),
ExtendedPoint::random(rng),
jubjub::ExtendedPoint::random(rng.clone()),
jubjub::ExtendedPoint::random(rng),
];
println!("Points:");
for point in points {
println!("{:?}", point.to_bytes());
}
let res = ExtendedPoint::vartime_multiscalar_mul(scalars, points);
let res = jubjub::ExtendedPoint::vartime_multiscalar_mul(scalars, points);
println!("Result:");
println!("{:?}", res.to_bytes());
}
@ -65,21 +71,94 @@ fn test_jubjub_vartime_multiscalar_mul() {
131, 180, 48, 148, 72, 212, 148, 212, 240, 77, 244, 91, 213,
];
let scalars: Vec<Scalar> = scalars
let scalars: Vec<jubjub::Scalar> = scalars
.into_iter()
.map(|s| Scalar::from_bytes(&s).expect("Could not deserialize a `jubjub::Scalar`."))
.map(|s| jubjub::Scalar::from_bytes(&s).expect("Could not deserialize a `jubjub::Scalar`."))
.collect();
let points: Vec<ExtendedPoint> = points
let points: Vec<jubjub::ExtendedPoint> = points
.into_iter()
.map(|p| {
ExtendedPoint::from_bytes(&p).expect("Could not deserialize a `jubjub::ExtendedPoint`.")
jubjub::ExtendedPoint::from_bytes(&p)
.expect("Could not deserialize a `jubjub::ExtendedPoint`.")
})
.collect();
let expected_product = ExtendedPoint::from_bytes(&expected_product)
let expected_product = jubjub::ExtendedPoint::from_bytes(&expected_product)
.expect("Could not deserialize a `jubjub::ExtendedPoint`.");
let product = ExtendedPoint::vartime_multiscalar_mul(scalars, points);
let product = jubjub::ExtendedPoint::vartime_multiscalar_mul(scalars, points);
assert_eq!(expected_product, product);
}
/// Tests the non-adjacent form for a Jubjub scalar.
#[test]
fn test_non_adjacent_form() {
let rng = thread_rng();
let scalar = jubjub::Scalar::random(rng);
test_non_adjacent_form_for_scalar(5, scalar);
}
/// Tests the non-adjacent form for a particular scalar.
pub(crate) fn test_non_adjacent_form_for_scalar<Scalar: NonAdjacentForm>(w: usize, scalar: Scalar) {
let naf = scalar.non_adjacent_form(w);
let naf_length = Scalar::naf_length();
// Check that the computed w-NAF has the intended length.
assert_eq!(naf.len(), naf_length);
let w = u32::try_from(w).expect("The window `w` did not fit into `u32`.");
// `bound` <- 2^(w-1)
let bound = 2_i32.pow(w - 1);
// `valid_coeffs` <- a range of odd integers from -2^(w-1) to 2^(w-1)
let valid_coeffs: Vec<i32> = (-bound..bound).filter(|x| x.rem_euclid(2) == 1).collect();
let mut reconstructed_scalar: BigInt = Zero::zero();
// Reconstruct the original scalar, and check two general invariants for any w-NAF along the
// way.
let mut i = 0;
while i < naf_length {
if naf[i] != 0 {
// In a w-NAF, every nonzero coefficient `naf[i]` is an odd signed integer with
// -2^(w-1) < `naf[i]` < 2^(w-1).
assert!(valid_coeffs.contains(&i32::from(naf[i])));
// Incrementally keep reconstructing the original scalar.
reconstructed_scalar += naf[i] * BigInt::from(2).pow(i.try_into().unwrap());
// In a w-NAF, at most one of any `w` consecutive coefficients is nonzero.
for _ in 1..w {
i += 1;
if i >= naf_length {
break;
}
assert_eq!(naf[i], 0)
}
}
i += 1;
}
// Check that the reconstructed scalar is not negative, and convert it to little-endian bytes.
let reconstructed_scalar = reconstructed_scalar
.to_biguint()
.expect("The reconstructed scalar is negative.")
.to_bytes_le();
// Check that the reconstructed scalar is not too big.
assert!(reconstructed_scalar.len() <= 32);
// Convert the reconstructed scalar to a fixed byte array so we can compare it with the orginal
// scalar.
let mut reconstructed_scalar_bytes: [u8; 32] = [0; 32];
for (i, byte) in reconstructed_scalar.iter().enumerate() {
reconstructed_scalar_bytes[i] = *byte;
}
// Check that the reconstructed scalar matches the original one.
assert_eq!(reconstructed_scalar_bytes, scalar.inner_to_bytes());
}