Poseidon specification and constants

This commit is contained in:
Jack Grigg 2021-02-08 22:56:32 +00:00 committed by Jack Grigg
parent 3911fb3202
commit 84907c50e1
4 changed files with 346 additions and 0 deletions

View File

@ -4,5 +4,6 @@
// - EphemeralPublicKey
// - EphemeralSecretKey
pub(crate) mod poseidon;
pub mod redpallas;
pub(crate) mod sinsemilla;

View File

@ -0,0 +1,84 @@
use std::marker::PhantomData;
use halo2::arithmetic::FieldExt;
pub(crate) mod grain;
pub(crate) mod mds;
use grain::SboxType;
/// A specification for a Poseidon permutation.
pub trait PoseidonSpec<F: FieldExt> {
/// The arity of this specification.
fn arity(&self) -> usize;
/// The number of full rounds for this specification.
fn full_rounds(&self) -> usize;
/// The number of partial rounds for this specification.
fn partial_rounds(&self) -> usize;
/// Generates `(round_constants, mds, mds^-1)` corresponding to this specification.
fn constants(&self) -> (Vec<Vec<F>>, Vec<Vec<F>>, Vec<Vec<F>>);
}
/// A little-endian Poseidon specification.
#[derive(Debug)]
pub struct LsbPoseidon<F: FieldExt> {
sbox: SboxType,
/// The arity of the Poseidon permutation.
t: u16,
/// The number of full rounds.
r_f: u16,
/// The number of partial rounds.
r_p: u16,
/// The index of the first secure MDS matrix that will be generated for the given
/// parameters.
secure_mds: usize,
_field: PhantomData<F>,
}
impl<F: FieldExt> LsbPoseidon<F> {
/// Creates a new Poseidon specification for a field, using the `x^\alpha` S-box.
pub fn with_pow_sbox(
arity: usize,
full_rounds: usize,
partial_rounds: usize,
secure_mds: usize,
) -> Self {
LsbPoseidon {
sbox: SboxType::Pow,
t: arity as u16,
r_f: full_rounds as u16,
r_p: partial_rounds as u16,
secure_mds,
_field: PhantomData::default(),
}
}
}
impl<F: FieldExt> PoseidonSpec<F> for LsbPoseidon<F> {
fn arity(&self) -> usize {
self.t as usize
}
fn full_rounds(&self) -> usize {
self.r_f as usize
}
fn partial_rounds(&self) -> usize {
self.r_p as usize
}
fn constants(&self) -> (Vec<Vec<F>>, Vec<Vec<F>>, Vec<Vec<F>>) {
let mut grain = grain::Grain::new(self.sbox, self.t, self.r_f, self.r_p);
let round_constants = (0..(self.r_f + self.r_p))
.map(|_| (0..self.t).map(|_| grain.next_field_element()).collect())
.collect();
let (mds, mds_inv) = mds::generate_mds(&mut grain, self.t as usize, self.secure_mds);
(round_constants, mds, mds_inv)
}
}

View File

@ -0,0 +1,155 @@
//! The Grain LFSR in self-shrinking mode, as used by Poseidon.
use std::marker::PhantomData;
use bitvec::prelude::*;
use halo2::arithmetic::FieldExt;
const STATE: usize = 80;
#[derive(Debug, Clone, Copy)]
pub(super) enum FieldType {
/// GF(2^n)
#[allow(dead_code)]
Binary,
/// GF(p)
PrimeOrder,
}
impl FieldType {
fn tag(&self) -> u8 {
match self {
FieldType::Binary => 0,
FieldType::PrimeOrder => 1,
}
}
}
#[derive(Debug, Clone, Copy)]
pub(super) enum SboxType {
/// x^alpha
Pow,
/// x^(-1)
#[allow(dead_code)]
Inv,
}
impl SboxType {
fn tag(&self) -> u8 {
match self {
SboxType::Pow => 0,
SboxType::Inv => 1,
}
}
}
pub(super) struct Grain<F: FieldExt> {
state: bitarr!(for 80, in Msb0, u8),
next_bit: usize,
_field: PhantomData<F>,
}
impl<F: FieldExt> Grain<F> {
pub(super) fn new(sbox: SboxType, t: u16, r_f: u16, r_p: u16) -> Self {
// Initialize the LFSR state.
let mut state = bitarr![Msb0, u8; 1; STATE];
let mut set_bits = |offset: usize, len, value| {
// Poseidon reference impl sets initial state bits in MSB order.
for i in 0..len {
*state.get_mut(offset + len - 1 - i).unwrap() = (value >> i) & 1 != 0;
}
};
set_bits(0, 2, FieldType::PrimeOrder.tag() as u16);
set_bits(2, 4, sbox.tag() as u16);
set_bits(6, 12, F::NUM_BITS as u16);
set_bits(18, 12, t);
set_bits(30, 10, r_f);
set_bits(40, 10, r_p);
let mut grain = Grain {
state,
next_bit: STATE,
_field: PhantomData::default(),
};
// Discard the first 160 bits.
for _ in 0..20 {
grain.load_next_8_bits();
grain.next_bit = STATE;
}
grain
}
fn load_next_8_bits(&mut self) {
let mut new_bits = 0u8;
for i in 0..8 {
new_bits |= ((self.state[i + 62]
^ self.state[i + 51]
^ self.state[i + 38]
^ self.state[i + 23]
^ self.state[i + 13]
^ self.state[i]) as u8)
<< i;
}
self.state.rotate_left(8);
self.next_bit -= 8;
for i in 0..8 {
*self.state.get_mut(self.next_bit + i).unwrap() = (new_bits >> i) & 1 != 0;
}
}
fn get_next_bit(&mut self) -> bool {
if self.next_bit == STATE {
self.load_next_8_bits();
}
let ret = self.state[self.next_bit];
self.next_bit += 1;
ret
}
/// Returns the next field element from this Grain instantiation.
pub(super) fn next_field_element(&mut self) -> F {
// Loop until we get an element in the field.
loop {
let mut bytes = F::Repr::default();
// Fill the repr with bits in little-endian order.
let view = bytes.as_mut();
for (i, bit) in self.take(F::NUM_BITS as usize).enumerate() {
view[i / 8] |= if bit { 1 << (i % 8) } else { 0 };
}
if let Some(f) = F::from_repr(bytes) {
break f;
}
}
}
}
impl<F: FieldExt> Iterator for Grain<F> {
type Item = bool;
fn next(&mut self) -> Option<Self::Item> {
// Evaluate bits in pairs:
// - If the first bit is a 1, output the second bit.
// - If the first bit is a 0, discard the second bit.
while !self.get_next_bit() {
self.get_next_bit();
}
Some(self.get_next_bit())
}
}
#[cfg(test)]
mod tests {
use pasta_curves::Fp;
use super::{Grain, SboxType};
#[test]
fn grain() {
let mut grain = Grain::<Fp>::new(SboxType::Pow, 3, 8, 56);
let _f = grain.next_field_element();
}
}

View File

@ -0,0 +1,106 @@
use halo2::arithmetic::FieldExt;
use super::grain::Grain;
pub(super) fn generate_mds<F: FieldExt>(
grain: &mut Grain<F>,
arity: usize,
mut select: usize,
) -> (Vec<Vec<F>>, Vec<Vec<F>>) {
let (xs, ys, mds) = loop {
// Generate two [F; arity] arrays of unique field elements.
let (xs, ys) = loop {
let mut vals: Vec<_> = (0..2 * arity).map(|_| grain.next_field_element()).collect();
// Check that we have unique field elements.
let mut unique = vals.clone();
unique.sort_unstable();
unique.dedup();
if vals.len() == unique.len() {
let rhs = vals.split_off(arity);
break (vals, rhs);
}
};
// We need to ensure that the MDS is secure. Instead of checking the MDS against
// the relevant algorithms directly, we witness a fixed number of MDS matrices
// that we need to sample from the given Grain state before obtaining a secure
// matrix. This can be determined out-of-band via the reference implementation in
// Sage.
if select != 0 {
select -= 1;
continue;
}
// Generate a Cauchy matrix, with elements a_ij in the form:
// a_ij = 1/(x_i - y_j); x_i - y_j != 0
//
// The Poseidon paper uses the alternate definition:
// a_ij = 1/(x_i + y_j); x_i + y_j != 0
//
// These are clearly equivalent on `y <= -y`, but it is easier to work with the
// negative formulation, because ensuring that xs ys is unique implies that
// x_i - y_j != 0 by construction (whereas the positive case does not hold). It
// also makes computation of the matrix inverse simpler below (the theorem used
// was formulated for the negative definition).
let mut mds = vec![vec![F::zero(); arity]; arity];
for i in 0..arity {
for j in 0..arity {
mds[i][j] = (xs[i] - ys[j]).invert().unwrap();
}
}
break (xs, ys, mds);
};
// Compute the inverse. All square Cauchy matrices have a non-zero determinant and
// thus are invertible. The inverse has elements b_ij given by:
//
// b_ij = (x_j - y_i) A_j(y_i) B_i(x_j) (Schechter 1959, Theorem 1)
//
// where A_i(x) and B_i(x) are the Lagrange polynomials for xs and ys respectively.
let mut mds_inv = vec![vec![F::zero(); arity]; arity];
let l = |xs: &[F], j, x: F| {
let x_j = xs[j];
xs.iter().enumerate().fold(F::one(), |acc, (m, x_m)| {
if m == j {
acc
} else {
// We can invert freely; by construction, the elements of xs are distinct.
acc * (x - x_m) * (x_j - x_m).invert().unwrap()
}
})
};
for i in 0..arity {
for j in 0..arity {
mds_inv[i][j] = (xs[j] - ys[i]) * l(&xs, j, ys[i]) * l(&ys, i, xs[j]);
}
}
(mds, mds_inv)
}
#[cfg(test)]
mod tests {
use pasta_curves::Fp;
use super::{generate_mds, Grain};
#[test]
fn poseidon_mds() {
let arity = 3;
let mut grain = Grain::new(super::super::grain::SboxType::Pow, arity as u16, 8, 56);
let (mds, mds_inv) = generate_mds::<Fp>(&mut grain, arity, 0);
// Verify that MDS * MDS^-1 = I.
for i in 0..arity {
for j in 0..arity {
let expected = if i == j { Fp::one() } else { Fp::zero() };
assert_eq!(
(0..arity).fold(Fp::zero(), |acc, k| acc + (mds[i][k] * mds_inv[k][j])),
expected
);
}
}
}
}