feat: Add tests for vartime multiscalar multiplication (#57)

* Move Pallas tests to their own file

* Add tests for multiplication on Pallas

* Add tests for multiplication on Jubjub

* Use `assert_eq` instead of `assert`

* Apply suggestions from code review

Co-authored-by: Conrado Gouvea <conradoplg@gmail.com>

* Refactor Pallas tests

* Refactor Jubjub tests

* Use `product` instead of `res`

---------

Co-authored-by: Conrado Gouvea <conradoplg@gmail.com>
This commit is contained in:
Marek 2023-04-22 01:24:11 +02:00 committed by GitHub
parent f8ad8ea992
commit cc558d4f79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 198 additions and 27 deletions

View File

@ -15,6 +15,9 @@ use crate::{private, SigType};
#[cfg(feature = "alloc")]
use crate::scalar_mul::{LookupTable5, NonAdjacentForm, VartimeMultiscalarMul};
#[cfg(test)]
mod tests;
/// The byte-encoding of the basepoint for the Orchard `SpendAuthSig` on the [Pallas curve][pallasandvesta].
///
/// [pallasandvesta]: https://zips.z.cash/protocol/nu5.pdf#pallasandvesta
@ -200,30 +203,3 @@ impl VartimeMultiscalarMul for pallas::Point {
Some(r)
}
}
#[cfg(test)]
mod tests {
#[test]
fn orchard_spendauth_basepoint() {
use super::ORCHARD_SPENDAUTHSIG_BASEPOINT_BYTES;
use group::GroupEncoding;
use pasta_curves::{arithmetic::CurveExt, pallas};
assert_eq!(
pallas::Point::hash_to_curve("z.cash:Orchard")(b"G").to_bytes(),
ORCHARD_SPENDAUTHSIG_BASEPOINT_BYTES
);
}
#[test]
fn orchard_binding_basepoint() {
use super::ORCHARD_BINDINGSIG_BASEPOINT_BYTES;
use group::GroupEncoding;
use pasta_curves::{arithmetic::CurveExt, pallas};
assert_eq!(
pallas::Point::hash_to_curve("z.cash:Orchard-cv")(b"r").to_bytes(),
ORCHARD_BINDINGSIG_BASEPOINT_BYTES
);
}
}

107
src/orchard/tests.rs Normal file
View File

@ -0,0 +1,107 @@
use crate::scalar_mul::VartimeMultiscalarMul;
use alloc::vec::Vec;
use group::{ff::PrimeField, GroupEncoding};
use pasta_curves::arithmetic::CurveExt;
use pasta_curves::pallas;
#[test]
fn orchard_spendauth_basepoint() {
use super::ORCHARD_SPENDAUTHSIG_BASEPOINT_BYTES;
assert_eq!(
pallas::Point::hash_to_curve("z.cash:Orchard")(b"G").to_bytes(),
ORCHARD_SPENDAUTHSIG_BASEPOINT_BYTES
);
}
#[test]
fn orchard_binding_basepoint() {
use super::ORCHARD_BINDINGSIG_BASEPOINT_BYTES;
assert_eq!(
pallas::Point::hash_to_curve("z.cash:Orchard-cv")(b"r").to_bytes(),
ORCHARD_BINDINGSIG_BASEPOINT_BYTES
);
}
/// Generates test vectors for [`test_pallas_vartime_multiscalar_mul`].
// #[test]
#[allow(dead_code)]
fn gen_pallas_test_vectors() {
use group::{ff::Field, Group};
use rand::thread_rng;
use std::println;
let rng = thread_rng();
let scalars = [
pallas::Scalar::random(rng.clone()),
pallas::Scalar::random(rng.clone()),
];
println!("Scalars:");
for scalar in scalars {
println!("{:?}", scalar.to_repr());
}
let points = [
pallas::Point::random(rng.clone()),
pallas::Point::random(rng),
];
println!("Points:");
for point in points {
println!("{:?}", point.to_bytes());
}
let res = pallas::Point::vartime_multiscalar_mul(scalars, points);
println!("Result:");
println!("{:?}", res.to_bytes());
}
/// Checks if the vartime multiscalar multiplication on Pallas produces the expected product.
/// The test vectors were generated by [`gen_pallas_test_vectors`].
#[test]
fn test_pallas_vartime_multiscalar_mul() {
let scalars: [[u8; 32]; 2] = [
[
235, 211, 155, 231, 188, 225, 161, 143, 148, 66, 177, 18, 246, 175, 177, 55, 1, 185,
115, 175, 208, 12, 252, 5, 168, 198, 26, 166, 129, 252, 158, 8,
],
[
1, 8, 55, 59, 168, 56, 248, 199, 77, 230, 228, 96, 35, 65, 191, 56, 137, 226, 161, 184,
105, 223, 98, 166, 248, 160, 156, 74, 18, 228, 122, 44,
],
];
let points: [[u8; 32]; 2] = [
[
81, 113, 73, 111, 90, 141, 91, 248, 252, 201, 109, 74, 99, 75, 11, 228, 152, 144, 254,
104, 240, 69, 211, 23, 201, 128, 236, 187, 233, 89, 59, 133,
],
[
177, 3, 100, 162, 246, 15, 81, 236, 51, 73, 69, 43, 45, 202, 226, 99, 27, 58, 133, 52,
231, 244, 125, 221, 88, 155, 192, 4, 164, 102, 34, 143,
],
];
let expected_product: [u8; 32] = [
68, 54, 98, 93, 238, 28, 229, 186, 127, 154, 101, 209, 216, 214, 66, 45, 141, 210, 70, 119,
100, 245, 164, 155, 213, 45, 126, 17, 199, 8, 84, 143,
];
let scalars: Vec<pallas::Scalar> = scalars
.into_iter()
.map(|s| {
pallas::Scalar::from_repr_vartime(s).expect("Could not deserialize a `pallas::Scalar`.")
})
.collect();
let points: Vec<pallas::Point> = points
.into_iter()
.map(|p| pallas::Point::from_bytes(&p).expect("Could not deserialize a `pallas::Point`."))
.collect();
let expected_product = pallas::Point::from_bytes(&expected_product)
.expect("Could not deserialize a `pallas::Point`.");
let product = pallas::Point::vartime_multiscalar_mul(scalars, points);
assert_eq!(expected_product, product);
}

View File

@ -21,6 +21,9 @@ pub trait NonAdjacentForm {
fn non_adjacent_form(&self, w: usize) -> [i8; 256];
}
#[cfg(test)]
mod tests;
/// A trait for variable-time multiscalar multiplication without precomputation.
pub trait VartimeMultiscalarMul {
/// The type of scalar being multiplied, e.g., `jubjub::Scalar`.

85
src/scalar_mul/tests.rs Normal file
View File

@ -0,0 +1,85 @@
use alloc::vec::Vec;
use group::GroupEncoding;
use jubjub::{ExtendedPoint, Scalar};
use crate::scalar_mul::VartimeMultiscalarMul;
/// Generates test vectors for [`test_jubjub_vartime_multiscalar_mul`].
// #[test]
#[allow(dead_code)]
fn gen_jubjub_test_vectors() {
use group::{ff::Field, Group};
use rand::thread_rng;
use std::println;
let rng = thread_rng();
let scalars = [Scalar::random(rng.clone()), Scalar::random(rng.clone())];
println!("Scalars:");
for scalar in scalars {
println!("{:?}", scalar.to_bytes());
}
let points = [
ExtendedPoint::random(rng.clone()),
ExtendedPoint::random(rng),
];
println!("Points:");
for point in points {
println!("{:?}", point.to_bytes());
}
let res = ExtendedPoint::vartime_multiscalar_mul(scalars, points);
println!("Result:");
println!("{:?}", res.to_bytes());
}
/// Checks if the vartime multiscalar multiplication on Jubjub produces the expected product.
/// The test vectors were generated by [`gen_jubjub_test_vectors`].
#[test]
fn test_jubjub_vartime_multiscalar_mul() {
let scalars: [[u8; 32]; 2] = [
[
147, 209, 135, 83, 133, 175, 29, 28, 22, 161, 0, 220, 100, 218, 103, 47, 134, 242, 49,
19, 254, 204, 107, 185, 189, 155, 33, 110, 100, 141, 59, 0,
],
[
138, 136, 196, 249, 144, 2, 9, 103, 233, 93, 253, 46, 181, 12, 41, 158, 62, 201, 35,
198, 108, 139, 136, 78, 210, 12, 1, 223, 231, 22, 92, 13,
],
];
let points: [[u8; 32]; 2] = [
[
93, 252, 67, 45, 63, 170, 103, 247, 53, 37, 164, 250, 32, 210, 38, 71, 162, 68, 205,
176, 116, 46, 209, 66, 131, 209, 107, 193, 210, 153, 222, 31,
],
[
139, 112, 204, 231, 187, 141, 159, 122, 210, 164, 7, 162, 185, 171, 47, 199, 5, 33, 80,
207, 129, 24, 165, 90, 204, 253, 38, 27, 55, 86, 225, 52,
],
];
let expected_product: [u8; 32] = [
64, 228, 212, 168, 76, 90, 248, 218, 86, 22, 182, 130, 227, 52, 170, 88, 220, 193, 166,
131, 180, 48, 148, 72, 212, 148, 212, 240, 77, 244, 91, 213,
];
let scalars: Vec<Scalar> = scalars
.into_iter()
.map(|s| Scalar::from_bytes(&s).expect("Could not deserialize a `jubjub::Scalar`."))
.collect();
let points: Vec<ExtendedPoint> = points
.into_iter()
.map(|p| {
ExtendedPoint::from_bytes(&p).expect("Could not deserialize a `jubjub::ExtendedPoint`.")
})
.collect();
let expected_product = ExtendedPoint::from_bytes(&expected_product)
.expect("Could not deserialize a `jubjub::ExtendedPoint`.");
let product = ExtendedPoint::vartime_multiscalar_mul(scalars, points);
assert_eq!(expected_product, product);
}