From 061ad0656b49ef2f9a11062c9fea64145bf07768 Mon Sep 17 00:00:00 2001 From: Jack Grigg Date: Fri, 26 Mar 2021 08:56:25 +1300 Subject: [PATCH] Refactor Poseidon primitive to use const generics --- src/primitives/poseidon.rs | 214 +++++++++++++-------------- src/primitives/poseidon/mds.rs | 37 +++-- src/primitives/poseidon/nullifier.rs | 34 ++--- 3 files changed, 131 insertions(+), 154 deletions(-) diff --git a/src/primitives/poseidon.rs b/src/primitives/poseidon.rs index f2e04935..b80ec0b8 100644 --- a/src/primitives/poseidon.rs +++ b/src/primitives/poseidon.rs @@ -11,23 +11,17 @@ pub use nullifier::OrchardNullifier; use grain::SboxType; +/// The type used to hold permutation state. +pub type State = [F; T]; + +/// The type used to hold duplex sponge state. +pub type SpongeState = [Option; RATE]; + +/// The type used to hold the MDS matrix and its inverse. +pub type Mds = [[F; T]; T]; + /// A specification for a Poseidon permutation. -pub trait Spec { - /// The type used to hold permutation state, or equivalent-length constant values. - /// - /// This must be an array of length [`Spec::width`], that defaults to all-zeroes. - type State: Default + AsRef<[F]> + AsMut<[F]>; - - /// The type used to hold duplex sponge state. - /// - /// This must be an array of length equal to the rate of the duplex sponge (allowing - /// for a capacity consistent with this specification's security level), that defaults - /// to `[None; RATE]`. - type Rate: Default + AsRef<[Option]> + AsMut<[Option]>; - - /// The width of this specification. - fn width() -> usize; - +pub trait Spec { /// The number of full rounds for this specification. /// /// This must be an even number. @@ -47,20 +41,18 @@ pub trait Spec { fn secure_mds(&self) -> usize; /// Generates `(round_constants, mds, mds^-1)` corresponding to this specification. - fn constants(&self) -> (Vec, Vec, Vec) { - let t = Self::width(); + fn constants(&self) -> (Vec<[F; T]>, Mds, Mds) { let r_f = Self::full_rounds(); let r_p = Self::partial_rounds(); - let mut grain = grain::Grain::new(SboxType::Pow, t as u16, r_f as u16, r_p as u16); + let mut grain = grain::Grain::new(SboxType::Pow, T as u16, r_f as u16, r_p as u16); let round_constants = (0..(r_f + r_p)) .map(|_| { - let mut rc_row = Self::State::default(); + let mut rc_row = [F::zero(); T]; for (rc, value) in rc_row - .as_mut() .iter_mut() - .zip((0..t).map(|_| grain.next_field_element())) + .zip((0..T).map(|_| grain.next_field_element())) { *rc = value; } @@ -68,74 +60,53 @@ pub trait Spec { }) .collect(); - let (mds, mds_inv) = mds::generate_mds(&mut grain, t, self.secure_mds()); + let (mds, mds_inv) = mds::generate_mds::(&mut grain, self.secure_mds()); - ( - round_constants, - mds.into_iter() - .map(|row| { - let mut mds_row = Self::State::default(); - for (entry, value) in mds_row.as_mut().iter_mut().zip(row.into_iter()) { - *entry = value; - } - mds_row - }) - .collect(), - mds_inv - .into_iter() - .map(|row| { - let mut mds_row = Self::State::default(); - for (entry, value) in mds_row.as_mut().iter_mut().zip(row.into_iter()) { - *entry = value; - } - mds_row - }) - .collect(), - ) + (round_constants, mds, mds_inv) } } /// Runs the Poseidon permutation on the given state. -fn permute>( - state: &mut S::State, - mds: &[S::State], - round_constants: &[S::State], +fn permute, const T: usize, const RATE: usize>( + state: &mut State, + mds: &Mds, + round_constants: &[[F; T]], ) { let r_f = S::full_rounds() / 2; let r_p = S::partial_rounds(); - let apply_mds = |state: &mut S::State| { - let mut new_state = S::State::default(); + let apply_mds = |state: &mut State| { + let mut new_state = [F::zero(); T]; // Matrix multiplication #[allow(clippy::needless_range_loop)] - for i in 0..S::width() { - for j in 0..S::width() { - new_state.as_mut()[i] += mds[i].as_ref()[j] * state.as_ref()[j]; + for i in 0..T { + for j in 0..T { + new_state[i] += mds[i][j] * state[j]; } } *state = new_state; }; - let full_round = |state: &mut S::State, rcs: &S::State| { - for (word, rc) in state.as_mut().iter_mut().zip(rcs.as_ref().iter()) { + let full_round = |state: &mut State, rcs: &[F; T]| { + for (word, rc) in state.iter_mut().zip(rcs.iter()) { *word = S::sbox(*word + rc); } apply_mds(state); }; - let part_round = |state: &mut S::State, rcs: &S::State| { - for (word, rc) in state.as_mut().iter_mut().zip(rcs.as_ref().iter()) { + let part_round = |state: &mut State, rcs: &[F; T]| { + for (word, rc) in state.iter_mut().zip(rcs.iter()) { *word += rc; } // In a partial round, the S-box is only applied to the first state word. - state.as_mut()[0] = S::sbox(state.as_ref()[0]); + state[0] = S::sbox(state[0]); apply_mds(state); }; iter::empty() - .chain(iter::repeat(&full_round as &dyn Fn(&mut S::State, &S::State)).take(r_f)) - .chain(iter::repeat(&part_round as &dyn Fn(&mut S::State, &S::State)).take(r_p)) - .chain(iter::repeat(&full_round as &dyn Fn(&mut S::State, &S::State)).take(r_f)) + .chain(iter::repeat(&full_round as &dyn Fn(&mut State, &[F; T])).take(r_f)) + .chain(iter::repeat(&part_round as &dyn Fn(&mut State, &[F; T])).take(r_p)) + .chain(iter::repeat(&full_round as &dyn Fn(&mut State, &[F; T])).take(r_f)) .zip(round_constants.iter()) .fold(state, |state, (round, rcs)| { round(state, rcs); @@ -143,62 +114,62 @@ fn permute>( }); } -fn poseidon_duplex>( - state: &mut S::State, - input: &S::Rate, - pad_and_add: &dyn Fn(&mut S::State, &S::Rate), - mds_matrix: &[S::State], - round_constants: &[S::State], -) -> S::Rate { +fn poseidon_duplex, const T: usize, const RATE: usize>( + state: &mut State, + input: &SpongeState, + pad_and_add: &dyn Fn(&mut State, &SpongeState), + mds_matrix: &Mds, + round_constants: &[[F; T]], +) -> SpongeState { pad_and_add(state, input); - permute::(state, mds_matrix, round_constants); + permute::(state, mds_matrix, round_constants); - let mut output = S::Rate::default(); - for (word, value) in output.as_mut().iter_mut().zip(state.as_ref().iter()) { + let mut output = [None; RATE]; + for (word, value) in output.iter_mut().zip(state.iter()) { *word = Some(*value); } output } -enum SpongeState> { - Absorbing(S::Rate), - Squeezing(S::Rate), +enum Sponge { + Absorbing(SpongeState), + Squeezing(SpongeState), } -impl> SpongeState { +impl Sponge { fn absorb(val: F) -> Self { - let mut input = S::Rate::default(); - input.as_mut()[0] = Some(val); - SpongeState::Absorbing(input) + let mut input = [None; RATE]; + input[0] = Some(val); + Sponge::Absorbing(input) } } /// A Poseidon duplex sponge. -pub struct Duplex> { - sponge: SpongeState, - state: S::State, - pad_and_add: Box, - mds_matrix: Vec, - round_constants: Vec, +pub struct Duplex, const T: usize, const RATE: usize> { + sponge: Sponge, + state: State, + pad_and_add: Box, &SpongeState)>, + mds_matrix: Mds, + round_constants: Vec<[F; T]>, _marker: PhantomData, } -impl> Duplex { +impl, const T: usize, const RATE: usize> Duplex { /// Constructs a new duplex sponge for the given Poseidon specification. pub fn new( spec: S, initial_capacity_element: F, - pad_and_add: Box, + pad_and_add: Box, &SpongeState)>, ) -> Self { let (round_constants, mds_matrix, _) = spec.constants(); - let input = S::Rate::default(); - let mut state = S::State::default(); - state.as_mut()[input.as_ref().len()] = initial_capacity_element; + let input = [None; RATE]; + let mut state = [F::zero(); T]; + state[RATE] = initial_capacity_element; Duplex { - sponge: SpongeState::Absorbing(input), + sponge: Sponge::Absorbing(input), state, pad_and_add, mds_matrix, @@ -210,8 +181,8 @@ impl> Duplex { /// Absorbs an element into the sponge. pub fn absorb(&mut self, value: F) { match self.sponge { - SpongeState::Absorbing(ref mut input) => { - for entry in input.as_mut().iter_mut() { + Sponge::Absorbing(ref mut input) => { + for entry in input.iter_mut() { if entry.is_none() { *entry = Some(value); return; @@ -219,18 +190,18 @@ impl> Duplex { } // We've already absorbed as many elements as we can - let _ = poseidon_duplex::( + let _ = poseidon_duplex::( &mut self.state, &input, &self.pad_and_add, &self.mds_matrix, &self.round_constants, ); - self.sponge = SpongeState::absorb(value); + self.sponge = Sponge::absorb(value); } - SpongeState::Squeezing(_) => { + Sponge::Squeezing(_) => { // Drop the remaining output elements - self.sponge = SpongeState::absorb(value); + self.sponge = Sponge::absorb(value); } } } @@ -239,8 +210,8 @@ impl> Duplex { pub fn squeeze(&mut self) -> F { loop { match self.sponge { - SpongeState::Absorbing(ref input) => { - self.sponge = SpongeState::Squeezing(poseidon_duplex::( + Sponge::Absorbing(ref input) => { + self.sponge = Sponge::Squeezing(poseidon_duplex::( &mut self.state, &input, &self.pad_and_add, @@ -248,15 +219,15 @@ impl> Duplex { &self.round_constants, )); } - SpongeState::Squeezing(ref mut output) => { - for entry in output.as_mut().iter_mut() { + Sponge::Squeezing(ref mut output) => { + for entry in output.iter_mut() { if let Some(e) = entry.take() { return e; } } // We've already squeezed out all available elements - self.sponge = SpongeState::Absorbing(S::Rate::default()); + self.sponge = Sponge::Absorbing([None; RATE]); } } } @@ -264,13 +235,15 @@ impl> Duplex { } /// A domain in which a Poseidon hash function is being used. -pub trait Domain>: Copy { +pub trait Domain, const T: usize, const RATE: usize>: + Copy +{ /// The initial capacity element, encoding this domain. fn initial_capacity_element(&self) -> F; /// Returns a function that will update the given state with the given input to a /// duplex permutation round, applying padding according to this domain specification. - fn pad_and_add(&self) -> Box; + fn pad_and_add(&self) -> Box, &SpongeState)>; } /// A Poseidon hash function used with constant input length. @@ -279,18 +252,20 @@ pub trait Domain>: Copy { #[derive(Clone, Copy, Debug)] pub struct ConstantLength(pub usize); -impl> Domain for ConstantLength { +impl, const T: usize, const RATE: usize> Domain + for ConstantLength +{ fn initial_capacity_element(&self) -> F { // Capacity value is $length \cdot 2^64 + (o-1)$ where o is the output length. // We hard-code an output length of 1. F::from_u128((self.0 as u128) << 64) } - fn pad_and_add(&self) -> Box { + fn pad_and_add(&self) -> Box, &SpongeState)> { Box::new(|state, input| { // `Iterator::zip` short-circuits when one iterator completes, so this will only // mutate the rate portion of the state. - for (word, value) in state.as_mut().iter_mut().zip(input.as_ref().iter()) { + for (word, value) in state.iter_mut().zip(input.iter()) { // For constant-input-length hashing, padding consists of the field // elements being zero, so we don't add anything to the state. if let Some(value) = value { @@ -302,12 +277,25 @@ impl> Domain for ConstantLength { } /// A Poseidon hash function, built around a duplex sponge. -pub struct Hash, D: Domain> { - duplex: Duplex, +pub struct Hash< + F: FieldExt, + S: Spec, + D: Domain, + const T: usize, + const RATE: usize, +> { + duplex: Duplex, domain: D, } -impl, D: Domain> Hash { +impl< + F: FieldExt, + S: Spec, + D: Domain, + const T: usize, + const RATE: usize, + > Hash +{ /// Initializes a new hasher. pub fn init(spec: S, domain: D) -> Self { Hash { @@ -321,7 +309,9 @@ impl, D: Domain> Hash { } } -impl> Hash { +impl, const T: usize, const RATE: usize> + Hash +{ /// Hashes the given input. /// /// # Panics @@ -357,7 +347,7 @@ mod tests { // The result should be equivalent to just directly applying the permutation and // taking the first state element as the output. let mut state = [message[0], message[1], pallas::Base::from_u128(2 << 64)]; - permute::<_, OrchardNullifier>(&mut state, &mds, &round_constants); + permute::<_, OrchardNullifier, 3, 2>(&mut state, &mds, &round_constants); assert_eq!(state[0], result); } } diff --git a/src/primitives/poseidon/mds.rs b/src/primitives/poseidon/mds.rs index 65b0b5c9..62ecb312 100644 --- a/src/primitives/poseidon/mds.rs +++ b/src/primitives/poseidon/mds.rs @@ -1,16 +1,15 @@ use halo2::arithmetic::FieldExt; -use super::grain::Grain; +use super::{grain::Grain, Mds}; -pub(super) fn generate_mds( +pub(super) fn generate_mds( grain: &mut Grain, - width: usize, mut select: usize, -) -> (Vec>, Vec>) { +) -> (Mds, Mds) { let (xs, ys, mds) = loop { - // Generate two [F; width] arrays of unique field elements. + // Generate two [F; T] arrays of unique field elements. let (xs, ys) = loop { - let mut vals: Vec<_> = (0..2 * width) + let mut vals: Vec<_> = (0..2 * T) .map(|_| grain.next_field_element_without_rejection()) .collect(); @@ -19,7 +18,7 @@ pub(super) fn generate_mds( unique.sort_unstable(); unique.dedup(); if vals.len() == unique.len() { - let rhs = vals.split_off(width); + let rhs = vals.split_off(T); break (vals, rhs); } }; @@ -49,10 +48,10 @@ pub(super) fn generate_mds( // However, the Poseidon paper and reference impl use the positive formulation, // and we want to rely on the reference impl for MDS security, so we use the same // formulation. - let mut mds = vec![vec![F::zero(); width]; width]; + let mut mds = [[F::zero(); T]; T]; #[allow(clippy::needless_range_loop)] - for i in 0..width { - for j in 0..width { + for i in 0..T { + for j in 0..T { let sum = xs[i] + ys[j]; // We leverage the secure MDS selection counter to also check this. assert!(!sum.is_zero()); @@ -75,7 +74,7 @@ pub(super) fn generate_mds( // where A_i(x) and B_i(x) are the Lagrange polynomials for xs and ys respectively. // // We adapt this to the positive Cauchy formulation by negating ys. - let mut mds_inv = vec![vec![F::zero(); width]; width]; + let mut mds_inv = [[F::zero(); T]; T]; let l = |xs: &[F], j, x: F| { let x_j = xs[j]; xs.iter().enumerate().fold(F::one(), |acc, (m, x_m)| { @@ -88,8 +87,8 @@ pub(super) fn generate_mds( }) }; let neg_ys: Vec<_> = ys.iter().map(|y| -*y).collect(); - for i in 0..width { - for j in 0..width { + for i in 0..T { + for j in 0..T { mds_inv[i][j] = (xs[j] - neg_ys[i]) * l(&xs, j, neg_ys[i]) * l(&neg_ys, i, xs[j]); } } @@ -105,17 +104,17 @@ mod tests { #[test] fn poseidon_mds() { - let width = 3; - let mut grain = Grain::new(super::super::grain::SboxType::Pow, width as u16, 8, 56); - let (mds, mds_inv) = generate_mds::(&mut grain, width, 0); + const T: usize = 3; + let mut grain = Grain::new(super::super::grain::SboxType::Pow, T as u16, 8, 56); + let (mds, mds_inv) = generate_mds::(&mut grain, 0); // Verify that MDS * MDS^-1 = I. #[allow(clippy::needless_range_loop)] - for i in 0..width { - for j in 0..width { + for i in 0..T { + for j in 0..T { let expected = if i == j { Fp::one() } else { Fp::zero() }; assert_eq!( - (0..width).fold(Fp::zero(), |acc, k| acc + (mds[i][k] * mds_inv[k][j])), + (0..T).fold(Fp::zero(), |acc, k| acc + (mds[i][k] * mds_inv[k][j])), expected ); } diff --git a/src/primitives/poseidon/nullifier.rs b/src/primitives/poseidon/nullifier.rs index d77b4771..67603577 100644 --- a/src/primitives/poseidon/nullifier.rs +++ b/src/primitives/poseidon/nullifier.rs @@ -1,7 +1,7 @@ use halo2::arithmetic::Field; use pasta_curves::pallas; -use super::Spec; +use super::{Mds, Spec}; /// Poseidon-128 using the $x^5$ S-box, with a width of 3 field elements, and an extra /// partial round compared to the standard specification. @@ -11,14 +11,7 @@ use super::Spec; #[derive(Debug)] pub struct OrchardNullifier; -impl Spec for OrchardNullifier { - type State = [pallas::Base; 3]; - type Rate = [Option; 2]; - - fn width() -> usize { - 3 - } - +impl Spec for OrchardNullifier { fn full_rounds() -> usize { 8 } @@ -35,12 +28,14 @@ impl Spec for OrchardNullifier { unimplemented!() } - fn constants(&self) -> (Vec, Vec, Vec) { - let round_constants = ROUND_CONSTANTS[..].to_vec(); - let mds = MDS[..].to_vec(); - let mds_inv = MDS_INV[..].to_vec(); - - (round_constants, mds, mds_inv) + fn constants( + &self, + ) -> ( + Vec<[pallas::Base; 3]>, + Mds, + Mds, + ) { + (ROUND_CONSTANTS[..].to_vec(), MDS, MDS_INV) } } @@ -1536,14 +1531,7 @@ mod tests { } } - impl Spec for P128Pow5T3Plus { - type State = [F; 3]; - type Rate = [Option; 2]; - - fn width() -> usize { - 3 - } - + impl Spec for P128Pow5T3Plus { fn full_rounds() -> usize { 8 }