Use fixed-size array for windows in tables

Co-authored-by: Jack Grigg <jack@electriccoin.co>
This commit is contained in:
therealyingtong 2021-03-24 13:18:00 +08:00
parent d915097407
commit 3381b15cd9
1 changed files with 42 additions and 40 deletions

View File

@ -41,6 +41,9 @@ pub const MERKLE_CRH_PERSONALIZATION: &str = "z.cash:Orchard-MerkleCRH";
/// Window size for fixed-base scalar multiplication
pub const FIXED_BASE_WINDOW_SIZE: usize = 3;
/// 2^{FIXED_BASE_WINDOW_SIZE}
pub const H: usize = 1 << FIXED_BASE_WINDOW_SIZE;
/// Number of windows
pub const NUM_WINDOWS: usize = pallas::Base::NUM_BITS as usize / FIXED_BASE_WINDOW_SIZE;
@ -59,21 +62,11 @@ pub enum OrchardFixedBases<C: CurveAffine> {
impl<C: CurveAffine> std::hash::Hash for OrchardFixedBases<C> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
match *self {
OrchardFixedBases::CommitIvkR(_) => {
state.write(&format!("{:?}", "CommitIvkR").as_bytes())
}
OrchardFixedBases::NoteCommitR(_) => {
state.write(&format!("{:?}", "NoteCommitR").as_bytes())
}
OrchardFixedBases::NullifierK(_) => {
state.write(&format!("{:?}", "NullifierK").as_bytes())
}
OrchardFixedBases::ValueCommitR(_) => {
state.write(&format!("{:?}", "ValueCommitR").as_bytes())
}
OrchardFixedBases::ValueCommitV(_) => {
state.write(&format!("{:?}", "ValueCommitV").as_bytes())
}
OrchardFixedBases::CommitIvkR(_) => state.write(b"CommitIvkR"),
OrchardFixedBases::NoteCommitR(_) => state.write(b"NoteCommitR"),
OrchardFixedBases::NullifierK(_) => state.write(b"NullifierK"),
OrchardFixedBases::ValueCommitR(_) => state.write(b"ValueCommitR"),
OrchardFixedBases::ValueCommitV(_) => state.write(b"ValueCommitV"),
}
}
}
@ -114,11 +107,11 @@ impl<C: CurveAffine> OrchardFixedBase<C> {
pub trait FixedBase<C: CurveAffine> {
/// For each fixed base, we calculate its scalar multiples in three-bit windows.
/// Each window will have 2^3 = 8 points.
fn compute_window_table(&self) -> Vec<Vec<C>>;
fn compute_window_table(&self) -> Vec<[C; H]>;
/// For each window, we interpolate the x-coordinate.
/// Here, we pre-compute and store the coefficients of the interpolation polynomial.
fn compute_lagrange_coeffs(&self) -> Vec<Vec<C::Base>>;
fn compute_lagrange_coeffs(&self) -> Vec<[C::Base; H]>;
/// For each window, z is a field element
/// such that for each point (x, y) in the window:
@ -128,23 +121,26 @@ pub trait FixedBase<C: CurveAffine> {
}
impl<C: CurveAffine> FixedBase<C> for OrchardFixedBase<C> {
fn compute_window_table(&self) -> Vec<Vec<C>> {
let h: usize = 1 << FIXED_BASE_WINDOW_SIZE;
let mut window_table: Vec<Vec<C>> = Vec::with_capacity(NUM_WINDOWS);
fn compute_window_table(&self) -> Vec<[C; H]> {
let mut window_table: Vec<[C; H]> = Vec::with_capacity(NUM_WINDOWS);
// Generate window table entries for all windows but the last.
// For these first 84 windows, we compute the multiple [(k+1)*(8^w)]B.
// Here, w ranges from [0..84)
for w in 0..(NUM_WINDOWS - 1) {
window_table.push(
(0..h)
(0..H)
.map(|k| {
// scalar = (k+1)*(8^w)
let scalar = C::ScalarExt::from_u64(k as u64 + 1)
* C::ScalarExt::from_u64(h as u64).pow(&[w as u64, 0, 0, 0]);
* C::ScalarExt::from_u64(H as u64).pow(&[w as u64, 0, 0, 0]);
(self.0 * scalar).to_affine()
})
.collect(),
.enumerate()
.fold([C::identity(); H], |mut window, (index, entry)| {
window[index] = entry;
window
}),
);
}
@ -152,14 +148,14 @@ impl<C: CurveAffine> FixedBase<C> for OrchardFixedBase<C> {
// For the last window, we compute [k * (8^w) - sum]B, where sum is defined
// as sum = \sum_{j = 0}^{83} 8^j
let sum = (0..(NUM_WINDOWS - 1)).fold(C::ScalarExt::zero(), |acc, w| {
acc + C::ScalarExt::from_u64(h as u64).pow(&[w as u64, 0, 0, 0])
acc + C::ScalarExt::from_u64(H as u64).pow(&[w as u64, 0, 0, 0])
});
window_table.push(
(0..h)
(0..H)
.map(|k| {
// scalar = k * (8^w) - sum, where w = 84
let scalar = C::ScalarExt::from_u64(k as u64)
* C::ScalarExt::from_u64(h as u64).pow(&[
* C::ScalarExt::from_u64(H as u64).pow(&[
(NUM_WINDOWS - 1) as u64,
0,
0,
@ -168,17 +164,19 @@ impl<C: CurveAffine> FixedBase<C> for OrchardFixedBase<C> {
- sum;
(self.0 * scalar).to_affine()
})
.collect(),
.enumerate()
.fold([C::identity(); H], |mut window, (index, entry)| {
window[index] = entry;
window
}),
);
window_table
}
fn compute_lagrange_coeffs(&self) -> Vec<Vec<C::Base>> {
let h: usize = 1 << FIXED_BASE_WINDOW_SIZE;
fn compute_lagrange_coeffs(&self) -> Vec<[C::Base; 8]> {
// We are interpolating over the 3-bit window, k \in [0..8)
let points: Vec<_> = (0..h).map(|i| C::Base::from_u64(i as u64)).collect();
let points: Vec<_> = (0..H).map(|i| C::Base::from_u64(i as u64)).collect();
let window_table = self.compute_window_table();
@ -190,8 +188,14 @@ impl<C: CurveAffine> FixedBase<C> for OrchardFixedBase<C> {
.map(|point| point.get_xy().unwrap().0)
.collect();
lagrange_interpolate(&points, &x_window_points)
.iter()
.enumerate()
.fold([C::Base::default(); H], |mut window, (index, entry)| {
window[index] = *entry;
window
})
})
.collect::<Vec<Vec<_>>>()
.collect()
}
/// For each window, z is a field element
@ -201,8 +205,7 @@ impl<C: CurveAffine> FixedBase<C> for OrchardFixedBase<C> {
fn find_zs(&self) -> Option<Vec<u64>> {
// Closure to find z for one window
let find_z = |window_points: &[C]| {
let h: usize = 1 << FIXED_BASE_WINDOW_SIZE;
assert_eq!(h, window_points.len());
assert_eq!(H, window_points.len());
let ys: Vec<_> = window_points
.iter()
@ -214,8 +217,8 @@ impl<C: CurveAffine> FixedBase<C> for OrchardFixedBase<C> {
(sum_y_is_square && !sum_neg_y_is_square) as usize
};
for z in 0..(1000 * (1 << (2 * h))) {
if ys.iter().map(|y| z_for_single_y(*y, z)).sum::<usize>() == h {
for z in 0..(1000 * (1 << (2 * H))) {
if ys.iter().map(|y| z_for_single_y(*y, z)).sum::<usize>() == H {
return Some(z);
}
}
@ -238,7 +241,6 @@ pub trait TestFixedBase<C: CurveAffine> {
impl<C: CurveAffine> TestFixedBase<C> for OrchardFixedBase<C> {
fn test_lagrange_coeffs(&self) {
let h = 1 << FIXED_BASE_WINDOW_SIZE;
let lagrange_coeffs = self.compute_lagrange_coeffs();
let mut points = Vec::<C::CurveExt>::with_capacity(NUM_WINDOWS);
@ -260,7 +262,7 @@ impl<C: CurveAffine> TestFixedBase<C> for OrchardFixedBase<C> {
// [(k+1)*(8^w)]B
let point = self.0
* C::Scalar::from_u64(*bits as u64 + 1)
* C::Scalar::from_u64(h as u64).pow(&[idx as u64, 0, 0, 0]);
* C::Scalar::from_u64(H as u64).pow(&[idx as u64, 0, 0, 0]);
let x = point.to_affine().get_xy().unwrap().0;
assert_eq!(x, interpolated_x);
@ -274,10 +276,10 @@ impl<C: CurveAffine> TestFixedBase<C> for OrchardFixedBase<C> {
// [k * (8^w) - offset]B, where offset = \sum_{j = 0}^{83} 8^j
let offset = (0..(NUM_WINDOWS - 1)).fold(C::Scalar::zero(), |acc, w| {
acc + C::Scalar::from_u64(h as u64).pow(&[w as u64, 0, 0, 0])
acc + C::Scalar::from_u64(H as u64).pow(&[w as u64, 0, 0, 0])
});
let scalar = C::Scalar::from_u64(last_bits as u64)
* C::Scalar::from_u64(h as u64).pow(&[(NUM_WINDOWS - 1) as u64, 0, 0, 0])
* C::Scalar::from_u64(H as u64).pow(&[(NUM_WINDOWS - 1) as u64, 0, 0, 0])
- offset;
let point = self.0 * scalar;
let x = point.to_affine().get_xy().unwrap().0;