diff --git a/src/primitives/poseidon.rs b/src/primitives/poseidon.rs index ab698b8c..e9eac85a 100644 --- a/src/primitives/poseidon.rs +++ b/src/primitives/poseidon.rs @@ -186,16 +186,11 @@ fn permute>( fn poseidon_duplex>( state: &mut S::State, input: &S::Rate, + pad_and_add: &dyn Fn(&mut S::State, &S::Rate), mds_matrix: &[S::State], round_constants: &[S::State], ) -> S::Rate { - // `Iterator::zip` short-circuits when one iterator completes, so this will only - // mutate the rate portion of the state. - for (word, value) in state.as_mut().iter_mut().zip(input.as_ref().iter()) { - // TODO: Decide on a padding strategy, if we ever need to use Poseidon with - // incomplete state input. - *word += value.unwrap(); - } + pad_and_add(state, input); permute::(state, mds_matrix, round_constants); @@ -223,6 +218,7 @@ impl> SpongeState { pub struct Duplex> { sponge: SpongeState, state: S::State, + pad_and_add: Box, mds_matrix: Vec, round_constants: Vec, _marker: PhantomData, @@ -230,12 +226,21 @@ pub struct Duplex> { impl> Duplex { /// Constructs a new duplex sponge for the given Poseidon specification. - pub fn new(spec: S) -> Self { + pub fn new( + spec: S, + initial_capacity_element: F, + pad_and_add: Box, + ) -> Self { let (round_constants, mds_matrix, _) = spec.constants(); + let input = S::Rate::default(); + let mut state = S::State::default(); + state.as_mut()[input.as_ref().len()] = initial_capacity_element; + Duplex { - sponge: SpongeState::Absorbing(S::Rate::default()), - state: S::State::default(), + sponge: SpongeState::Absorbing(input), + state, + pad_and_add, mds_matrix, round_constants, _marker: PhantomData::default(), @@ -257,6 +262,7 @@ impl> Duplex { let _ = poseidon_duplex::( &mut self.state, &input, + &self.pad_and_add, &self.mds_matrix, &self.round_constants, ); @@ -277,6 +283,7 @@ impl> Duplex { self.sponge = SpongeState::Squeezing(poseidon_duplex::( &mut self.state, &input, + &self.pad_and_add, &self.mds_matrix, &self.round_constants, )); @@ -296,23 +303,102 @@ impl> Duplex { } } -/// A Poseidon hash function, built around a duplex sponge. -pub struct Hash>(Duplex); +/// A domain in which a Poseidon hash function is being used. +pub trait Domain>: Copy { + /// The initial capacity element, encoding this domain. + fn initial_capacity_element(&self) -> F; -impl> Hash { - /// Initializes a new hasher. - pub fn init(spec: S) -> Self { - Hash(Duplex::new(spec)) + /// Returns a function that will update the given state with the given input to a + /// duplex permutation round, applying padding according to this domain specification. + fn pad_and_add(&self) -> Box; +} + +/// A Poseidon hash function used with constant input length. +/// +/// Domain specified in section 4.2 of https://eprint.iacr.org/2019/458.pdf +#[derive(Clone, Copy, Debug)] +pub struct ConstantLength(pub usize); + +impl> Domain for ConstantLength { + fn initial_capacity_element(&self) -> F { + // Capacity value is $length \cdot 2^64 + (o-1)$ where o is the output length. + // We hard-code an output length of 1. + F::from_u128((self.0 as u128) << 64) } - /// Updates the hasher with the given value. - pub fn update(&mut self, value: F) { - self.0.absorb(value); - } - - /// Finalizes the hasher, returning its output. - pub fn finalize(mut self) -> F { - // TODO: Check which state element other implementations use. - self.0.squeeze() + fn pad_and_add(&self) -> Box { + Box::new(|state, input| { + // `Iterator::zip` short-circuits when one iterator completes, so this will only + // mutate the rate portion of the state. + for (word, value) in state.as_mut().iter_mut().zip(input.as_ref().iter()) { + // For constant-input-length hashing, padding consists of the field + // elements being zero, so we don't add anything to the state. + if let Some(value) = value { + *word += value; + } + } + }) + } +} + +/// A Poseidon hash function, built around a duplex sponge. +pub struct Hash, D: Domain> { + duplex: Duplex, + domain: D, +} + +impl, D: Domain> Hash { + /// Initializes a new hasher. + pub fn init(spec: S, domain: D) -> Self { + Hash { + duplex: Duplex::new( + spec, + domain.initial_capacity_element(), + domain.pad_and_add(), + ), + domain, + } + } +} + +impl> Hash { + /// Hashes the given input. + /// + /// # Panics + /// + /// Panics if the message length is not the correct length. + pub fn hash(mut self, message: impl Iterator) -> F { + let mut length = 0; + for (i, value) in message.enumerate() { + length = i + 1; + self.duplex.absorb(value); + } + assert_eq!(length, self.domain.0); + self.duplex.squeeze() + } +} + +#[cfg(test)] +mod tests { + use halo2::arithmetic::FieldExt; + use pasta_curves::pallas; + + use super::{permute, ConstantLength, Hash, P256Pow5T3, Spec}; + + #[test] + fn orchard_spec_equivalence() { + let message = [pallas::Base::from_u64(6), pallas::Base::from_u64(42)]; + + let spec = P256Pow5T3::::new(0); + let (round_constants, mds, _) = spec.constants(); + + let hasher = Hash::init(spec, ConstantLength(2)); + let result = hasher.hash(message.iter().cloned()); + + // The result should be equivalent to just directly applying the permutation and + // taking the first state element as the output. + let mut state = [message[0], message[1], pallas::Base::from_u128(2 << 64)]; + permute::>(&mut state, &mds, &round_constants); + assert_eq!(state[0], result); } }