Add poseidon::Spec::State associated type

We reuse this type for the per-round round constants, and rows of the
MDS, to provide some type-level same-length guarantees. Once we can use
const generics, these will all be replaced by [F; Spec::ARITY].
This commit is contained in:
Jack Grigg 2021-03-12 06:16:49 +13:00
parent 5c8e9beea7
commit 6548666e37
2 changed files with 93 additions and 44 deletions

View File

@ -13,6 +13,11 @@ use grain::SboxType;
/// A specification for a Poseidon permutation.
pub trait Spec<F: FieldExt> {
/// The type used to hold permutation state, or equivalent-length constant values.
///
/// This must be an array of length [`Spec::arity`], that defaults to all-zeroes.
type State: Default + AsRef<[F]> + AsMut<[F]>;
/// The arity of this specification.
fn arity() -> usize;
@ -33,7 +38,7 @@ pub trait Spec<F: FieldExt> {
fn secure_mds(&self) -> usize;
/// Generates `(round_constants, mds, mds^-1)` corresponding to this specification.
fn constants(&self) -> (Vec<Vec<F>>, Vec<Vec<F>>, Vec<Vec<F>>) {
fn constants(&self) -> (Vec<Self::State>, Vec<Self::State>, Vec<Self::State>) {
let t = Self::arity();
let r_f = Self::full_rounds();
let r_p = Self::partial_rounds();
@ -41,12 +46,43 @@ pub trait Spec<F: FieldExt> {
let mut grain = grain::Grain::new(SboxType::Pow, t as u16, r_f as u16, r_p as u16);
let round_constants = (0..(r_f + r_p))
.map(|_| (0..t).map(|_| grain.next_field_element()).collect())
.map(|_| {
let mut rc_row = Self::State::default();
for (rc, value) in rc_row
.as_mut()
.iter_mut()
.zip((0..t).map(|_| grain.next_field_element()))
{
*rc = value;
}
rc_row
})
.collect();
let (mds, mds_inv) = mds::generate_mds(&mut grain, t, self.secure_mds());
(round_constants, mds, mds_inv)
(
round_constants,
mds.into_iter()
.map(|row| {
let mut mds_row = Self::State::default();
for (entry, value) in mds_row.as_mut().iter_mut().zip(row.into_iter()) {
*entry = value;
}
mds_row
})
.collect(),
mds_inv
.into_iter()
.map(|row| {
let mut mds_row = Self::State::default();
for (entry, value) in mds_row.as_mut().iter_mut().zip(row.into_iter()) {
*entry = value;
}
mds_row
})
.collect(),
)
}
}
@ -67,6 +103,8 @@ impl<F: FieldExt> P256Pow5T3<F> {
}
impl<F: FieldExt> Spec<F> for P256Pow5T3<F> {
type State = [F; 3];
fn arity() -> usize {
3
}
@ -89,47 +127,46 @@ impl<F: FieldExt> Spec<F> for P256Pow5T3<F> {
}
/// Runs the Poseidon permutation on the given state.
fn permute<F: FieldExt, S: Spec<F>>(state: &mut [F], mds: &[Vec<F>], round_constants: &[Vec<F>]) {
// TODO: Remove this when we can use const generics.
assert!(state.len() == S::arity());
fn permute<F: FieldExt, S: Spec<F>>(
state: &mut S::State,
mds: &[S::State],
round_constants: &[S::State],
) {
// TODO: Check what should happen for odd number of full rounds.
let r_f = S::full_rounds() / 2;
let r_p = S::partial_rounds();
let apply_mds = |state: &mut [F]| {
let new_state: Vec<_> = mds
.iter()
.map(|mds_row| {
mds_row
.iter()
.zip(state.iter())
.fold(F::zero(), |acc, (mds, word)| acc + *mds * *word)
})
.collect();
for (word, new_word) in state.iter_mut().zip(new_state.into_iter()) {
*word = new_word;
let apply_mds = |state: &mut S::State| {
let mut new_state = S::State::default();
// Matrix multiplication
for i in 0..S::arity() {
for j in 0..S::arity() {
new_state.as_mut()[i] += mds[i].as_ref()[j] * state.as_ref()[j];
}
}
*state = new_state;
};
let full_round = |state: &mut [F], rcs: &[F]| {
for (word, rc) in state.iter_mut().zip(rcs.iter()) {
let full_round = |state: &mut S::State, rcs: &S::State| {
for (word, rc) in state.as_mut().iter_mut().zip(rcs.as_ref().iter()) {
*word = S::sbox(*word + rc);
}
apply_mds(state);
};
let part_round = |state: &mut [F], rcs: &[F]| {
for (word, rc) in state.iter_mut().zip(rcs.iter()) {
let part_round = |state: &mut S::State, rcs: &S::State| {
for (word, rc) in state.as_mut().iter_mut().zip(rcs.as_ref().iter()) {
*word += rc;
}
state[0] = S::sbox(state[0]);
// In a partial round, the S-box is only applied to the first state word.
state.as_mut()[0] = S::sbox(state.as_ref()[0]);
apply_mds(state);
};
iter::empty()
.chain(iter::repeat(&full_round as &dyn Fn(&mut [F], &[F])).take(r_f))
.chain(iter::repeat(&part_round as &dyn Fn(&mut [F], &[F])).take(r_p))
.chain(iter::repeat(&full_round as &dyn Fn(&mut [F], &[F])).take(r_f))
.chain(iter::repeat(&full_round as &dyn Fn(&mut S::State, &S::State)).take(r_f))
.chain(iter::repeat(&part_round as &dyn Fn(&mut S::State, &S::State)).take(r_p))
.chain(iter::repeat(&full_round as &dyn Fn(&mut S::State, &S::State)).take(r_f))
.zip(round_constants.iter())
.fold(state, |state, (round, rcs)| {
round(state, rcs);
@ -156,24 +193,25 @@ enum SpongeState<F: FieldExt> {
/// A Poseidon duplex sponge.
pub struct Duplex<F: FieldExt, S: Spec<F>> {
sponge: Option<SpongeState<F>>,
state: Vec<F>,
state: S::State,
rate: usize,
mds_matrix: Vec<Vec<F>>,
round_constants: Vec<Vec<F>>,
mds_matrix: Vec<S::State>,
round_constants: Vec<S::State>,
_marker: PhantomData<S>,
}
impl<F: FieldExt, S: Spec<F>> Duplex<F, S> {
/// Constructs a new duplex sponge with the given rate.
pub fn new(spec: S, rate: usize) -> Self {
// The sponge capacity must be at least 1.
// TODO: Construct the capacity from the specification's security level.
assert!(rate < S::arity());
let state = vec![F::zero(); S::arity()];
let (round_constants, mds_matrix, _) = spec.constants();
Duplex {
sponge: Some(SpongeState::Absorbing(vec![])),
state,
state: S::State::default(),
rate,
mds_matrix,
round_constants,
@ -182,11 +220,11 @@ impl<F: FieldExt, S: Spec<F>> Duplex<F, S> {
}
fn process(&mut self, input: &[F]) -> Vec<F> {
pad_and_add(&mut self.state[..self.rate], input);
pad_and_add(&mut self.state.as_mut()[..self.rate], input);
permute::<F, S>(&mut self.state, &self.mds_matrix, &self.round_constants);
self.state[..self.rate].to_vec()
self.state.as_ref()[..self.rate].to_vec()
}
/// Absorbs an element into the sponge.

View File

@ -429,12 +429,18 @@ fn test_vectors() {
for (actual, expected) in round_constants
.into_iter()
.flatten()
.map(|f| {
let mut bytes = f.to_bytes();
bytes.reverse();
format!("0x{}", hex::encode(&bytes))
.map(|round| {
round
.as_ref()
.iter()
.map(|f| {
let mut bytes = f.to_bytes();
bytes.reverse();
format!("0x{}", hex::encode(&bytes))
})
.collect::<Vec<_>>()
})
.flatten()
.zip(ROUND_CONSTANTS.iter())
{
assert_eq!(&actual, expected);
@ -442,12 +448,17 @@ fn test_vectors() {
for (actual, expected) in mds
.into_iter()
.flatten()
.map(|f| {
let mut bytes = f.to_bytes();
bytes.reverse();
format!("0x{}", hex::encode(&bytes))
.map(|row| {
row.as_ref()
.iter()
.map(|f| {
let mut bytes = f.to_bytes();
bytes.reverse();
format!("0x{}", hex::encode(&bytes))
})
.collect::<Vec<_>>()
})
.flatten()
.zip(MDS.iter().flatten())
{
assert_eq!(&actual, expected);