diff --git a/Cargo.toml b/Cargo.toml index b8f32a58..17999251 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,9 @@ rand = "0.8" nonempty = "0.6" subtle = "2.3" +# Developer tooling dependencies +plotters = { version = "0.3.0", optional = true } + [dependencies.halo2] git = "https://github.com/zcash/halo2.git" rev = "0448584333c1e262e4a7dbaefa6fdd896bdaaefb" @@ -52,6 +55,7 @@ proptest = "1.0.0" bench = false [features] +dev-graph = ["halo2/dev-graph", "plotters"] test-dependencies = ["proptest"] [[bench]] diff --git a/src/circuit/gadget/poseidon.rs b/src/circuit/gadget/poseidon.rs index 1ebb35b0..3ea50a74 100644 --- a/src/circuit/gadget/poseidon.rs +++ b/src/circuit/gadget/poseidon.rs @@ -8,6 +8,9 @@ use halo2::{ plonk::Error, }; +mod pow5t3; +pub use pow5t3::{Pow5T3Chip, Pow5T3Config}; + /// The set of circuit instructions required to use the [`Poseidon`] gadget. pub trait PoseidonInstructions: Chip { /// Variable representing the state over which the Poseidon permutation operates. diff --git a/src/circuit/gadget/poseidon/pow5t3.rs b/src/circuit/gadget/poseidon/pow5t3.rs new file mode 100644 index 00000000..026c570a --- /dev/null +++ b/src/circuit/gadget/poseidon/pow5t3.rs @@ -0,0 +1,570 @@ +use halo2::{ + arithmetic::FieldExt, + circuit::{Cell, Chip, Layouter, Region}, + plonk::{Advice, Column, ConstraintSystem, Error, Expression, Fixed, Permutation, Selector}, + poly::Rotation, +}; + +use super::PoseidonInstructions; +use crate::primitives::poseidon::{Mds, Spec}; + +const WIDTH: usize = 3; + +/// Configuration for an [`Pow5T3Chip`]. +#[derive(Clone, Debug)] +pub struct Pow5T3Config { + state: [Column; WIDTH], + state_permutation: Permutation, + partial_sbox: Column, + rc_a: [Column; WIDTH], + rc_b: [Column; WIDTH], + s_full: Selector, + s_partial: Selector, + s_final: Selector, + + half_full_rounds: usize, + half_partial_rounds: usize, + alpha: [u64; 4], + round_constants: Vec<[F; WIDTH]>, + m_reg: Mds, + m_inv: Mds, +} + +/// A Poseidon chip using an $x^5$ S-Box, with a width of 3, suitable for a 2:1 reduction. +#[derive(Debug)] +pub struct Pow5T3Chip { + config: Pow5T3Config, +} + +impl Pow5T3Chip { + /// Configures this chip for use in a circuit. + // + // TODO: Does the rate need to be hard-coded here, or only the width? It probably + // needs to be known wherever we implement the hashing gadget, but it isn't strictly + // necessary for the permutation. + pub fn configure>( + meta: &mut ConstraintSystem, + spec: S, + state: [Column; WIDTH], + ) -> Pow5T3Config { + // Generate constants for the Poseidon permutation. + // This gadget requires R_F and R_P to be even. + assert!(S::full_rounds() & 1 == 0); + assert!(S::partial_rounds() & 1 == 0); + let half_full_rounds = S::full_rounds() / 2; + let half_partial_rounds = S::partial_rounds() / 2; + let (round_constants, m_reg, m_inv) = spec.constants(); + + let state_permutation = + Permutation::new(meta, &[state[0].into(), state[1].into(), state[2].into()]); + + let partial_sbox = meta.advice_column(); + + let rc_a = [ + meta.fixed_column(), + meta.fixed_column(), + meta.fixed_column(), + ]; + let rc_b = [ + meta.fixed_column(), + meta.fixed_column(), + meta.fixed_column(), + ]; + + let s_full = meta.selector(); + let s_partial = meta.selector(); + let s_final = meta.selector(); + + let alpha = [5, 0, 0, 0]; + let pow_5 = |v: Expression| { + let v2 = v.clone() * v.clone(); + v2.clone() * v2 * v + }; + + meta.create_gate("full round", |meta| { + let cur_0 = meta.query_advice(state[0], Rotation::cur()); + let cur_1 = meta.query_advice(state[1], Rotation::cur()); + let cur_2 = meta.query_advice(state[2], Rotation::cur()); + let next = [ + meta.query_advice(state[0], Rotation::next()), + meta.query_advice(state[1], Rotation::next()), + meta.query_advice(state[2], Rotation::next()), + ]; + + let rc_0 = meta.query_fixed(rc_a[0], Rotation::cur()); + let rc_1 = meta.query_fixed(rc_a[1], Rotation::cur()); + let rc_2 = meta.query_fixed(rc_a[2], Rotation::cur()); + + let s_full = meta.query_selector(s_full, Rotation::cur()); + + let full_round = |next_idx: usize| { + s_full.clone() + * (pow_5(cur_0.clone() + rc_0.clone()) * m_reg[next_idx][0] + + pow_5(cur_1.clone() + rc_1.clone()) * m_reg[next_idx][1] + + pow_5(cur_2.clone() + rc_2.clone()) * m_reg[next_idx][2] + - next[next_idx].clone()) + }; + + vec![full_round(0), full_round(1), full_round(2)] + }); + + meta.create_gate("partial round", |meta| { + let cur_0 = meta.query_advice(state[0], Rotation::cur()); + let cur_1 = meta.query_advice(state[1], Rotation::cur()); + let cur_2 = meta.query_advice(state[2], Rotation::cur()); + let mid_0 = meta.query_advice(partial_sbox, Rotation::cur()); + let next_0 = meta.query_advice(state[0], Rotation::next()); + let next_1 = meta.query_advice(state[1], Rotation::next()); + let next_2 = meta.query_advice(state[2], Rotation::next()); + + let rc_a0 = meta.query_fixed(rc_a[0], Rotation::cur()); + let rc_a1 = meta.query_fixed(rc_a[1], Rotation::cur()); + let rc_a2 = meta.query_fixed(rc_a[2], Rotation::cur()); + let rc_b0 = meta.query_fixed(rc_b[0], Rotation::cur()); + let rc_b1 = meta.query_fixed(rc_b[1], Rotation::cur()); + let rc_b2 = meta.query_fixed(rc_b[2], Rotation::cur()); + + let s_partial = meta.query_selector(s_partial, Rotation::cur()); + + let partial_round_linear = |idx: usize, rc_b: Expression| { + s_partial.clone() + * (mid_0.clone() * m_reg[idx][0] + + (cur_1.clone() + rc_a1.clone()) * m_reg[idx][1] + + (cur_2.clone() + rc_a2.clone()) * m_reg[idx][2] + + rc_b + - (next_0.clone() * m_inv[idx][0] + + next_1.clone() * m_inv[idx][1] + + next_2.clone() * m_inv[idx][2])) + }; + + vec![ + s_partial.clone() * (pow_5(cur_0 + rc_a0) - mid_0.clone()), + s_partial.clone() + * (pow_5( + mid_0.clone() * m_reg[0][0] + + (cur_1.clone() + rc_a1.clone()) * m_reg[0][1] + + (cur_2.clone() + rc_a2.clone()) * m_reg[0][2] + + rc_b0, + ) - (next_0.clone() * m_inv[0][0] + + next_1.clone() * m_inv[0][1] + + next_2.clone() * m_inv[0][2])), + partial_round_linear(1, rc_b1), + partial_round_linear(2, rc_b2), + ] + }); + + meta.create_gate("final full round", |meta| { + let cur = [ + meta.query_advice(state[0], Rotation::cur()), + meta.query_advice(state[1], Rotation::cur()), + meta.query_advice(state[2], Rotation::cur()), + ]; + let next = [ + meta.query_advice(state[0], Rotation::next()), + meta.query_advice(state[1], Rotation::next()), + meta.query_advice(state[2], Rotation::next()), + ]; + let rc = [ + meta.query_fixed(rc_a[0], Rotation::cur()), + meta.query_fixed(rc_a[1], Rotation::cur()), + meta.query_fixed(rc_a[2], Rotation::cur()), + ]; + let s_final = meta.query_selector(s_final, Rotation::cur()); + + let final_full_round = |idx: usize| { + s_final.clone() * (pow_5(cur[idx].clone() + rc[idx].clone()) - next[idx].clone()) + }; + vec![ + final_full_round(0), + final_full_round(1), + final_full_round(2), + ] + }); + + Pow5T3Config { + state, + state_permutation, + partial_sbox, + rc_a, + rc_b, + s_full, + s_partial, + s_final, + half_full_rounds, + half_partial_rounds, + alpha, + round_constants, + m_reg, + m_inv, + } + } + + fn construct(config: Pow5T3Config) -> Self { + Pow5T3Chip { config } + } +} + +impl Chip for Pow5T3Chip { + type Config = Pow5T3Config; + type Loaded = (); + + fn config(&self) -> &Self::Config { + &self.config + } + + fn loaded(&self) -> &Self::Loaded { + &() + } +} + +impl PoseidonInstructions for Pow5T3Chip { + type State = Pow5T3State; + + fn permute( + &self, + layouter: &mut impl Layouter, + initial_state: &Self::State, + ) -> Result { + let config = self.config(); + + layouter.assign_region( + || "permute state", + |mut region| { + // Load the initial state into this region. + let state = Pow5T3State::load(&mut region, &config, initial_state)?; + + let state = (0..config.half_full_rounds).fold(Ok(state), |res, r| { + res.and_then(|state| state.full_round(&mut region, &config, r, r)) + })?; + + let state = (0..config.half_partial_rounds).fold(Ok(state), |res, r| { + res.and_then(|state| { + state.partial_round( + &mut region, + &config, + config.half_full_rounds + 2 * r, + config.half_full_rounds + r, + ) + }) + })?; + + (0..config.half_full_rounds).fold(Ok(state), |res, r| { + res.and_then(|state| { + if r < config.half_full_rounds - 1 { + state.full_round( + &mut region, + &config, + config.half_full_rounds + 2 * config.half_partial_rounds + r, + config.half_full_rounds + config.half_partial_rounds + r, + ) + } else { + state.final_round( + &mut region, + &config, + config.half_full_rounds + 2 * config.half_partial_rounds + r, + config.half_full_rounds + config.half_partial_rounds + r, + ) + } + }) + }) + }, + ) + } +} + +#[derive(Debug)] +struct StateWord { + var: Cell, + value: Option, +} + +#[derive(Debug)] +pub struct Pow5T3State([StateWord; WIDTH]); + +impl Pow5T3State { + fn full_round( + self, + region: &mut Region, + config: &Pow5T3Config, + round: usize, + offset: usize, + ) -> Result { + Self::round(region, config, round, offset, config.s_full, |_| { + let q_0 = self.0[0] + .value + .map(|v| v + config.round_constants[round][0]); + let q_1 = self.0[1] + .value + .map(|v| v + config.round_constants[round][1]); + let q_2 = self.0[2] + .value + .map(|v| v + config.round_constants[round][2]); + + let r_0 = q_0.map(|v| v.pow(&config.alpha)); + let r_1 = q_1.map(|v| v.pow(&config.alpha)); + let r_2 = q_2.map(|v| v.pow(&config.alpha)); + + let m = &config.m_reg; + let r = r_0.and_then(|r_0| r_1.and_then(|r_1| r_2.map(|r_2| [r_0, r_1, r_2]))); + + Ok(( + round + 1, + [ + r.map(|r| m[0][0] * r[0] + m[0][1] * r[1] + m[0][2] * r[2]), + r.map(|r| m[1][0] * r[0] + m[1][1] * r[1] + m[1][2] * r[2]), + r.map(|r| m[2][0] * r[0] + m[2][1] * r[1] + m[2][2] * r[2]), + ], + )) + }) + } + + fn partial_round( + self, + region: &mut Region, + config: &Pow5T3Config, + round: usize, + offset: usize, + ) -> Result { + Self::round(region, config, round, offset, config.s_partial, |region| { + let m = &config.m_reg; + + let p = self.0[0].value.and_then(|p_0| { + self.0[1] + .value + .and_then(|p_1| self.0[2].value.map(|p_2| [p_0, p_1, p_2])) + }); + + let r = p.map(|p| { + [ + (p[0] + config.round_constants[round][0]).pow(&config.alpha), + p[1] + config.round_constants[round][1], + p[2] + config.round_constants[round][2], + ] + }); + + region.assign_advice( + || format!("round_{} partial_sbox", round), + config.partial_sbox, + offset, + || r.map(|r| r[0]).ok_or(Error::SynthesisError), + )?; + + let p_mid = r.map(|r| { + [ + m[0][0] * r[0] + m[0][1] * r[1] + m[0][2] * r[2], + m[1][0] * r[0] + m[1][1] * r[1] + m[1][2] * r[2], + m[2][0] * r[0] + m[2][1] * r[1] + m[2][2] * r[2], + ] + }); + + // Load the second round constants. + let mut load_round_constant = |i: usize| { + region.assign_fixed( + || format!("round_{} rc_{}", round + 1, i), + config.rc_b[i], + offset, + || Ok(config.round_constants[round + 1][i]), + ) + }; + for i in 0..WIDTH { + load_round_constant(i)?; + } + + let r_mid = p_mid.map(|p| { + [ + (p[0] + config.round_constants[round + 1][0]).pow(&config.alpha), + p[1] + config.round_constants[round + 1][1], + p[2] + config.round_constants[round + 1][2], + ] + }); + + Ok(( + round + 2, + [ + r_mid.map(|r| m[0][0] * r[0] + m[0][1] * r[1] + m[0][2] * r[2]), + r_mid.map(|r| m[1][0] * r[0] + m[1][1] * r[1] + m[1][2] * r[2]), + r_mid.map(|r| m[2][0] * r[0] + m[2][1] * r[1] + m[2][2] * r[2]), + ], + )) + }) + } + + fn final_round( + self, + region: &mut Region, + config: &Pow5T3Config, + round: usize, + offset: usize, + ) -> Result { + Self::round(region, config, round, offset, config.s_final, |_| { + let mut new_state = self + .0 + .iter() + .zip(config.round_constants[round].iter()) + .map(|(word, rc)| word.value.map(|v| (v + rc).pow(&config.alpha))); + + Ok(( + round + 1, + [ + new_state.next().unwrap(), + new_state.next().unwrap(), + new_state.next().unwrap(), + ], + )) + }) + } + + fn load( + region: &mut Region, + config: &Pow5T3Config, + initial_state: &Self, + ) -> Result { + let mut load_state_word = |i: usize| { + let value = initial_state.0[i].value; + let var = region.assign_advice( + || format!("load state_{}", i), + config.state[i], + 0, + || value.ok_or(Error::SynthesisError), + )?; + region.constrain_equal(&config.state_permutation, initial_state.0[i].var, var)?; + Ok(StateWord { var, value }) + }; + + Ok(Pow5T3State([ + load_state_word(0)?, + load_state_word(1)?, + load_state_word(2)?, + ])) + } + + fn round( + region: &mut Region, + config: &Pow5T3Config, + round: usize, + offset: usize, + round_gate: Selector, + round_fn: impl FnOnce(&mut Region) -> Result<(usize, [Option; WIDTH]), Error>, + ) -> Result { + // Enable the required gate. + round_gate.enable(region, offset)?; + + // Load the round constants. + let mut load_round_constant = |i: usize| { + region.assign_fixed( + || format!("round_{} rc_{}", round, i), + config.rc_a[i], + offset, + || Ok(config.round_constants[round][i]), + ) + }; + for i in 0..WIDTH { + load_round_constant(i)?; + } + + // Compute the next round's state. + let (next_round, next_state) = round_fn(region)?; + + let mut next_state_word = |i: usize| { + let value = next_state[i]; + let var = region.assign_advice( + || format!("round_{} state_{}", next_round, i), + config.state[i], + offset + 1, + || value.ok_or(Error::SynthesisError), + )?; + Ok(StateWord { var, value }) + }; + + Ok(Pow5T3State([ + next_state_word(0)?, + next_state_word(1)?, + next_state_word(2)?, + ])) + } +} + +#[cfg(test)] +mod tests { + use halo2::{ + circuit::{layouter, Layouter}, + dev::MockProver, + pasta::Fp, + plonk::{Assignment, Circuit, ConstraintSystem, Error}, + }; + + use super::{PoseidonInstructions, Pow5T3Chip, Pow5T3Config, Pow5T3State, StateWord}; + use crate::primitives::poseidon::OrchardNullifier; + + struct MyCircuit {} + + impl Circuit for MyCircuit { + type Config = Pow5T3Config; + + fn configure(meta: &mut ConstraintSystem) -> Pow5T3Config { + let state = [ + meta.advice_column(), + meta.advice_column(), + meta.advice_column(), + ]; + + Pow5T3Chip::configure(meta, OrchardNullifier, state) + } + + fn synthesize( + &self, + cs: &mut impl Assignment, + config: Pow5T3Config, + ) -> Result<(), Error> { + let mut layouter = layouter::SingleChipLayouter::new(cs)?; + + let initial_state = layouter.assign_region( + || "prepare initial state", + |mut region| { + let mut state_word = |i: usize| { + let value = Some(Fp::from(i as u64)); + let var = region.assign_advice( + || format!("load state_{}", i), + config.state[i], + 0, + || value.ok_or(Error::SynthesisError), + )?; + Ok(StateWord { var, value }) + }; + + Ok(Pow5T3State([ + state_word(0)?, + state_word(1)?, + state_word(2)?, + ])) + }, + )?; + + let chip = Pow5T3Chip::construct(config); + chip.permute(&mut layouter, &initial_state).map(|_| ()) + } + } + + #[test] + fn poseidon() { + let k = 6; + let circuit = MyCircuit {}; + let prover = MockProver::run(k, &circuit, vec![]).unwrap(); + assert_eq!(prover.verify(), Ok(())) + } + + #[cfg(feature = "dev-graph")] + #[test] + fn print_poseidon_chip() { + use plotters::prelude::*; + + let root = BitMapBackend::new("poseidon-chip-layout.png", (1024, 768)).into_drawing_area(); + root.fill(&WHITE).unwrap(); + let root = root + .titled("Poseidon Chip Layout", ("sans-serif", 60)) + .unwrap(); + + let circuit = MyCircuit {}; + halo2::dev::circuit_layout(&circuit, &root).unwrap(); + } +} diff --git a/src/primitives/poseidon.rs b/src/primitives/poseidon.rs index e1954545..426fd939 100644 --- a/src/primitives/poseidon.rs +++ b/src/primitives/poseidon.rs @@ -68,7 +68,7 @@ pub trait Spec { } /// Runs the Poseidon permutation on the given state. -fn permute, const T: usize, const RATE: usize>( +pub(crate) fn permute, const T: usize, const RATE: usize>( state: &mut State, mds: &Mds, round_constants: &[[F; T]],