diff --git a/src/constants.rs b/src/constants.rs index 59f4b62a..1a695ad3 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -41,6 +41,9 @@ pub const MERKLE_CRH_PERSONALIZATION: &str = "z.cash:Orchard-MerkleCRH"; /// Window size for fixed-base scalar multiplication pub const FIXED_BASE_WINDOW_SIZE: usize = 3; +/// 2^{FIXED_BASE_WINDOW_SIZE} +pub const H: usize = 1 << FIXED_BASE_WINDOW_SIZE; + /// Number of windows pub const NUM_WINDOWS: usize = pallas::Base::NUM_BITS as usize / FIXED_BASE_WINDOW_SIZE; @@ -59,21 +62,11 @@ pub enum OrchardFixedBases { impl std::hash::Hash for OrchardFixedBases { fn hash(&self, state: &mut H) { match *self { - OrchardFixedBases::CommitIvkR(_) => { - state.write(&format!("{:?}", "CommitIvkR").as_bytes()) - } - OrchardFixedBases::NoteCommitR(_) => { - state.write(&format!("{:?}", "NoteCommitR").as_bytes()) - } - OrchardFixedBases::NullifierK(_) => { - state.write(&format!("{:?}", "NullifierK").as_bytes()) - } - OrchardFixedBases::ValueCommitR(_) => { - state.write(&format!("{:?}", "ValueCommitR").as_bytes()) - } - OrchardFixedBases::ValueCommitV(_) => { - state.write(&format!("{:?}", "ValueCommitV").as_bytes()) - } + OrchardFixedBases::CommitIvkR(_) => state.write(b"CommitIvkR"), + OrchardFixedBases::NoteCommitR(_) => state.write(b"NoteCommitR"), + OrchardFixedBases::NullifierK(_) => state.write(b"NullifierK"), + OrchardFixedBases::ValueCommitR(_) => state.write(b"ValueCommitR"), + OrchardFixedBases::ValueCommitV(_) => state.write(b"ValueCommitV"), } } } @@ -114,11 +107,11 @@ impl OrchardFixedBase { pub trait FixedBase { /// For each fixed base, we calculate its scalar multiples in three-bit windows. /// Each window will have 2^3 = 8 points. - fn compute_window_table(&self) -> Vec>; + fn compute_window_table(&self) -> Vec<[C; H]>; /// For each window, we interpolate the x-coordinate. /// Here, we pre-compute and store the coefficients of the interpolation polynomial. - fn compute_lagrange_coeffs(&self) -> Vec>; + fn compute_lagrange_coeffs(&self) -> Vec<[C::Base; H]>; /// For each window, z is a field element /// such that for each point (x, y) in the window: @@ -128,23 +121,26 @@ pub trait FixedBase { } impl FixedBase for OrchardFixedBase { - fn compute_window_table(&self) -> Vec> { - let h: usize = 1 << FIXED_BASE_WINDOW_SIZE; - let mut window_table: Vec> = Vec::with_capacity(NUM_WINDOWS); + fn compute_window_table(&self) -> Vec<[C; H]> { + let mut window_table: Vec<[C; H]> = Vec::with_capacity(NUM_WINDOWS); // Generate window table entries for all windows but the last. // For these first 84 windows, we compute the multiple [(k+1)*(8^w)]B. // Here, w ranges from [0..84) for w in 0..(NUM_WINDOWS - 1) { window_table.push( - (0..h) + (0..H) .map(|k| { // scalar = (k+1)*(8^w) let scalar = C::ScalarExt::from_u64(k as u64 + 1) - * C::ScalarExt::from_u64(h as u64).pow(&[w as u64, 0, 0, 0]); + * C::ScalarExt::from_u64(H as u64).pow(&[w as u64, 0, 0, 0]); (self.0 * scalar).to_affine() }) - .collect(), + .enumerate() + .fold([C::identity(); H], |mut window, (index, entry)| { + window[index] = entry; + window + }), ); } @@ -152,14 +148,14 @@ impl FixedBase for OrchardFixedBase { // For the last window, we compute [k * (8^w) - sum]B, where sum is defined // as sum = \sum_{j = 0}^{83} 8^j let sum = (0..(NUM_WINDOWS - 1)).fold(C::ScalarExt::zero(), |acc, w| { - acc + C::ScalarExt::from_u64(h as u64).pow(&[w as u64, 0, 0, 0]) + acc + C::ScalarExt::from_u64(H as u64).pow(&[w as u64, 0, 0, 0]) }); window_table.push( - (0..h) + (0..H) .map(|k| { // scalar = k * (8^w) - sum, where w = 84 let scalar = C::ScalarExt::from_u64(k as u64) - * C::ScalarExt::from_u64(h as u64).pow(&[ + * C::ScalarExt::from_u64(H as u64).pow(&[ (NUM_WINDOWS - 1) as u64, 0, 0, @@ -168,17 +164,19 @@ impl FixedBase for OrchardFixedBase { - sum; (self.0 * scalar).to_affine() }) - .collect(), + .enumerate() + .fold([C::identity(); H], |mut window, (index, entry)| { + window[index] = entry; + window + }), ); window_table } - fn compute_lagrange_coeffs(&self) -> Vec> { - let h: usize = 1 << FIXED_BASE_WINDOW_SIZE; - + fn compute_lagrange_coeffs(&self) -> Vec<[C::Base; 8]> { // We are interpolating over the 3-bit window, k \in [0..8) - let points: Vec<_> = (0..h).map(|i| C::Base::from_u64(i as u64)).collect(); + let points: Vec<_> = (0..H).map(|i| C::Base::from_u64(i as u64)).collect(); let window_table = self.compute_window_table(); @@ -190,8 +188,14 @@ impl FixedBase for OrchardFixedBase { .map(|point| point.get_xy().unwrap().0) .collect(); lagrange_interpolate(&points, &x_window_points) + .iter() + .enumerate() + .fold([C::Base::default(); H], |mut window, (index, entry)| { + window[index] = *entry; + window + }) }) - .collect::>>() + .collect() } /// For each window, z is a field element @@ -201,8 +205,7 @@ impl FixedBase for OrchardFixedBase { fn find_zs(&self) -> Option> { // Closure to find z for one window let find_z = |window_points: &[C]| { - let h: usize = 1 << FIXED_BASE_WINDOW_SIZE; - assert_eq!(h, window_points.len()); + assert_eq!(H, window_points.len()); let ys: Vec<_> = window_points .iter() @@ -214,8 +217,8 @@ impl FixedBase for OrchardFixedBase { (sum_y_is_square && !sum_neg_y_is_square) as usize }; - for z in 0..(1000 * (1 << (2 * h))) { - if ys.iter().map(|y| z_for_single_y(*y, z)).sum::() == h { + for z in 0..(1000 * (1 << (2 * H))) { + if ys.iter().map(|y| z_for_single_y(*y, z)).sum::() == H { return Some(z); } } @@ -238,7 +241,6 @@ pub trait TestFixedBase { impl TestFixedBase for OrchardFixedBase { fn test_lagrange_coeffs(&self) { - let h = 1 << FIXED_BASE_WINDOW_SIZE; let lagrange_coeffs = self.compute_lagrange_coeffs(); let mut points = Vec::::with_capacity(NUM_WINDOWS); @@ -260,7 +262,7 @@ impl TestFixedBase for OrchardFixedBase { // [(k+1)*(8^w)]B let point = self.0 * C::Scalar::from_u64(*bits as u64 + 1) - * C::Scalar::from_u64(h as u64).pow(&[idx as u64, 0, 0, 0]); + * C::Scalar::from_u64(H as u64).pow(&[idx as u64, 0, 0, 0]); let x = point.to_affine().get_xy().unwrap().0; assert_eq!(x, interpolated_x); @@ -274,10 +276,10 @@ impl TestFixedBase for OrchardFixedBase { // [k * (8^w) - offset]B, where offset = \sum_{j = 0}^{83} 8^j let offset = (0..(NUM_WINDOWS - 1)).fold(C::Scalar::zero(), |acc, w| { - acc + C::Scalar::from_u64(h as u64).pow(&[w as u64, 0, 0, 0]) + acc + C::Scalar::from_u64(H as u64).pow(&[w as u64, 0, 0, 0]) }); let scalar = C::Scalar::from_u64(last_bits as u64) - * C::Scalar::from_u64(h as u64).pow(&[(NUM_WINDOWS - 1) as u64, 0, 0, 0]) + * C::Scalar::from_u64(H as u64).pow(&[(NUM_WINDOWS - 1) as u64, 0, 0, 0]) - offset; let point = self.0 * scalar; let x = point.to_affine().get_xy().unwrap().0;