diff --git a/src/circuit/gadget/poseidon.rs b/src/circuit/gadget/poseidon.rs index 135dfccb..771a6188 100644 --- a/src/circuit/gadget/poseidon.rs +++ b/src/circuit/gadget/poseidon.rs @@ -3,6 +3,7 @@ use std::array; use std::convert::TryInto; use std::fmt; +use std::marker::PhantomData; use halo2::{ arithmetic::FieldExt, @@ -13,7 +14,9 @@ use halo2::{ mod pow5; pub use pow5::{Pow5Chip, Pow5Config, StateWord}; -use crate::primitives::poseidon::{ConstantLength, Domain, Spec, Sponge, SpongeRate, State}; +use crate::primitives::poseidon::{ + Absorbing, ConstantLength, Domain, Spec, SpongeMode, SpongeRate, Squeezing, State, +}; /// The set of circuit instructions required to use the Poseidon permutation. pub trait PoseidonInstructions, const T: usize, const RATE: usize>: @@ -30,10 +33,10 @@ pub trait PoseidonInstructions, const T: usize, ) -> Result, Error>; } -/// The set of circuit instructions required to use the [`Duplex`] and [`Hash`] gadgets. +/// The set of circuit instructions required to use the [`Sponge`] and [`Hash`] gadgets. /// /// [`Hash`]: self::Hash -pub trait PoseidonDuplexInstructions< +pub trait PoseidonSpongeInstructions< F: FieldExt, S: Spec, const T: usize, @@ -91,9 +94,9 @@ impl< } } -fn poseidon_duplex< +fn poseidon_sponge< F: FieldExt, - PoseidonChip: PoseidonDuplexInstructions, + PoseidonChip: PoseidonSpongeInstructions, S: Spec, D: Domain, const T: usize, @@ -110,30 +113,32 @@ fn poseidon_duplex< Ok(PoseidonChip::get_output(state)) } -/// A Poseidon duplex sponge. +/// A Poseidon sponge. #[derive(Debug)] -pub struct Duplex< +pub struct Sponge< F: FieldExt, - PoseidonChip: PoseidonDuplexInstructions, + PoseidonChip: PoseidonSpongeInstructions, S: Spec, + M: SpongeMode, D: Domain, const T: usize, const RATE: usize, > { chip: PoseidonChip, - sponge: Sponge, + mode: M, state: State, domain: D, + _marker: PhantomData, } impl< F: FieldExt, - PoseidonChip: PoseidonDuplexInstructions, + PoseidonChip: PoseidonSpongeInstructions, S: Spec, D: Domain, const T: usize, const RATE: usize, - > Duplex + > Sponge, D, T, RATE> { /// Constructs a new duplex sponge for the given Poseidon specification. pub fn new( @@ -142,9 +147,9 @@ impl< domain: D, ) -> Result { chip.initial_state(&mut layouter, &domain) - .map(|state| Duplex { + .map(|state| Sponge { chip, - sponge: Sponge::Absorbing( + mode: Absorbing( (0..RATE) .map(|_| None) .collect::>() @@ -153,6 +158,7 @@ impl< ), state, domain, + _marker: PhantomData::default(), }) } @@ -162,84 +168,97 @@ impl< mut layouter: impl Layouter, value: AssignedCell, ) -> Result<(), Error> { - match self.sponge { - Sponge::Absorbing(ref mut input) => { - for entry in input.iter_mut() { - if entry.is_none() { - *entry = Some(value.into()); - return Ok(()); - } - } - - // We've already absorbed as many elements as we can - let _ = poseidon_duplex( - &self.chip, - layouter.namespace(|| "PoseidonDuplex"), - &self.domain, - &mut self.state, - input, - )?; - self.sponge = Sponge::absorb(value.into()); - } - Sponge::Squeezing(_) => { - // Drop the remaining output elements - self.sponge = Sponge::absorb(value.into()); + for entry in self.mode.0.iter_mut() { + if entry.is_none() { + *entry = Some(value.into()); + return Ok(()); } } + // We've already absorbed as many elements as we can + let _ = poseidon_sponge( + &self.chip, + layouter.namespace(|| "PoseidonSponge"), + &self.domain, + &mut self.state, + &self.mode.0, + )?; + self.mode = Absorbing::init_with(value.into()); + Ok(()) } + /// Transitions the sponge into its squeezing state. + #[allow(clippy::type_complexity)] + pub fn finish_absorbing( + mut self, + mut layouter: impl Layouter, + ) -> Result, D, T, RATE>, Error> + { + let mode = Squeezing(poseidon_sponge( + &self.chip, + layouter.namespace(|| "PoseidonSponge"), + &self.domain, + &mut self.state, + &self.mode.0, + )?); + + Ok(Sponge { + chip: self.chip, + mode, + state: self.state, + domain: self.domain, + _marker: PhantomData::default(), + }) + } +} + +impl< + F: FieldExt, + PoseidonChip: PoseidonSpongeInstructions, + S: Spec, + D: Domain, + const T: usize, + const RATE: usize, + > Sponge, D, T, RATE> +{ /// Squeezes an element from the sponge. pub fn squeeze(&mut self, mut layouter: impl Layouter) -> Result, Error> { loop { - match self.sponge { - Sponge::Absorbing(ref input) => { - self.sponge = Sponge::Squeezing(poseidon_duplex( - &self.chip, - layouter.namespace(|| "PoseidonDuplex"), - &self.domain, - &mut self.state, - input, - )?); - } - Sponge::Squeezing(ref mut output) => { - for entry in output.iter_mut() { - if let Some(inner) = entry.take() { - return Ok(inner.into()); - } - } - - // We've already squeezed out all available elements - self.sponge = Sponge::Absorbing( - (0..RATE) - .map(|_| None) - .collect::>() - .try_into() - .unwrap(), - ); + for entry in self.mode.0.iter_mut() { + if let Some(inner) = entry.take() { + return Ok(inner.into()); } } + + // We've already squeezed out all available elements + self.mode = Squeezing(poseidon_sponge( + &self.chip, + layouter.namespace(|| "PoseidonSponge"), + &self.domain, + &mut self.state, + &self.mode.0, + )?); } } } -/// A Poseidon hash function, built around a duplex sponge. +/// A Poseidon hash function, built around a sponge. #[derive(Debug)] pub struct Hash< F: FieldExt, - PoseidonChip: PoseidonDuplexInstructions, + PoseidonChip: PoseidonSpongeInstructions, S: Spec, D: Domain, const T: usize, const RATE: usize, > { - duplex: Duplex, + sponge: Sponge, D, T, RATE>, } impl< F: FieldExt, - PoseidonChip: PoseidonDuplexInstructions, + PoseidonChip: PoseidonSpongeInstructions, S: Spec, D: Domain, const T: usize, @@ -248,13 +267,13 @@ impl< { /// Initializes a new hasher. pub fn init(chip: PoseidonChip, layouter: impl Layouter, domain: D) -> Result { - Duplex::new(chip, layouter, domain).map(|duplex| Hash { duplex }) + Sponge::new(chip, layouter, domain).map(|sponge| Hash { sponge }) } } impl< F: FieldExt, - PoseidonChip: PoseidonDuplexInstructions, + PoseidonChip: PoseidonSpongeInstructions, S: Spec, const T: usize, const RATE: usize, @@ -268,9 +287,11 @@ impl< message: [AssignedCell; L], ) -> Result, Error> { for (i, value) in array::IntoIter::new(message).enumerate() { - self.duplex + self.sponge .absorb(layouter.namespace(|| format!("absorb_{}", i)), value)?; } - self.duplex.squeeze(layouter.namespace(|| "squeeze")) + self.sponge + .finish_absorbing(layouter.namespace(|| "finish absorbing"))? + .squeeze(layouter.namespace(|| "squeeze")) } } diff --git a/src/circuit/gadget/poseidon/pow5.rs b/src/circuit/gadget/poseidon/pow5.rs index 0c509352..1880e1af 100644 --- a/src/circuit/gadget/poseidon/pow5.rs +++ b/src/circuit/gadget/poseidon/pow5.rs @@ -8,7 +8,7 @@ use halo2::{ poly::Rotation, }; -use super::{PoseidonDuplexInstructions, PoseidonInstructions}; +use super::{PoseidonInstructions, PoseidonSpongeInstructions}; use crate::circuit::gadget::utilities::Var; use crate::primitives::poseidon::{Domain, Mds, Spec, SpongeRate, State}; @@ -269,7 +269,7 @@ impl, const WIDTH: usize, const RATE: usize } impl, const WIDTH: usize, const RATE: usize> - PoseidonDuplexInstructions for Pow5Chip + PoseidonSpongeInstructions for Pow5Chip { fn initial_state( &self, diff --git a/src/primitives/poseidon.rs b/src/primitives/poseidon.rs index df54d2c0..10933404 100644 --- a/src/primitives/poseidon.rs +++ b/src/primitives/poseidon.rs @@ -24,7 +24,7 @@ use grain::SboxType; /// The type used to hold permutation state. pub(crate) type State = [F; T]; -/// The type used to hold duplex sponge state. +/// The type used to hold sponge rate. pub(crate) type SpongeRate = [Option; RATE]; /// The type used to hold the MDS matrix and its inverse. @@ -124,7 +124,7 @@ pub(crate) fn permute, const T: usize, const RA }); } -fn poseidon_duplex, const T: usize, const RATE: usize>( +fn poseidon_sponge, const T: usize, const RATE: usize>( state: &mut State, input: &SpongeRate, pad_and_add: &dyn Fn(&mut State, &SpongeRate), @@ -142,48 +142,65 @@ fn poseidon_duplex, const T: usize, const RATE: output } -#[derive(Debug)] -pub(crate) enum Sponge { - Absorbing(SpongeRate), - Squeezing(SpongeRate), -} +/// The state of the [`Sponge`]. +// TODO: Seal this trait? +pub trait SpongeMode {} -impl Sponge { - pub(crate) fn absorb(val: F) -> Self { - let mut input: [Option; RATE] = (0..RATE) - .map(|_| None) - .collect::>() - .try_into() - .unwrap(); - input[0] = Some(val); - Sponge::Absorbing(input) +/// The absorbing state of the [`Sponge`]. +#[derive(Debug)] +pub struct Absorbing(pub(crate) SpongeRate); + +/// The squeezing state of the [`Sponge`]. +#[derive(Debug)] +pub struct Squeezing(pub(crate) SpongeRate); + +impl SpongeMode for Absorbing {} +impl SpongeMode for Squeezing {} + +impl Absorbing { + pub(crate) fn init_with(val: F) -> Self { + Self( + iter::once(Some(val)) + .chain((1..RATE).map(|_| None)) + .collect::>() + .try_into() + .unwrap(), + ) } } -/// A Poseidon duplex sponge. -pub(crate) struct Duplex, const T: usize, const RATE: usize> { - sponge: Sponge, +/// A Poseidon sponge. +pub(crate) struct Sponge< + F: FieldExt, + S: Spec, + M: SpongeMode, + const T: usize, + const RATE: usize, +> { + mode: M, state: State, pad_and_add: Box, &SpongeRate)>, mds_matrix: Mds, round_constants: Vec<[F; T]>, - _marker: PhantomData, + _marker: PhantomData<(S, M)>, } -impl, const T: usize, const RATE: usize> Duplex { - /// Constructs a new duplex sponge for the given Poseidon specification. +impl, const T: usize, const RATE: usize> + Sponge, T, RATE> +{ + /// Constructs a new sponge for the given Poseidon specification. pub(crate) fn new( initial_capacity_element: F, pad_and_add: Box, &SpongeRate)>, ) -> Self { let (round_constants, mds_matrix, _) = S::constants(); - let input = [None; RATE]; + let mode = Absorbing([None; RATE]); let mut state = [F::zero(); T]; state[RATE] = initial_capacity_element; - Duplex { - sponge: Sponge::Absorbing(input), + Sponge { + mode, state, pad_and_add, mds_matrix, @@ -194,56 +211,65 @@ impl, const T: usize, const RATE: usize> Duplex /// Absorbs an element into the sponge. pub(crate) fn absorb(&mut self, value: F) { - match self.sponge { - Sponge::Absorbing(ref mut input) => { - for entry in input.iter_mut() { - if entry.is_none() { - *entry = Some(value); - return; - } - } - - // We've already absorbed as many elements as we can - let _ = poseidon_duplex::( - &mut self.state, - input, - &self.pad_and_add, - &self.mds_matrix, - &self.round_constants, - ); - self.sponge = Sponge::absorb(value); - } - Sponge::Squeezing(_) => { - // Drop the remaining output elements - self.sponge = Sponge::absorb(value); + for entry in self.mode.0.iter_mut() { + if entry.is_none() { + *entry = Some(value); + return; } } + + // We've already absorbed as many elements as we can + let _ = poseidon_sponge::( + &mut self.state, + &self.mode.0, + &self.pad_and_add, + &self.mds_matrix, + &self.round_constants, + ); + self.mode = Absorbing::init_with(value); } + /// Transitions the sponge into its squeezing state. + pub(crate) fn finish_absorbing(mut self) -> Sponge, T, RATE> { + let mode = Squeezing(poseidon_sponge::( + &mut self.state, + &self.mode.0, + &self.pad_and_add, + &self.mds_matrix, + &self.round_constants, + )); + + Sponge { + mode, + state: self.state, + pad_and_add: self.pad_and_add, + mds_matrix: self.mds_matrix, + round_constants: self.round_constants, + _marker: PhantomData::default(), + } + } +} + +impl, const T: usize, const RATE: usize> + Sponge, T, RATE> +{ /// Squeezes an element from the sponge. pub(crate) fn squeeze(&mut self) -> F { loop { - match self.sponge { - Sponge::Absorbing(ref input) => { - self.sponge = Sponge::Squeezing(poseidon_duplex::( - &mut self.state, - input, - &self.pad_and_add, - &self.mds_matrix, - &self.round_constants, - )); - } - Sponge::Squeezing(ref mut output) => { - for entry in output.iter_mut() { - if let Some(e) = entry.take() { - return e; - } - } - - // We've already squeezed out all available elements - self.sponge = Sponge::Absorbing([None; RATE]); + for entry in self.mode.0.iter_mut() { + if let Some(e) = entry.take() { + return e; } } + + // We've already squeezed out all available elements + self.mode = Squeezing(poseidon_sponge::( + &mut self.state, + &self.mode.0, + &self.pad_and_add, + &self.mds_matrix, + &self.round_constants, + )); } } } @@ -301,7 +327,7 @@ impl Domain, @@ -309,7 +335,7 @@ pub struct Hash< const T: usize, const RATE: usize, > { - duplex: Duplex, + sponge: Sponge, T, RATE>, domain: D, } @@ -343,7 +369,7 @@ impl< /// Initializes a new hasher. pub fn init(domain: D) -> Self { Hash { - duplex: Duplex::new(domain.initial_capacity_element(), domain.pad_and_add()), + sponge: Sponge::new(domain.initial_capacity_element(), domain.pad_and_add()), domain, } } @@ -355,9 +381,9 @@ impl, const T: usize, const RATE: usize, const /// Hashes the given input. pub fn hash(mut self, message: [F; L]) -> F { for value in array::IntoIter::new(message) { - self.duplex.absorb(value); + self.sponge.absorb(value); } - self.duplex.squeeze() + self.sponge.finish_absorbing().squeeze() } }