mirror of https://github.com/zcash/orchard.git
Add poseidon::Spec::State associated type
We reuse this type for the per-round round constants, and rows of the MDS, to provide some type-level same-length guarantees. Once we can use const generics, these will all be replaced by [F; Spec::ARITY].
This commit is contained in:
parent
5c8e9beea7
commit
6548666e37
|
@ -13,6 +13,11 @@ use grain::SboxType;
|
|||
|
||||
/// A specification for a Poseidon permutation.
|
||||
pub trait Spec<F: FieldExt> {
|
||||
/// The type used to hold permutation state, or equivalent-length constant values.
|
||||
///
|
||||
/// This must be an array of length [`Spec::arity`], that defaults to all-zeroes.
|
||||
type State: Default + AsRef<[F]> + AsMut<[F]>;
|
||||
|
||||
/// The arity of this specification.
|
||||
fn arity() -> usize;
|
||||
|
||||
|
@ -33,7 +38,7 @@ pub trait Spec<F: FieldExt> {
|
|||
fn secure_mds(&self) -> usize;
|
||||
|
||||
/// Generates `(round_constants, mds, mds^-1)` corresponding to this specification.
|
||||
fn constants(&self) -> (Vec<Vec<F>>, Vec<Vec<F>>, Vec<Vec<F>>) {
|
||||
fn constants(&self) -> (Vec<Self::State>, Vec<Self::State>, Vec<Self::State>) {
|
||||
let t = Self::arity();
|
||||
let r_f = Self::full_rounds();
|
||||
let r_p = Self::partial_rounds();
|
||||
|
@ -41,12 +46,43 @@ pub trait Spec<F: FieldExt> {
|
|||
let mut grain = grain::Grain::new(SboxType::Pow, t as u16, r_f as u16, r_p as u16);
|
||||
|
||||
let round_constants = (0..(r_f + r_p))
|
||||
.map(|_| (0..t).map(|_| grain.next_field_element()).collect())
|
||||
.map(|_| {
|
||||
let mut rc_row = Self::State::default();
|
||||
for (rc, value) in rc_row
|
||||
.as_mut()
|
||||
.iter_mut()
|
||||
.zip((0..t).map(|_| grain.next_field_element()))
|
||||
{
|
||||
*rc = value;
|
||||
}
|
||||
rc_row
|
||||
})
|
||||
.collect();
|
||||
|
||||
let (mds, mds_inv) = mds::generate_mds(&mut grain, t, self.secure_mds());
|
||||
|
||||
(round_constants, mds, mds_inv)
|
||||
(
|
||||
round_constants,
|
||||
mds.into_iter()
|
||||
.map(|row| {
|
||||
let mut mds_row = Self::State::default();
|
||||
for (entry, value) in mds_row.as_mut().iter_mut().zip(row.into_iter()) {
|
||||
*entry = value;
|
||||
}
|
||||
mds_row
|
||||
})
|
||||
.collect(),
|
||||
mds_inv
|
||||
.into_iter()
|
||||
.map(|row| {
|
||||
let mut mds_row = Self::State::default();
|
||||
for (entry, value) in mds_row.as_mut().iter_mut().zip(row.into_iter()) {
|
||||
*entry = value;
|
||||
}
|
||||
mds_row
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -67,6 +103,8 @@ impl<F: FieldExt> P256Pow5T3<F> {
|
|||
}
|
||||
|
||||
impl<F: FieldExt> Spec<F> for P256Pow5T3<F> {
|
||||
type State = [F; 3];
|
||||
|
||||
fn arity() -> usize {
|
||||
3
|
||||
}
|
||||
|
@ -89,47 +127,46 @@ impl<F: FieldExt> Spec<F> for P256Pow5T3<F> {
|
|||
}
|
||||
|
||||
/// Runs the Poseidon permutation on the given state.
|
||||
fn permute<F: FieldExt, S: Spec<F>>(state: &mut [F], mds: &[Vec<F>], round_constants: &[Vec<F>]) {
|
||||
// TODO: Remove this when we can use const generics.
|
||||
assert!(state.len() == S::arity());
|
||||
|
||||
fn permute<F: FieldExt, S: Spec<F>>(
|
||||
state: &mut S::State,
|
||||
mds: &[S::State],
|
||||
round_constants: &[S::State],
|
||||
) {
|
||||
// TODO: Check what should happen for odd number of full rounds.
|
||||
let r_f = S::full_rounds() / 2;
|
||||
let r_p = S::partial_rounds();
|
||||
|
||||
let apply_mds = |state: &mut [F]| {
|
||||
let new_state: Vec<_> = mds
|
||||
.iter()
|
||||
.map(|mds_row| {
|
||||
mds_row
|
||||
.iter()
|
||||
.zip(state.iter())
|
||||
.fold(F::zero(), |acc, (mds, word)| acc + *mds * *word)
|
||||
})
|
||||
.collect();
|
||||
for (word, new_word) in state.iter_mut().zip(new_state.into_iter()) {
|
||||
*word = new_word;
|
||||
let apply_mds = |state: &mut S::State| {
|
||||
let mut new_state = S::State::default();
|
||||
// Matrix multiplication
|
||||
for i in 0..S::arity() {
|
||||
for j in 0..S::arity() {
|
||||
new_state.as_mut()[i] += mds[i].as_ref()[j] * state.as_ref()[j];
|
||||
}
|
||||
}
|
||||
*state = new_state;
|
||||
};
|
||||
|
||||
let full_round = |state: &mut [F], rcs: &[F]| {
|
||||
for (word, rc) in state.iter_mut().zip(rcs.iter()) {
|
||||
let full_round = |state: &mut S::State, rcs: &S::State| {
|
||||
for (word, rc) in state.as_mut().iter_mut().zip(rcs.as_ref().iter()) {
|
||||
*word = S::sbox(*word + rc);
|
||||
}
|
||||
apply_mds(state);
|
||||
};
|
||||
|
||||
let part_round = |state: &mut [F], rcs: &[F]| {
|
||||
for (word, rc) in state.iter_mut().zip(rcs.iter()) {
|
||||
let part_round = |state: &mut S::State, rcs: &S::State| {
|
||||
for (word, rc) in state.as_mut().iter_mut().zip(rcs.as_ref().iter()) {
|
||||
*word += rc;
|
||||
}
|
||||
state[0] = S::sbox(state[0]);
|
||||
// In a partial round, the S-box is only applied to the first state word.
|
||||
state.as_mut()[0] = S::sbox(state.as_ref()[0]);
|
||||
apply_mds(state);
|
||||
};
|
||||
|
||||
iter::empty()
|
||||
.chain(iter::repeat(&full_round as &dyn Fn(&mut [F], &[F])).take(r_f))
|
||||
.chain(iter::repeat(&part_round as &dyn Fn(&mut [F], &[F])).take(r_p))
|
||||
.chain(iter::repeat(&full_round as &dyn Fn(&mut [F], &[F])).take(r_f))
|
||||
.chain(iter::repeat(&full_round as &dyn Fn(&mut S::State, &S::State)).take(r_f))
|
||||
.chain(iter::repeat(&part_round as &dyn Fn(&mut S::State, &S::State)).take(r_p))
|
||||
.chain(iter::repeat(&full_round as &dyn Fn(&mut S::State, &S::State)).take(r_f))
|
||||
.zip(round_constants.iter())
|
||||
.fold(state, |state, (round, rcs)| {
|
||||
round(state, rcs);
|
||||
|
@ -156,24 +193,25 @@ enum SpongeState<F: FieldExt> {
|
|||
/// A Poseidon duplex sponge.
|
||||
pub struct Duplex<F: FieldExt, S: Spec<F>> {
|
||||
sponge: Option<SpongeState<F>>,
|
||||
state: Vec<F>,
|
||||
state: S::State,
|
||||
rate: usize,
|
||||
mds_matrix: Vec<Vec<F>>,
|
||||
round_constants: Vec<Vec<F>>,
|
||||
mds_matrix: Vec<S::State>,
|
||||
round_constants: Vec<S::State>,
|
||||
_marker: PhantomData<S>,
|
||||
}
|
||||
|
||||
impl<F: FieldExt, S: Spec<F>> Duplex<F, S> {
|
||||
/// Constructs a new duplex sponge with the given rate.
|
||||
pub fn new(spec: S, rate: usize) -> Self {
|
||||
// The sponge capacity must be at least 1.
|
||||
// TODO: Construct the capacity from the specification's security level.
|
||||
assert!(rate < S::arity());
|
||||
|
||||
let state = vec![F::zero(); S::arity()];
|
||||
let (round_constants, mds_matrix, _) = spec.constants();
|
||||
|
||||
Duplex {
|
||||
sponge: Some(SpongeState::Absorbing(vec![])),
|
||||
state,
|
||||
state: S::State::default(),
|
||||
rate,
|
||||
mds_matrix,
|
||||
round_constants,
|
||||
|
@ -182,11 +220,11 @@ impl<F: FieldExt, S: Spec<F>> Duplex<F, S> {
|
|||
}
|
||||
|
||||
fn process(&mut self, input: &[F]) -> Vec<F> {
|
||||
pad_and_add(&mut self.state[..self.rate], input);
|
||||
pad_and_add(&mut self.state.as_mut()[..self.rate], input);
|
||||
|
||||
permute::<F, S>(&mut self.state, &self.mds_matrix, &self.round_constants);
|
||||
|
||||
self.state[..self.rate].to_vec()
|
||||
self.state.as_ref()[..self.rate].to_vec()
|
||||
}
|
||||
|
||||
/// Absorbs an element into the sponge.
|
||||
|
|
|
@ -429,12 +429,18 @@ fn test_vectors() {
|
|||
|
||||
for (actual, expected) in round_constants
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.map(|f| {
|
||||
let mut bytes = f.to_bytes();
|
||||
bytes.reverse();
|
||||
format!("0x{}", hex::encode(&bytes))
|
||||
.map(|round| {
|
||||
round
|
||||
.as_ref()
|
||||
.iter()
|
||||
.map(|f| {
|
||||
let mut bytes = f.to_bytes();
|
||||
bytes.reverse();
|
||||
format!("0x{}", hex::encode(&bytes))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.flatten()
|
||||
.zip(ROUND_CONSTANTS.iter())
|
||||
{
|
||||
assert_eq!(&actual, expected);
|
||||
|
@ -442,12 +448,17 @@ fn test_vectors() {
|
|||
|
||||
for (actual, expected) in mds
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.map(|f| {
|
||||
let mut bytes = f.to_bytes();
|
||||
bytes.reverse();
|
||||
format!("0x{}", hex::encode(&bytes))
|
||||
.map(|row| {
|
||||
row.as_ref()
|
||||
.iter()
|
||||
.map(|f| {
|
||||
let mut bytes = f.to_bytes();
|
||||
bytes.reverse();
|
||||
format!("0x{}", hex::encode(&bytes))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.flatten()
|
||||
.zip(MDS.iter().flatten())
|
||||
{
|
||||
assert_eq!(&actual, expected);
|
||||
|
|
Loading…
Reference in New Issue