Merge pull request #41 from zcash/poseidon-primitive

Poseidon primitive
This commit is contained in:
str4d 2021-03-26 07:36:45 +13:00 committed by GitHub
commit ee2bfa7f43
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 2267 additions and 0 deletions

View File

@ -42,6 +42,7 @@ rev = "f8ff124a52d86e122e0705e8e9272f2099fe4c46"
[dev-dependencies]
criterion = "0.3"
hex = "0.4"
[lib]
bench = false

View File

@ -4,5 +4,6 @@
// - EphemeralPublicKey
// - EphemeralSecretKey
pub(crate) mod poseidon;
pub mod redpallas;
pub(crate) mod sinsemilla;

363
src/primitives/poseidon.rs Normal file
View File

@ -0,0 +1,363 @@
use std::iter;
use std::marker::PhantomData;
use halo2::arithmetic::FieldExt;
pub(crate) mod grain;
pub(crate) mod mds;
mod nullifier;
pub use nullifier::OrchardNullifier;
use grain::SboxType;
/// A specification for a Poseidon permutation.
pub trait Spec<F: FieldExt> {
/// The type used to hold permutation state, or equivalent-length constant values.
///
/// This must be an array of length [`Spec::width`], that defaults to all-zeroes.
type State: Default + AsRef<[F]> + AsMut<[F]>;
/// The type used to hold duplex sponge state.
///
/// This must be an array of length equal to the rate of the duplex sponge (allowing
/// for a capacity consistent with this specification's security level), that defaults
/// to `[None; RATE]`.
type Rate: Default + AsRef<[Option<F>]> + AsMut<[Option<F>]>;
/// The width of this specification.
fn width() -> usize;
/// The number of full rounds for this specification.
///
/// This must be an even number.
fn full_rounds() -> usize;
/// The number of partial rounds for this specification.
fn partial_rounds() -> usize;
/// The S-box for this specification.
fn sbox(val: F) -> F;
/// Side-loaded index of the first correct and secure MDS that will be generated by
/// the reference implementation.
///
/// This is used by the default implementation of [`Spec::constants`]. If you are
/// hard-coding the constants, you may leave this unimplemented.
fn secure_mds(&self) -> usize;
/// Generates `(round_constants, mds, mds^-1)` corresponding to this specification.
fn constants(&self) -> (Vec<Self::State>, Vec<Self::State>, Vec<Self::State>) {
let t = Self::width();
let r_f = Self::full_rounds();
let r_p = Self::partial_rounds();
let mut grain = grain::Grain::new(SboxType::Pow, t as u16, r_f as u16, r_p as u16);
let round_constants = (0..(r_f + r_p))
.map(|_| {
let mut rc_row = Self::State::default();
for (rc, value) in rc_row
.as_mut()
.iter_mut()
.zip((0..t).map(|_| grain.next_field_element()))
{
*rc = value;
}
rc_row
})
.collect();
let (mds, mds_inv) = mds::generate_mds(&mut grain, t, self.secure_mds());
(
round_constants,
mds.into_iter()
.map(|row| {
let mut mds_row = Self::State::default();
for (entry, value) in mds_row.as_mut().iter_mut().zip(row.into_iter()) {
*entry = value;
}
mds_row
})
.collect(),
mds_inv
.into_iter()
.map(|row| {
let mut mds_row = Self::State::default();
for (entry, value) in mds_row.as_mut().iter_mut().zip(row.into_iter()) {
*entry = value;
}
mds_row
})
.collect(),
)
}
}
/// Runs the Poseidon permutation on the given state.
fn permute<F: FieldExt, S: Spec<F>>(
state: &mut S::State,
mds: &[S::State],
round_constants: &[S::State],
) {
let r_f = S::full_rounds() / 2;
let r_p = S::partial_rounds();
let apply_mds = |state: &mut S::State| {
let mut new_state = S::State::default();
// Matrix multiplication
#[allow(clippy::needless_range_loop)]
for i in 0..S::width() {
for j in 0..S::width() {
new_state.as_mut()[i] += mds[i].as_ref()[j] * state.as_ref()[j];
}
}
*state = new_state;
};
let full_round = |state: &mut S::State, rcs: &S::State| {
for (word, rc) in state.as_mut().iter_mut().zip(rcs.as_ref().iter()) {
*word = S::sbox(*word + rc);
}
apply_mds(state);
};
let part_round = |state: &mut S::State, rcs: &S::State| {
for (word, rc) in state.as_mut().iter_mut().zip(rcs.as_ref().iter()) {
*word += rc;
}
// In a partial round, the S-box is only applied to the first state word.
state.as_mut()[0] = S::sbox(state.as_ref()[0]);
apply_mds(state);
};
iter::empty()
.chain(iter::repeat(&full_round as &dyn Fn(&mut S::State, &S::State)).take(r_f))
.chain(iter::repeat(&part_round as &dyn Fn(&mut S::State, &S::State)).take(r_p))
.chain(iter::repeat(&full_round as &dyn Fn(&mut S::State, &S::State)).take(r_f))
.zip(round_constants.iter())
.fold(state, |state, (round, rcs)| {
round(state, rcs);
state
});
}
fn poseidon_duplex<F: FieldExt, S: Spec<F>>(
state: &mut S::State,
input: &S::Rate,
pad_and_add: &dyn Fn(&mut S::State, &S::Rate),
mds_matrix: &[S::State],
round_constants: &[S::State],
) -> S::Rate {
pad_and_add(state, input);
permute::<F, S>(state, mds_matrix, round_constants);
let mut output = S::Rate::default();
for (word, value) in output.as_mut().iter_mut().zip(state.as_ref().iter()) {
*word = Some(*value);
}
output
}
enum SpongeState<F: FieldExt, S: Spec<F>> {
Absorbing(S::Rate),
Squeezing(S::Rate),
}
impl<F: FieldExt, S: Spec<F>> SpongeState<F, S> {
fn absorb(val: F) -> Self {
let mut input = S::Rate::default();
input.as_mut()[0] = Some(val);
SpongeState::Absorbing(input)
}
}
/// A Poseidon duplex sponge.
pub struct Duplex<F: FieldExt, S: Spec<F>> {
sponge: SpongeState<F, S>,
state: S::State,
pad_and_add: Box<dyn Fn(&mut S::State, &S::Rate)>,
mds_matrix: Vec<S::State>,
round_constants: Vec<S::State>,
_marker: PhantomData<S>,
}
impl<F: FieldExt, S: Spec<F>> Duplex<F, S> {
/// Constructs a new duplex sponge for the given Poseidon specification.
pub fn new(
spec: S,
initial_capacity_element: F,
pad_and_add: Box<dyn Fn(&mut S::State, &S::Rate)>,
) -> Self {
let (round_constants, mds_matrix, _) = spec.constants();
let input = S::Rate::default();
let mut state = S::State::default();
state.as_mut()[input.as_ref().len()] = initial_capacity_element;
Duplex {
sponge: SpongeState::Absorbing(input),
state,
pad_and_add,
mds_matrix,
round_constants,
_marker: PhantomData::default(),
}
}
/// Absorbs an element into the sponge.
pub fn absorb(&mut self, value: F) {
match self.sponge {
SpongeState::Absorbing(ref mut input) => {
for entry in input.as_mut().iter_mut() {
if entry.is_none() {
*entry = Some(value);
return;
}
}
// We've already absorbed as many elements as we can
let _ = poseidon_duplex::<F, S>(
&mut self.state,
&input,
&self.pad_and_add,
&self.mds_matrix,
&self.round_constants,
);
self.sponge = SpongeState::absorb(value);
}
SpongeState::Squeezing(_) => {
// Drop the remaining output elements
self.sponge = SpongeState::absorb(value);
}
}
}
/// Squeezes an element from the sponge.
pub fn squeeze(&mut self) -> F {
loop {
match self.sponge {
SpongeState::Absorbing(ref input) => {
self.sponge = SpongeState::Squeezing(poseidon_duplex::<F, S>(
&mut self.state,
&input,
&self.pad_and_add,
&self.mds_matrix,
&self.round_constants,
));
}
SpongeState::Squeezing(ref mut output) => {
for entry in output.as_mut().iter_mut() {
if let Some(e) = entry.take() {
return e;
}
}
// We've already squeezed out all available elements
self.sponge = SpongeState::Absorbing(S::Rate::default());
}
}
}
}
}
/// A domain in which a Poseidon hash function is being used.
pub trait Domain<F: FieldExt, S: Spec<F>>: Copy {
/// The initial capacity element, encoding this domain.
fn initial_capacity_element(&self) -> F;
/// Returns a function that will update the given state with the given input to a
/// duplex permutation round, applying padding according to this domain specification.
fn pad_and_add(&self) -> Box<dyn Fn(&mut S::State, &S::Rate)>;
}
/// A Poseidon hash function used with constant input length.
///
/// Domain specified in section 4.2 of https://eprint.iacr.org/2019/458.pdf
#[derive(Clone, Copy, Debug)]
pub struct ConstantLength(pub usize);
impl<F: FieldExt, S: Spec<F>> Domain<F, S> for ConstantLength {
fn initial_capacity_element(&self) -> F {
// Capacity value is $length \cdot 2^64 + (o-1)$ where o is the output length.
// We hard-code an output length of 1.
F::from_u128((self.0 as u128) << 64)
}
fn pad_and_add(&self) -> Box<dyn Fn(&mut S::State, &S::Rate)> {
Box::new(|state, input| {
// `Iterator::zip` short-circuits when one iterator completes, so this will only
// mutate the rate portion of the state.
for (word, value) in state.as_mut().iter_mut().zip(input.as_ref().iter()) {
// For constant-input-length hashing, padding consists of the field
// elements being zero, so we don't add anything to the state.
if let Some(value) = value {
*word += value;
}
}
})
}
}
/// A Poseidon hash function, built around a duplex sponge.
pub struct Hash<F: FieldExt, S: Spec<F>, D: Domain<F, S>> {
duplex: Duplex<F, S>,
domain: D,
}
impl<F: FieldExt, S: Spec<F>, D: Domain<F, S>> Hash<F, S, D> {
/// Initializes a new hasher.
pub fn init(spec: S, domain: D) -> Self {
Hash {
duplex: Duplex::new(
spec,
domain.initial_capacity_element(),
domain.pad_and_add(),
),
domain,
}
}
}
impl<F: FieldExt, S: Spec<F>> Hash<F, S, ConstantLength> {
/// Hashes the given input.
///
/// # Panics
///
/// Panics if the message length is not the correct length.
pub fn hash(mut self, message: impl Iterator<Item = F>) -> F {
let mut length = 0;
for (i, value) in message.enumerate() {
length = i + 1;
self.duplex.absorb(value);
}
assert_eq!(length, self.domain.0);
self.duplex.squeeze()
}
}
#[cfg(test)]
mod tests {
use halo2::arithmetic::FieldExt;
use pasta_curves::pallas;
use super::{permute, ConstantLength, Hash, OrchardNullifier, Spec};
#[test]
fn orchard_spec_equivalence() {
let message = [pallas::Base::from_u64(6), pallas::Base::from_u64(42)];
let (round_constants, mds, _) = OrchardNullifier.constants();
let hasher = Hash::init(OrchardNullifier, ConstantLength(2));
let result = hasher.hash(message.iter().cloned());
// The result should be equivalent to just directly applying the permutation and
// taking the first state element as the output.
let mut state = [message[0], message[1], pallas::Base::from_u128(2 << 64)];
permute::<_, OrchardNullifier>(&mut state, &mds, &round_constants);
assert_eq!(state[0], result);
}
}

View File

@ -0,0 +1,193 @@
//! The Grain LFSR in self-shrinking mode, as used by Poseidon.
use std::marker::PhantomData;
use bitvec::prelude::*;
use halo2::arithmetic::FieldExt;
const STATE: usize = 80;
#[derive(Debug, Clone, Copy)]
pub(super) enum FieldType {
/// GF(2^n)
#[allow(dead_code)]
Binary,
/// GF(p)
PrimeOrder,
}
impl FieldType {
fn tag(&self) -> u8 {
match self {
FieldType::Binary => 0,
FieldType::PrimeOrder => 1,
}
}
}
#[derive(Debug, Clone, Copy)]
pub(super) enum SboxType {
/// x^alpha
Pow,
/// x^(-1)
#[allow(dead_code)]
Inv,
}
impl SboxType {
fn tag(&self) -> u8 {
match self {
SboxType::Pow => 0,
SboxType::Inv => 1,
}
}
}
pub(super) struct Grain<F: FieldExt> {
state: bitarr!(for 80, in Msb0, u8),
next_bit: usize,
_field: PhantomData<F>,
}
impl<F: FieldExt> Grain<F> {
pub(super) fn new(sbox: SboxType, t: u16, r_f: u16, r_p: u16) -> Self {
// Initialize the LFSR state.
let mut state = bitarr![Msb0, u8; 1; STATE];
let mut set_bits = |offset: usize, len, value| {
// Poseidon reference impl sets initial state bits in MSB order.
for i in 0..len {
*state.get_mut(offset + len - 1 - i).unwrap() = (value >> i) & 1 != 0;
}
};
set_bits(0, 2, FieldType::PrimeOrder.tag() as u16);
set_bits(2, 4, sbox.tag() as u16);
set_bits(6, 12, F::NUM_BITS as u16);
set_bits(18, 12, t);
set_bits(30, 10, r_f);
set_bits(40, 10, r_p);
let mut grain = Grain {
state,
next_bit: STATE,
_field: PhantomData::default(),
};
// Discard the first 160 bits.
for _ in 0..20 {
grain.load_next_8_bits();
grain.next_bit = STATE;
}
grain
}
fn load_next_8_bits(&mut self) {
let mut new_bits = 0u8;
for i in 0..8 {
new_bits |= ((self.state[i + 62]
^ self.state[i + 51]
^ self.state[i + 38]
^ self.state[i + 23]
^ self.state[i + 13]
^ self.state[i]) as u8)
<< i;
}
self.state.rotate_left(8);
self.next_bit -= 8;
for i in 0..8 {
*self.state.get_mut(self.next_bit + i).unwrap() = (new_bits >> i) & 1 != 0;
}
}
fn get_next_bit(&mut self) -> bool {
if self.next_bit == STATE {
self.load_next_8_bits();
}
let ret = self.state[self.next_bit];
self.next_bit += 1;
ret
}
/// Returns the next field element from this Grain instantiation.
pub(super) fn next_field_element(&mut self) -> F {
// Loop until we get an element in the field.
loop {
let mut bytes = F::Repr::default();
// Poseidon reference impl interprets the bits as a repr in MSB order, because
// it's easy to do that in Python. Meanwhile, our field elements all use LSB
// order. There's little motivation to diverge from the reference impl; these
// are all constants, so we aren't introducing big-endianness into the rest of
// the circuit (assuming unkeyed Poseidon, but we probably wouldn't want to
// implement Grain inside a circuit, so we'd use a different round constant
// derivation function there).
let view = bytes.as_mut();
for (i, bit) in self.take(F::NUM_BITS as usize).enumerate() {
// If we diverged from the reference impl and interpreted the bits in LSB
// order, we would remove this line.
let i = F::NUM_BITS as usize - 1 - i;
view[i / 8] |= if bit { 1 << (i % 8) } else { 0 };
}
if let Some(f) = F::from_repr(bytes) {
break f;
}
}
}
/// Returns the next field element from this Grain instantiation, without using
/// rejection sampling.
pub(super) fn next_field_element_without_rejection(&mut self) -> F {
let mut bytes = [0u8; 64];
// Poseidon reference impl interprets the bits as a repr in MSB order, because
// it's easy to do that in Python. Additionally, it does not use rejection
// sampling in cases where the constants don't specifically need to be uniformly
// random for security. We do not provide APIs that take a field-element-sized
// array and reduce it modulo the field order, because those are unsafe APIs to
// offer generally (accidentally using them can lead to divergence in consensus
// systems due to not rejecting canonical forms).
//
// Given that we don't want to diverge from the reference implementation, we hack
// around this restriction by serializing the bits into a 64-byte array and then
// calling F::from_bytes_wide. PLEASE DO NOT COPY THIS INTO YOUR OWN CODE!
let view = bytes.as_mut();
for (i, bit) in self.take(F::NUM_BITS as usize).enumerate() {
// If we diverged from the reference impl and interpreted the bits in LSB
// order, we would remove this line.
let i = F::NUM_BITS as usize - 1 - i;
view[i / 8] |= if bit { 1 << (i % 8) } else { 0 };
}
F::from_bytes_wide(&bytes)
}
}
impl<F: FieldExt> Iterator for Grain<F> {
type Item = bool;
fn next(&mut self) -> Option<Self::Item> {
// Evaluate bits in pairs:
// - If the first bit is a 1, output the second bit.
// - If the first bit is a 0, discard the second bit.
while !self.get_next_bit() {
self.get_next_bit();
}
Some(self.get_next_bit())
}
}
#[cfg(test)]
mod tests {
use pasta_curves::Fp;
use super::{Grain, SboxType};
#[test]
fn grain() {
let mut grain = Grain::<Fp>::new(SboxType::Pow, 3, 8, 56);
let _f = grain.next_field_element();
}
}

View File

@ -0,0 +1,124 @@
use halo2::arithmetic::FieldExt;
use super::grain::Grain;
pub(super) fn generate_mds<F: FieldExt>(
grain: &mut Grain<F>,
width: usize,
mut select: usize,
) -> (Vec<Vec<F>>, Vec<Vec<F>>) {
let (xs, ys, mds) = loop {
// Generate two [F; width] arrays of unique field elements.
let (xs, ys) = loop {
let mut vals: Vec<_> = (0..2 * width)
.map(|_| grain.next_field_element_without_rejection())
.collect();
// Check that we have unique field elements.
let mut unique = vals.clone();
unique.sort_unstable();
unique.dedup();
if vals.len() == unique.len() {
let rhs = vals.split_off(width);
break (vals, rhs);
}
};
// We need to ensure that the MDS is secure. Instead of checking the MDS against
// the relevant algorithms directly, we witness a fixed number of MDS matrices
// that we need to sample from the given Grain state before obtaining a secure
// matrix. This can be determined out-of-band via the reference implementation in
// Sage.
if select != 0 {
select -= 1;
continue;
}
// Generate a Cauchy matrix, with elements a_ij in the form:
// a_ij = 1/(x_i + y_j); x_i + y_j != 0
//
// It would be much easier to use the alternate definition:
// a_ij = 1/(x_i - y_j); x_i - y_j != 0
//
// These are clearly equivalent on `y <- -y`, but it is easier to work with the
// negative formulation, because ensuring that xs ys is unique implies that
// x_i - y_j != 0 by construction (whereas the positive case does not hold). It
// also makes computation of the matrix inverse simpler below (the theorem used
// was formulated for the negative definition).
//
// However, the Poseidon paper and reference impl use the positive formulation,
// and we want to rely on the reference impl for MDS security, so we use the same
// formulation.
let mut mds = vec![vec![F::zero(); width]; width];
#[allow(clippy::needless_range_loop)]
for i in 0..width {
for j in 0..width {
let sum = xs[i] + ys[j];
// We leverage the secure MDS selection counter to also check this.
assert!(!sum.is_zero());
mds[i][j] = sum.invert().unwrap();
}
}
break (xs, ys, mds);
};
// Compute the inverse. All square Cauchy matrices have a non-zero determinant and
// thus are invertible. The inverse for a Cauchy matrix of the form:
//
// a_ij = 1/(x_i - y_j); x_i - y_j != 0
//
// has elements b_ij given by:
//
// b_ij = (x_j - y_i) A_j(y_i) B_i(x_j) (Schechter 1959, Theorem 1)
//
// where A_i(x) and B_i(x) are the Lagrange polynomials for xs and ys respectively.
//
// We adapt this to the positive Cauchy formulation by negating ys.
let mut mds_inv = vec![vec![F::zero(); width]; width];
let l = |xs: &[F], j, x: F| {
let x_j = xs[j];
xs.iter().enumerate().fold(F::one(), |acc, (m, x_m)| {
if m == j {
acc
} else {
// We can invert freely; by construction, the elements of xs are distinct.
acc * (x - x_m) * (x_j - x_m).invert().unwrap()
}
})
};
let neg_ys: Vec<_> = ys.iter().map(|y| -*y).collect();
for i in 0..width {
for j in 0..width {
mds_inv[i][j] = (xs[j] - neg_ys[i]) * l(&xs, j, neg_ys[i]) * l(&neg_ys, i, xs[j]);
}
}
(mds, mds_inv)
}
#[cfg(test)]
mod tests {
use pasta_curves::Fp;
use super::{generate_mds, Grain};
#[test]
fn poseidon_mds() {
let width = 3;
let mut grain = Grain::new(super::super::grain::SboxType::Pow, width as u16, 8, 56);
let (mds, mds_inv) = generate_mds::<Fp>(&mut grain, width, 0);
// Verify that MDS * MDS^-1 = I.
#[allow(clippy::needless_range_loop)]
for i in 0..width {
for j in 0..width {
let expected = if i == j { Fp::one() } else { Fp::zero() };
assert_eq!(
(0..width).fold(Fp::zero(), |acc, k| acc + (mds[i][k] * mds_inv[k][j])),
expected
);
}
}
}
}

File diff suppressed because it is too large Load Diff