Move fixed bases utils into halo2_ecc crate.

This commit is contained in:
therealyingtong 2021-08-25 12:26:33 +08:00
parent f762ee342e
commit a997364545
13 changed files with 261 additions and 247 deletions

View File

@ -0,0 +1,125 @@
//! Utilities to compute associated constants for fixed bases.
use super::{FIXED_BASE_WINDOW_SIZE, H};
use arrayvec::ArrayVec;
use ff::Field;
use group::Curve;
use halo2::arithmetic::lagrange_interpolate;
use pasta_curves::arithmetic::{CurveAffine, FieldExt};
/// For each fixed base, we calculate its scalar multiples in three-bit windows.
/// Each window will have $2^3 = 8$ points.
pub fn compute_window_table<C: CurveAffine>(base: C, num_windows: usize) -> Vec<[C; H]> {
let mut window_table: Vec<[C; H]> = Vec::with_capacity(num_windows);
// Generate window table entries for all windows but the last.
// For these first `num_windows - 1` windows, we compute the multiple [(k+2)*(2^3)^w]B.
// Here, w ranges from [0..`num_windows - 1`)
for w in 0..(num_windows - 1) {
window_table.push(
(0..H)
.map(|k| {
// scalar = (k+2)*(8^w)
let scalar = C::ScalarExt::from_u64(k as u64 + 2)
* C::ScalarExt::from_u64(H as u64).pow(&[w as u64, 0, 0, 0]);
(base * scalar).to_affine()
})
.collect::<ArrayVec<C, H>>()
.into_inner()
.unwrap(),
);
}
// Generate window table entries for the last window, w = `num_windows - 1`.
// For the last window, we compute [k * (2^3)^w - sum]B, where sum is defined
// as sum = \sum_{j = 0}^{`num_windows - 2`} 2^{3j+1}
let sum = (0..(num_windows - 1)).fold(C::ScalarExt::zero(), |acc, j| {
acc + C::ScalarExt::from_u64(2).pow(&[
FIXED_BASE_WINDOW_SIZE as u64 * j as u64 + 1,
0,
0,
0,
])
});
window_table.push(
(0..H)
.map(|k| {
// scalar = k * (2^3)^w - sum, where w = `num_windows - 1`
let scalar = C::ScalarExt::from_u64(k as u64)
* C::ScalarExt::from_u64(H as u64).pow(&[(num_windows - 1) as u64, 0, 0, 0])
- sum;
(base * scalar).to_affine()
})
.collect::<ArrayVec<C, H>>()
.into_inner()
.unwrap(),
);
window_table
}
/// For each window, we interpolate the $x$-coordinate.
/// Here, we pre-compute and store the coefficients of the interpolation polynomial.
pub fn compute_lagrange_coeffs<C: CurveAffine>(base: C, num_windows: usize) -> Vec<[C::Base; H]> {
// We are interpolating over the 3-bit window, k \in [0..8)
let points: Vec<_> = (0..H).map(|i| C::Base::from_u64(i as u64)).collect();
let window_table = compute_window_table(base, num_windows);
window_table
.iter()
.map(|window_points| {
let x_window_points: Vec<_> = window_points
.iter()
.map(|point| *point.coordinates().unwrap().x())
.collect();
lagrange_interpolate(&points, &x_window_points)
.into_iter()
.collect::<ArrayVec<C::Base, H>>()
.into_inner()
.unwrap()
})
.collect()
}
/// For each window, $z$ is a field element such that for each point $(x, y)$ in the window:
/// - $z + y = u^2$ (some square in the field); and
/// - $z - y$ is not a square.
/// If successful, return a vector of `(z: u64, us: [C::Base; H])` for each window.
///
/// This function was used to generate the `z`s and `u`s for the Orchard fixed
/// bases. The outputs of this function have been stored as constants, and it
/// is not called anywhere in this codebase. However, we keep this function here
/// as a utility for those who wish to use it with different parameters.
pub fn find_zs_and_us<C: CurveAffine>(
base: C,
num_windows: usize,
) -> Option<Vec<(u64, [[u8; 32]; H])>> {
// Closure to find z and u's for one window
let find_z_and_us = |window_points: &[C]| {
assert_eq!(H, window_points.len());
let ys: Vec<_> = window_points
.iter()
.map(|point| *point.coordinates().unwrap().y())
.collect();
(0..(1000 * (1 << (2 * H)))).find_map(|z| {
ys.iter()
.map(|&y| {
let u = if (-y + C::Base::from_u64(z)).sqrt().is_none().into() {
(y + C::Base::from_u64(z)).sqrt().into()
} else {
None
};
u.map(|u: C::Base| u.to_bytes())
})
.collect::<Option<ArrayVec<[u8; 32], H>>>()
.map(|us| (z, us.into_inner().unwrap()))
})
};
let window_table = compute_window_table(base, num_windows);
window_table
.iter()
.map(|window_points| find_z_and_us(window_points))
.collect()
}

View File

@ -164,9 +164,17 @@ pub fn decompose_word<F: PrimeFieldBits>(
/// Takes in an FnMut closure and returns a constant-length array with elements of
/// type `Output`.
pub fn gen_const_array<Output: Copy + Default, const LEN: usize>(
closure: impl FnMut(usize) -> Output,
) -> [Output; LEN] {
gen_const_array_with_default(Default::default(), closure)
}
/// Uses gen_const_array with a given default.
pub fn gen_const_array_with_default<Output: Copy, const LEN: usize>(
default_value: Output,
mut closure: impl FnMut(usize) -> Output,
) -> [Output; LEN] {
let mut ret: [Output; LEN] = [Default::default(); LEN];
let mut ret: [Output; LEN] = [default_value; LEN];
for (bit, val) in ret.iter_mut().zip((0..LEN).map(|idx| closure(idx))) {
*bit = val;
}

View File

@ -1,11 +1,9 @@
//! Constants used in the Orchard protocol.
pub mod fixed_bases;
pub mod sinsemilla;
pub mod util;
pub use self::sinsemilla::{OrchardCommitDomains, OrchardHashDomains};
pub use fixed_bases::OrchardFixedBases;
pub use util::{evaluate, gen_const_array};
/// $\mathsf{MerkleDepth^{Orchard}}$
pub(crate) const MERKLE_DEPTH_ORCHARD: usize = 32;
@ -61,3 +59,98 @@ mod tests {
assert_eq!(t_p + two_pow_254, pallas::Base::zero());
}
}
#[cfg(test)]
use pasta_curves::arithmetic::{CurveAffine, FieldExt};
#[cfg(test)]
/// Test that Lagrange interpolation coefficients reproduce the correct x-coordinate
/// for each fixed-base multiple in each window.
fn test_lagrange_coeffs<C: CurveAffine>(base: C, num_windows: usize) {
use ecc::{chip::compute_lagrange_coeffs, gadget::FIXED_BASE_WINDOW_SIZE};
use ff::Field;
use group::Curve;
fn evaluate<C: CurveAffine>(x: u8, coeffs: &[C::Base]) -> C::Base {
let x = C::Base::from_u64(x as u64);
coeffs
.iter()
.rev()
.cloned()
.reduce(|acc, coeff| acc * x + coeff)
.unwrap_or_else(C::Base::zero)
}
let lagrange_coeffs = compute_lagrange_coeffs(base, num_windows);
// Check first 84 windows, i.e. `k_0, k_1, ..., k_83`
for (idx, coeffs) in lagrange_coeffs[0..(num_windows - 1)].iter().enumerate() {
// Test each three-bit chunk in this window.
for bits in 0..(1 << FIXED_BASE_WINDOW_SIZE) {
{
// Interpolate the x-coordinate using this window's coefficients
let interpolated_x = evaluate::<C>(bits, coeffs);
// Compute the actual x-coordinate of the multiple [(k+2)*(8^w)]B.
let point = base
* C::Scalar::from_u64(bits as u64 + 2)
* C::Scalar::from_u64(fixed_bases::H as u64).pow(&[idx as u64, 0, 0, 0]);
let x = *point.to_affine().coordinates().unwrap().x();
// Check that the interpolated x-coordinate matches the actual one.
assert_eq!(x, interpolated_x);
}
}
}
// Check last window.
for bits in 0..(1 << FIXED_BASE_WINDOW_SIZE) {
// Interpolate the x-coordinate using the last window's coefficients
let interpolated_x = evaluate::<C>(bits, &lagrange_coeffs[num_windows - 1]);
// Compute the actual x-coordinate of the multiple [k * (8^84) - offset]B,
// where offset = \sum_{j = 0}^{83} 2^{3j+1}
let offset = (0..(num_windows - 1)).fold(C::Scalar::zero(), |acc, w| {
acc + C::Scalar::from_u64(2).pow(&[
FIXED_BASE_WINDOW_SIZE as u64 * w as u64 + 1,
0,
0,
0,
])
});
let scalar = C::Scalar::from_u64(bits as u64)
* C::Scalar::from_u64(fixed_bases::H as u64).pow(&[(num_windows - 1) as u64, 0, 0, 0])
- offset;
let point = base * scalar;
let x = *point.to_affine().coordinates().unwrap().x();
// Check that the interpolated x-coordinate matches the actual one.
assert_eq!(x, interpolated_x);
}
}
#[cfg(test)]
// Test that the z-values and u-values satisfy the conditions:
// 1. z + y = u^2,
// 2. z - y is not a square
// for the y-coordinate of each fixed-base multiple in each window.
fn test_zs_and_us<C: CurveAffine>(
base: C,
z: &[u64],
u: &[[[u8; 32]; ecc::gadget::H]],
num_windows: usize,
) {
use ecc::chip::compute_window_table;
use ff::Field;
let window_table = compute_window_table(base, num_windows);
for ((u, z), window_points) in u.iter().zip(z.iter()).zip(window_table) {
for (u, point) in u.iter().zip(window_points.iter()) {
let y = *point.coordinates().unwrap().y();
let u = C::Base::from_bytes(u).unwrap();
assert_eq!(C::Base::from_u64(*z) + y, u * u); // allow either square root
assert!(bool::from((C::Base::from_u64(*z) - y).sqrt().is_none()));
}
}
}

View File

@ -1,15 +1,8 @@
//! Orchard fixed bases.
use super::{L_ORCHARD_SCALAR, L_VALUE};
use ecc::gadget::FixedPoints;
use ecc::{chip::compute_lagrange_coeffs, gadget::FixedPoints};
use arrayvec::ArrayVec;
use ff::Field;
use group::Curve;
use halo2::arithmetic::lagrange_interpolate;
use pasta_curves::{
arithmetic::{CurveAffine, FieldExt},
pallas,
};
use pasta_curves::pallas;
pub mod commit_ivk_r;
pub mod note_commit_r;
@ -51,120 +44,6 @@ pub const NUM_WINDOWS: usize =
pub const NUM_WINDOWS_SHORT: usize =
(L_VALUE + FIXED_BASE_WINDOW_SIZE - 1) / FIXED_BASE_WINDOW_SIZE;
/// For each fixed base, we calculate its scalar multiples in three-bit windows.
/// Each window will have $2^3 = 8$ points.
fn compute_window_table<C: CurveAffine>(base: C, num_windows: usize) -> Vec<[C; H]> {
let mut window_table: Vec<[C; H]> = Vec::with_capacity(num_windows);
// Generate window table entries for all windows but the last.
// For these first `num_windows - 1` windows, we compute the multiple [(k+2)*(2^3)^w]B.
// Here, w ranges from [0..`num_windows - 1`)
for w in 0..(num_windows - 1) {
window_table.push(
(0..H)
.map(|k| {
// scalar = (k+2)*(8^w)
let scalar = C::ScalarExt::from_u64(k as u64 + 2)
* C::ScalarExt::from_u64(H as u64).pow(&[w as u64, 0, 0, 0]);
(base * scalar).to_affine()
})
.collect::<ArrayVec<C, H>>()
.into_inner()
.unwrap(),
);
}
// Generate window table entries for the last window, w = `num_windows - 1`.
// For the last window, we compute [k * (2^3)^w - sum]B, where sum is defined
// as sum = \sum_{j = 0}^{`num_windows - 2`} 2^{3j+1}
let sum = (0..(num_windows - 1)).fold(C::ScalarExt::zero(), |acc, j| {
acc + C::ScalarExt::from_u64(2).pow(&[
FIXED_BASE_WINDOW_SIZE as u64 * j as u64 + 1,
0,
0,
0,
])
});
window_table.push(
(0..H)
.map(|k| {
// scalar = k * (2^3)^w - sum, where w = `num_windows - 1`
let scalar = C::ScalarExt::from_u64(k as u64)
* C::ScalarExt::from_u64(H as u64).pow(&[(num_windows - 1) as u64, 0, 0, 0])
- sum;
(base * scalar).to_affine()
})
.collect::<ArrayVec<C, H>>()
.into_inner()
.unwrap(),
);
window_table
}
/// For each window, we interpolate the $x$-coordinate.
/// Here, we pre-compute and store the coefficients of the interpolation polynomial.
fn compute_lagrange_coeffs<C: CurveAffine>(base: C, num_windows: usize) -> Vec<[C::Base; H]> {
// We are interpolating over the 3-bit window, k \in [0..8)
let points: Vec<_> = (0..H).map(|i| C::Base::from_u64(i as u64)).collect();
let window_table = compute_window_table(base, num_windows);
window_table
.iter()
.map(|window_points| {
let x_window_points: Vec<_> = window_points
.iter()
.map(|point| *point.coordinates().unwrap().x())
.collect();
lagrange_interpolate(&points, &x_window_points)
.into_iter()
.collect::<ArrayVec<C::Base, H>>()
.into_inner()
.unwrap()
})
.collect()
}
/// For each window, $z$ is a field element such that for each point $(x, y)$ in the window:
/// - $z + y = u^2$ (some square in the field); and
/// - $z - y$ is not a square.
/// If successful, return a vector of `(z: u64, us: [C::Base; H])` for each window.
///
/// This function was used to generate the `z`s and `u`s for the Orchard fixed
/// bases. The outputs of this function have been stored as constants, and it
/// is not called anywhere in this codebase. However, we keep this function here
/// as a utility for those who wish to use it with different parameters.
fn find_zs_and_us<C: CurveAffine>(base: C, num_windows: usize) -> Option<Vec<(u64, [C::Base; H])>> {
// Closure to find z and u's for one window
let find_z_and_us = |window_points: &[C]| {
assert_eq!(H, window_points.len());
let ys: Vec<_> = window_points
.iter()
.map(|point| *point.coordinates().unwrap().y())
.collect();
(0..(1000 * (1 << (2 * H)))).find_map(|z| {
ys.iter()
.map(|&y| {
if (-y + C::Base::from_u64(z)).sqrt().is_none().into() {
(y + C::Base::from_u64(z)).sqrt().into()
} else {
None
}
})
.collect::<Option<ArrayVec<C::Base, H>>>()
.map(|us| (z, us.into_inner().unwrap()))
})
};
let window_table = compute_window_table(base, num_windows);
window_table
.iter()
.map(|window_points| find_z_and_us(window_points))
.collect()
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
/// The fixed bases used in the Orchard protocol.
pub enum OrchardFixedBases {
@ -231,73 +110,3 @@ impl FixedPoints<pallas::Affine> for OrchardFixedBases {
}
}
}
#[cfg(test)]
// Test that Lagrange interpolation coefficients reproduce the correct x-coordinate
// for each fixed-base multiple in each window.
fn test_lagrange_coeffs<C: CurveAffine>(base: C, num_windows: usize) {
let lagrange_coeffs = compute_lagrange_coeffs(base, num_windows);
// Check first 84 windows, i.e. `k_0, k_1, ..., k_83`
for (idx, coeffs) in lagrange_coeffs[0..(num_windows - 1)].iter().enumerate() {
// Test each three-bit chunk in this window.
for bits in 0..(1 << FIXED_BASE_WINDOW_SIZE) {
{
// Interpolate the x-coordinate using this window's coefficients
let interpolated_x = super::evaluate::<C>(bits, coeffs);
// Compute the actual x-coordinate of the multiple [(k+2)*(8^w)]B.
let point = base
* C::Scalar::from_u64(bits as u64 + 2)
* C::Scalar::from_u64(H as u64).pow(&[idx as u64, 0, 0, 0]);
let x = *point.to_affine().coordinates().unwrap().x();
// Check that the interpolated x-coordinate matches the actual one.
assert_eq!(x, interpolated_x);
}
}
}
// Check last window.
for bits in 0..(1 << FIXED_BASE_WINDOW_SIZE) {
// Interpolate the x-coordinate using the last window's coefficients
let interpolated_x = super::evaluate::<C>(bits, &lagrange_coeffs[num_windows - 1]);
// Compute the actual x-coordinate of the multiple [k * (8^84) - offset]B,
// where offset = \sum_{j = 0}^{83} 2^{3j+1}
let offset = (0..(num_windows - 1)).fold(C::Scalar::zero(), |acc, w| {
acc + C::Scalar::from_u64(2).pow(&[
FIXED_BASE_WINDOW_SIZE as u64 * w as u64 + 1,
0,
0,
0,
])
});
let scalar = C::Scalar::from_u64(bits as u64)
* C::Scalar::from_u64(H as u64).pow(&[(num_windows - 1) as u64, 0, 0, 0])
- offset;
let point = base * scalar;
let x = *point.to_affine().coordinates().unwrap().x();
// Check that the interpolated x-coordinate matches the actual one.
assert_eq!(x, interpolated_x);
}
}
#[cfg(test)]
// Test that the z-values and u-values satisfy the conditions:
// 1. z + y = u^2,
// 2. z - y is not a square
// for the y-coordinate of each fixed-base multiple in each window.
fn test_zs_and_us<C: CurveAffine>(base: C, z: &[u64], u: &[[[u8; 32]; H]], num_windows: usize) {
let window_table = compute_window_table(base, num_windows);
for ((u, z), window_points) in u.iter().zip(z.iter()).zip(window_table) {
for (u, point) in u.iter().zip(window_points.iter()) {
let y = *point.coordinates().unwrap().y();
let u = C::Base::from_bytes(u).unwrap();
assert_eq!(C::Base::from_u64(*z) + y, u * u); // allow either square root
assert!(bool::from((C::Base::from_u64(*z) - y).sqrt().is_none()));
}
}
}

View File

@ -2952,8 +2952,9 @@ fn test_generator() {
#[cfg(test)]
mod tests {
use super::super::{test_lagrange_coeffs, test_zs_and_us, NUM_WINDOWS};
use super::super::NUM_WINDOWS;
use super::*;
use crate::constants::{test_lagrange_coeffs, test_zs_and_us};
#[test]
fn lagrange_coeffs() {

View File

@ -2952,8 +2952,9 @@ fn test_generator() {
#[cfg(test)]
mod tests {
use super::super::{test_lagrange_coeffs, test_zs_and_us, NUM_WINDOWS};
use super::super::NUM_WINDOWS;
use super::*;
use crate::constants::{test_lagrange_coeffs, test_zs_and_us};
#[test]
fn lagrange_coeffs() {

View File

@ -2933,10 +2933,10 @@ pub fn generator() -> pallas::Affine {
#[cfg(test)]
mod tests {
use super::super::{
test_lagrange_coeffs, test_zs_and_us, NUM_WINDOWS, ORCHARD_PERSONALIZATION,
};
use super::super::{NUM_WINDOWS, ORCHARD_PERSONALIZATION};
use super::*;
use crate::constants::{test_lagrange_coeffs, test_zs_and_us};
use group::Curve;
use pasta_curves::{
arithmetic::{CurveExt, FieldExt},

View File

@ -2934,10 +2934,10 @@ pub fn generator() -> pallas::Affine {
#[cfg(test)]
mod tests {
use super::super::{
test_lagrange_coeffs, test_zs_and_us, NUM_WINDOWS, ORCHARD_PERSONALIZATION,
};
use super::super::{NUM_WINDOWS, ORCHARD_PERSONALIZATION};
use super::*;
use crate::constants::{test_lagrange_coeffs, test_zs_and_us};
use group::Curve;
use pasta_curves::{
arithmetic::{CurveAffine, CurveExt, FieldExt},

View File

@ -2934,10 +2934,10 @@ pub fn generator() -> pallas::Affine {
#[cfg(test)]
mod tests {
use super::super::{
test_lagrange_coeffs, test_zs_and_us, NUM_WINDOWS, VALUE_COMMITMENT_PERSONALIZATION,
};
use super::super::{NUM_WINDOWS, VALUE_COMMITMENT_PERSONALIZATION};
use super::*;
use crate::constants::{test_lagrange_coeffs, test_zs_and_us};
use group::Curve;
use pasta_curves::{
arithmetic::{CurveAffine, CurveExt, FieldExt},

View File

@ -787,10 +787,10 @@ pub fn generator() -> pallas::Affine {
#[cfg(test)]
mod tests {
use super::super::{
test_lagrange_coeffs, test_zs_and_us, NUM_WINDOWS_SHORT, VALUE_COMMITMENT_PERSONALIZATION,
};
use super::super::{NUM_WINDOWS_SHORT, VALUE_COMMITMENT_PERSONALIZATION};
use super::*;
use crate::constants::{test_lagrange_coeffs, test_zs_and_us};
use group::Curve;
use pasta_curves::{
arithmetic::{CurveAffine, CurveExt, FieldExt},

View File

@ -1,34 +0,0 @@
//! Utilities used in the constants module.
use ff::Field;
use halo2::arithmetic::{CurveAffine, FieldExt};
/// Evaluate y = f(x) given the coefficients of f(x)
pub fn evaluate<C: CurveAffine>(x: u8, coeffs: &[C::Base]) -> C::Base {
let x = C::Base::from_u64(x as u64);
coeffs
.iter()
.rev()
.cloned()
.reduce(|acc, coeff| acc * x + coeff)
.unwrap_or_else(C::Base::zero)
}
/// Takes in an FnMut closure and returns a constant-length array with elements of
/// type `Output`.
pub fn gen_const_array<Output: Copy + Default, const LEN: usize>(
closure: impl FnMut(usize) -> Output,
) -> [Output; LEN] {
gen_const_array_with_default(Default::default(), closure)
}
pub(crate) fn gen_const_array_with_default<Output: Copy, const LEN: usize>(
default_value: Output,
mut closure: impl FnMut(usize) -> Output,
) -> [Output; LEN] {
let mut ret: [Output; LEN] = [default_value; LEN];
for (bit, val) in ret.iter_mut().zip((0..LEN).map(|idx| closure(idx))) {
*bit = val;
}
ret
}

View File

@ -11,8 +11,7 @@ use pasta_curves::pallas;
use subtle::{ConditionallySelectable, CtOption};
use crate::constants::{
fixed_bases::COMMIT_IVK_PERSONALIZATION, util::gen_const_array,
KEY_DIVERSIFICATION_PERSONALIZATION, L_ORCHARD_BASE,
fixed_bases::COMMIT_IVK_PERSONALIZATION, KEY_DIVERSIFICATION_PERSONALIZATION, L_ORCHARD_BASE,
};
use poseidon::primitive as poseidon;
use sinsemilla::primitive as sinsemilla;
@ -276,6 +275,18 @@ pub fn i2lebsp<const NUM_BITS: usize>(int: u64) -> [bool; NUM_BITS] {
gen_const_array(|mask: usize| (int & (1 << mask)) != 0)
}
/// Takes in an FnMut closure and returns a constant-length array with elements of
/// type `Output`.
pub fn gen_const_array<Output: Copy + Default, const LEN: usize>(
mut closure: impl FnMut(usize) -> Output,
) -> [Output; LEN] {
let mut ret: [Output; LEN] = [Default::default(); LEN];
for (bit, val) in ret.iter_mut().zip((0..LEN).map(|idx| closure(idx))) {
*bit = val;
}
ret
}
#[cfg(test)]
mod tests {
use super::{i2lebsp, lebs2ip};

View File

@ -3,11 +3,11 @@
use crate::{
constants::{
sinsemilla::{i2lebsp_k, L_ORCHARD_MERKLE, MERKLE_CRH_PERSONALIZATION},
util::gen_const_array_with_default,
MERKLE_DEPTH_ORCHARD,
},
note::commitment::ExtractedNoteCommitment,
};
use halo2_utilities::utilities::gen_const_array_with_default;
use incrementalmerkletree::{Altitude, Hashable};
use pasta_curves::{arithmetic::FieldExt, pallas};
use sinsemilla::primitive::HashDomain;