change: Refactor & optimize the NAF (#63)
* Make the NAF function generic * Use the `jubjub` prefix for Jubjub types in tests * Add tests for the NAF for Jubjub & Pallas scalars * Use Rust's TryInto for [u8; 32] Co-authored-by: Conrado Gouvea <conradoplg@gmail.com> * Simplify the scalar conversion * Revert "Simplify the scalar conversion" This reverts commitf50ff9dd8a
. * Revert "Use Rust's TryInto for [u8; 32]" This reverts commit282c3b16ac
. --------- Co-authored-by: Deirdre Connolly <deirdre@zfnd.org> Co-authored-by: Conrado Gouvea <conradoplg@gmail.com>
This commit is contained in:
parent
c31c5c4a4f
commit
4f8ce48cd5
|
@ -51,6 +51,8 @@ rand = "0.8"
|
|||
rand_chacha = "0.3"
|
||||
serde_json = "1.0"
|
||||
frost-rerandomized = { version = "0.2", features=["test-impl"] }
|
||||
num-bigint = "0.4.3"
|
||||
num-traits = "0.2.15"
|
||||
|
||||
# `alloc` is only used in test code
|
||||
[dev-dependencies.pasta_curves]
|
||||
|
|
|
@ -87,63 +87,14 @@ impl private::Sealed<Binding> for Binding {
|
|||
|
||||
#[cfg(feature = "alloc")]
|
||||
impl NonAdjacentForm for pallas::Scalar {
|
||||
/// Compute a width-\\(w\\) "Non-Adjacent Form" of this scalar.
|
||||
///
|
||||
/// Thanks to curve25519-dalek
|
||||
fn non_adjacent_form(&self, w: usize) -> [i8; 256] {
|
||||
// required by the NAF definition
|
||||
debug_assert!(w >= 2);
|
||||
// required so that the NAF digits fit in i8
|
||||
debug_assert!(w <= 8);
|
||||
fn inner_to_bytes(&self) -> [u8; 32] {
|
||||
self.to_repr()
|
||||
}
|
||||
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
|
||||
let mut naf = [0i8; 256];
|
||||
|
||||
let mut x_u64 = [0u64; 5];
|
||||
LittleEndian::read_u64_into(self.to_repr().as_ref(), &mut x_u64[0..4]);
|
||||
|
||||
let width = 1 << w;
|
||||
let window_mask = width - 1;
|
||||
|
||||
let mut pos = 0;
|
||||
let mut carry = 0;
|
||||
while pos < 256 {
|
||||
// Construct a buffer of bits of the scalar, starting at bit `pos`
|
||||
let u64_idx = pos / 64;
|
||||
let bit_idx = pos % 64;
|
||||
let bit_buf = if bit_idx < 64 - w {
|
||||
// This window's bits are contained in a single u64
|
||||
x_u64[u64_idx] >> bit_idx
|
||||
} else {
|
||||
// Combine the current u64's bits with the bits from the next u64
|
||||
(x_u64[u64_idx] >> bit_idx) | (x_u64[1 + u64_idx] << (64 - bit_idx))
|
||||
};
|
||||
|
||||
// Add the carry into the current window
|
||||
let window = carry + (bit_buf & window_mask);
|
||||
|
||||
if window & 1 == 0 {
|
||||
// If the window value is even, preserve the carry and continue.
|
||||
// Why is the carry preserved?
|
||||
// If carry == 0 and window & 1 == 0, then the next carry should be 0
|
||||
// If carry == 1 and window & 1 == 0, then bit_buf & 1 == 1 so the next carry should be 1
|
||||
pos += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if window < width / 2 {
|
||||
carry = 0;
|
||||
naf[pos] = window as i8;
|
||||
} else {
|
||||
carry = 1;
|
||||
naf[pos] = (window as i8).wrapping_sub(width as i8);
|
||||
}
|
||||
|
||||
pos += w;
|
||||
}
|
||||
|
||||
naf
|
||||
/// The NAF length for Pallas is 255 since Pallas' order is about 2<sup>254</sup> +
|
||||
/// 2<sup>125.1</sup>.
|
||||
fn naf_length() -> usize {
|
||||
255
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -184,14 +135,15 @@ impl VartimeMultiscalarMul for pallas::Point {
|
|||
.collect::<Option<Vec<_>>>()?;
|
||||
|
||||
let mut r = pallas::Point::identity();
|
||||
let naf_size = Self::Scalar::naf_length();
|
||||
|
||||
for i in (0..256).rev() {
|
||||
for i in (0..naf_size).rev() {
|
||||
let mut t = r.double();
|
||||
|
||||
for (naf, lookup_table) in nafs.iter().zip(lookup_tables.iter()) {
|
||||
#[allow(clippy::comparison_chain)]
|
||||
if naf[i] > 0 {
|
||||
t += lookup_table.select(naf[i] as usize)
|
||||
t += lookup_table.select(naf[i] as usize);
|
||||
} else if naf[i] < 0 {
|
||||
t -= lookup_table.select(-naf[i] as usize);
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
use crate::scalar_mul::VartimeMultiscalarMul;
|
||||
use crate::scalar_mul::{self, VartimeMultiscalarMul};
|
||||
use alloc::vec::Vec;
|
||||
use group::ff::Field;
|
||||
use group::{ff::PrimeField, GroupEncoding};
|
||||
use rand::thread_rng;
|
||||
|
||||
use pasta_curves::arithmetic::CurveExt;
|
||||
use pasta_curves::pallas;
|
||||
|
@ -27,8 +29,7 @@ fn orchard_binding_basepoint() {
|
|||
// #[test]
|
||||
#[allow(dead_code)]
|
||||
fn gen_pallas_test_vectors() {
|
||||
use group::{ff::Field, Group};
|
||||
use rand::thread_rng;
|
||||
use group::Group;
|
||||
use std::println;
|
||||
|
||||
let rng = thread_rng();
|
||||
|
@ -105,3 +106,12 @@ fn test_pallas_vartime_multiscalar_mul() {
|
|||
let product = pallas::Point::vartime_multiscalar_mul(scalars, points);
|
||||
assert_eq!(expected_product, product);
|
||||
}
|
||||
|
||||
/// Tests the non-adjacent form for a Pallas scalar.
|
||||
#[test]
|
||||
fn test_non_adjacent_form() {
|
||||
let rng = thread_rng();
|
||||
|
||||
let scalar = pallas::Scalar::random(rng);
|
||||
scalar_mul::tests::test_non_adjacent_form_for_scalar(5, scalar);
|
||||
}
|
||||
|
|
|
@ -17,12 +17,8 @@ use core::{borrow::Borrow, fmt::Debug};
|
|||
|
||||
use jubjub::{ExtendedNielsPoint, ExtendedPoint};
|
||||
|
||||
pub trait NonAdjacentForm {
|
||||
fn non_adjacent_form(&self, w: usize) -> [i8; 256];
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
pub(crate) mod tests;
|
||||
|
||||
/// A trait for variable-time multiscalar multiplication without precomputation.
|
||||
pub trait VartimeMultiscalarMul {
|
||||
|
@ -67,13 +63,33 @@ pub trait VartimeMultiscalarMul {
|
|||
}
|
||||
}
|
||||
|
||||
impl NonAdjacentForm for jubjub::Scalar {
|
||||
/// Compute a width-\\(w\\) "Non-Adjacent Form" of this scalar.
|
||||
/// Produces the non-adjacent form (NAF) of a 32-byte scalar.
|
||||
pub trait NonAdjacentForm {
|
||||
/// Returns the scalar represented as a little-endian byte array.
|
||||
fn inner_to_bytes(&self) -> [u8; 32];
|
||||
|
||||
/// Returns the number of coefficients in the NAF.
|
||||
///
|
||||
/// Claim: The length of the NAF requires at most one more coefficient than the length of the
|
||||
/// binary representation of the scalar. [^1]
|
||||
///
|
||||
/// This trait works with scalars of at most 256 binary bits, so the default implementation
|
||||
/// returns 257. However, some (sub)groups' orders don't reach 256 bits and their scalars don't
|
||||
/// need the full 256 bits. Setting the corresponding NAF length for a particular curve will
|
||||
/// speed up the multiscalar multiplication since the number of loop iterations required for the
|
||||
/// multiplication is equal to the length of the NAF.
|
||||
///
|
||||
/// [^1]: The proof is left as an exercise to the reader.
|
||||
fn naf_length() -> usize {
|
||||
257
|
||||
}
|
||||
|
||||
/// Computes the width-`w` non-adjacent form (width-`w` NAF) of the scalar.
|
||||
///
|
||||
/// Thanks to [`curve25519-dalek`].
|
||||
///
|
||||
/// [`curve25519-dalek`]: https://github.com/dalek-cryptography/curve25519-dalek/blob/3e189820da03cc034f5fa143fc7b2ccb21fffa5e/src/scalar.rs#L907
|
||||
fn non_adjacent_form(&self, w: usize) -> [i8; 256] {
|
||||
fn non_adjacent_form(&self, w: usize) -> Vec<i8> {
|
||||
// required by the NAF definition
|
||||
debug_assert!(w >= 2);
|
||||
// required so that the NAF digits fit in i8
|
||||
|
@ -81,17 +97,19 @@ impl NonAdjacentForm for jubjub::Scalar {
|
|||
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
|
||||
let mut naf = [0i8; 256];
|
||||
let naf_length = Self::naf_length();
|
||||
let mut naf = vec![0; naf_length];
|
||||
|
||||
let mut x_u64 = [0u64; 5];
|
||||
LittleEndian::read_u64_into(&self.to_bytes(), &mut x_u64[0..4]);
|
||||
LittleEndian::read_u64_into(&self.inner_to_bytes(), &mut x_u64[0..4]);
|
||||
|
||||
let width = 1 << w;
|
||||
let window_mask = width - 1;
|
||||
|
||||
let mut pos = 0;
|
||||
let mut carry = 0;
|
||||
while pos < 256 {
|
||||
|
||||
while pos < naf_length {
|
||||
// Construct a buffer of bits of the scalar, starting at bit `pos`
|
||||
let u64_idx = pos / 64;
|
||||
let bit_idx = pos % 64;
|
||||
|
@ -130,6 +148,17 @@ impl NonAdjacentForm for jubjub::Scalar {
|
|||
}
|
||||
}
|
||||
|
||||
impl NonAdjacentForm for jubjub::Scalar {
|
||||
fn inner_to_bytes(&self) -> [u8; 32] {
|
||||
self.to_bytes()
|
||||
}
|
||||
|
||||
/// The NAF length for Jubjub is 253 since Jubjub's order is about 2<sup>251.85</sup>.
|
||||
fn naf_length() -> usize {
|
||||
253
|
||||
}
|
||||
}
|
||||
|
||||
/// Holds odd multiples 1A, 3A, ..., 15A of a point A.
|
||||
#[derive(Copy, Clone)]
|
||||
pub(crate) struct LookupTable5<T>(pub(crate) [T; 8]);
|
||||
|
@ -195,8 +224,9 @@ impl VartimeMultiscalarMul for ExtendedPoint {
|
|||
.collect::<Option<Vec<_>>>()?;
|
||||
|
||||
let mut r = ExtendedPoint::identity();
|
||||
let naf_size = Self::Scalar::naf_length();
|
||||
|
||||
for i in (0..256).rev() {
|
||||
for i in (0..naf_size).rev() {
|
||||
let mut t = r.double();
|
||||
|
||||
for (naf, lookup_table) in nafs.iter().zip(lookup_tables.iter()) {
|
||||
|
|
|
@ -1,35 +1,41 @@
|
|||
use alloc::vec::Vec;
|
||||
use group::GroupEncoding;
|
||||
use jubjub::{ExtendedPoint, Scalar};
|
||||
use group::{ff::Field, GroupEncoding};
|
||||
use num_bigint::BigInt;
|
||||
use num_traits::Zero;
|
||||
use rand::thread_rng;
|
||||
|
||||
use crate::scalar_mul::VartimeMultiscalarMul;
|
||||
|
||||
use super::NonAdjacentForm;
|
||||
|
||||
/// Generates test vectors for [`test_jubjub_vartime_multiscalar_mul`].
|
||||
// #[test]
|
||||
#[allow(dead_code)]
|
||||
fn gen_jubjub_test_vectors() {
|
||||
use group::{ff::Field, Group};
|
||||
use rand::thread_rng;
|
||||
use group::Group;
|
||||
use std::println;
|
||||
|
||||
let rng = thread_rng();
|
||||
|
||||
let scalars = [Scalar::random(rng.clone()), Scalar::random(rng.clone())];
|
||||
let scalars = [
|
||||
jubjub::Scalar::random(rng.clone()),
|
||||
jubjub::Scalar::random(rng.clone()),
|
||||
];
|
||||
println!("Scalars:");
|
||||
for scalar in scalars {
|
||||
println!("{:?}", scalar.to_bytes());
|
||||
}
|
||||
|
||||
let points = [
|
||||
ExtendedPoint::random(rng.clone()),
|
||||
ExtendedPoint::random(rng),
|
||||
jubjub::ExtendedPoint::random(rng.clone()),
|
||||
jubjub::ExtendedPoint::random(rng),
|
||||
];
|
||||
println!("Points:");
|
||||
for point in points {
|
||||
println!("{:?}", point.to_bytes());
|
||||
}
|
||||
|
||||
let res = ExtendedPoint::vartime_multiscalar_mul(scalars, points);
|
||||
let res = jubjub::ExtendedPoint::vartime_multiscalar_mul(scalars, points);
|
||||
println!("Result:");
|
||||
println!("{:?}", res.to_bytes());
|
||||
}
|
||||
|
@ -65,21 +71,94 @@ fn test_jubjub_vartime_multiscalar_mul() {
|
|||
131, 180, 48, 148, 72, 212, 148, 212, 240, 77, 244, 91, 213,
|
||||
];
|
||||
|
||||
let scalars: Vec<Scalar> = scalars
|
||||
let scalars: Vec<jubjub::Scalar> = scalars
|
||||
.into_iter()
|
||||
.map(|s| Scalar::from_bytes(&s).expect("Could not deserialize a `jubjub::Scalar`."))
|
||||
.map(|s| jubjub::Scalar::from_bytes(&s).expect("Could not deserialize a `jubjub::Scalar`."))
|
||||
.collect();
|
||||
|
||||
let points: Vec<ExtendedPoint> = points
|
||||
let points: Vec<jubjub::ExtendedPoint> = points
|
||||
.into_iter()
|
||||
.map(|p| {
|
||||
ExtendedPoint::from_bytes(&p).expect("Could not deserialize a `jubjub::ExtendedPoint`.")
|
||||
jubjub::ExtendedPoint::from_bytes(&p)
|
||||
.expect("Could not deserialize a `jubjub::ExtendedPoint`.")
|
||||
})
|
||||
.collect();
|
||||
|
||||
let expected_product = ExtendedPoint::from_bytes(&expected_product)
|
||||
let expected_product = jubjub::ExtendedPoint::from_bytes(&expected_product)
|
||||
.expect("Could not deserialize a `jubjub::ExtendedPoint`.");
|
||||
|
||||
let product = ExtendedPoint::vartime_multiscalar_mul(scalars, points);
|
||||
let product = jubjub::ExtendedPoint::vartime_multiscalar_mul(scalars, points);
|
||||
assert_eq!(expected_product, product);
|
||||
}
|
||||
|
||||
/// Tests the non-adjacent form for a Jubjub scalar.
|
||||
#[test]
|
||||
fn test_non_adjacent_form() {
|
||||
let rng = thread_rng();
|
||||
|
||||
let scalar = jubjub::Scalar::random(rng);
|
||||
test_non_adjacent_form_for_scalar(5, scalar);
|
||||
}
|
||||
|
||||
/// Tests the non-adjacent form for a particular scalar.
|
||||
pub(crate) fn test_non_adjacent_form_for_scalar<Scalar: NonAdjacentForm>(w: usize, scalar: Scalar) {
|
||||
let naf = scalar.non_adjacent_form(w);
|
||||
let naf_length = Scalar::naf_length();
|
||||
|
||||
// Check that the computed w-NAF has the intended length.
|
||||
assert_eq!(naf.len(), naf_length);
|
||||
|
||||
let w = u32::try_from(w).expect("The window `w` did not fit into `u32`.");
|
||||
|
||||
// `bound` <- 2^(w-1)
|
||||
let bound = 2_i32.pow(w - 1);
|
||||
|
||||
// `valid_coeffs` <- a range of odd integers from -2^(w-1) to 2^(w-1)
|
||||
let valid_coeffs: Vec<i32> = (-bound..bound).filter(|x| x.rem_euclid(2) == 1).collect();
|
||||
|
||||
let mut reconstructed_scalar: BigInt = Zero::zero();
|
||||
|
||||
// Reconstruct the original scalar, and check two general invariants for any w-NAF along the
|
||||
// way.
|
||||
let mut i = 0;
|
||||
while i < naf_length {
|
||||
if naf[i] != 0 {
|
||||
// In a w-NAF, every nonzero coefficient `naf[i]` is an odd signed integer with
|
||||
// -2^(w-1) < `naf[i]` < 2^(w-1).
|
||||
assert!(valid_coeffs.contains(&i32::from(naf[i])));
|
||||
|
||||
// Incrementally keep reconstructing the original scalar.
|
||||
reconstructed_scalar += naf[i] * BigInt::from(2).pow(i.try_into().unwrap());
|
||||
|
||||
// In a w-NAF, at most one of any `w` consecutive coefficients is nonzero.
|
||||
for _ in 1..w {
|
||||
i += 1;
|
||||
if i >= naf_length {
|
||||
break;
|
||||
}
|
||||
assert_eq!(naf[i], 0)
|
||||
}
|
||||
}
|
||||
|
||||
i += 1;
|
||||
}
|
||||
|
||||
// Check that the reconstructed scalar is not negative, and convert it to little-endian bytes.
|
||||
let reconstructed_scalar = reconstructed_scalar
|
||||
.to_biguint()
|
||||
.expect("The reconstructed scalar is negative.")
|
||||
.to_bytes_le();
|
||||
|
||||
// Check that the reconstructed scalar is not too big.
|
||||
assert!(reconstructed_scalar.len() <= 32);
|
||||
|
||||
// Convert the reconstructed scalar to a fixed byte array so we can compare it with the orginal
|
||||
// scalar.
|
||||
let mut reconstructed_scalar_bytes: [u8; 32] = [0; 32];
|
||||
for (i, byte) in reconstructed_scalar.iter().enumerate() {
|
||||
reconstructed_scalar_bytes[i] = *byte;
|
||||
}
|
||||
|
||||
// Check that the reconstructed scalar matches the original one.
|
||||
assert_eq!(reconstructed_scalar_bytes, scalar.inner_to_bytes());
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue