gadget::utilities: Add decompose_running_sum helper.

This decomposes a field element into K-bit windows using a
running sum. Each step of the running sum is range-constrained.
In strict mode, the final output of the running sum is constrained
to be zero.

This helper asserts K <= 3.
This commit is contained in:
therealyingtong 2021-07-09 20:32:35 +08:00
parent f3c9b6cedc
commit ee062bae3d
5 changed files with 384 additions and 14 deletions

View File

@ -6,7 +6,7 @@ use crate::{
bitrange_subset, copy, lookup_range_check::LookupRangeCheckConfig, range_check, CellValue,
Var,
},
constants::{self, util::decompose_scalar_fixed, T_P},
constants::{self, util::decompose_word, T_P},
primitives::sinsemilla,
};
use halo2::{
@ -428,7 +428,7 @@ impl Config {
// Decompose base field element into 3-bit words.
let words: Vec<Option<u8>> = {
let words = base_field_elem.value().map(|base_field_elem| {
decompose_scalar_fixed::<pallas::Base>(
decompose_word::<pallas::Base>(
base_field_elem,
constants::L_ORCHARD_BASE,
constants::FIXED_BASE_WINDOW_SIZE,

View File

@ -61,7 +61,7 @@ impl Config {
// Decompose scalar into `k-bit` windows
let scalar_windows: Option<Vec<u8>> = scalar.map(|scalar| {
util::decompose_scalar_fixed::<pallas::Scalar>(
util::decompose_word::<pallas::Scalar>(
scalar,
SCALAR_NUM_BITS,
constants::FIXED_BASE_WINDOW_SIZE,

View File

@ -7,6 +7,7 @@ use pasta_curves::arithmetic::FieldExt;
use std::{array, convert::TryInto, ops::Range};
pub(crate) mod cond_swap;
pub(crate) mod decompose_running_sum;
pub(crate) mod enable_flag;
pub(crate) mod lookup_range_check;
pub(crate) mod plonk;

View File

@ -0,0 +1,369 @@
//! Decomposes an n-bit field element alpha into W windows, each window
//! being a K-bit word, using a running sum z.
//! We constrain K <= 3 for this helper.
//! alpha = w_0 + (2^K) w_1 + (2^2K) w_2 + ... + (2^(W-1)K) w_{W-1}
//! z_0 is initialized as alpha. Each successive z_{i+1} is computed as
//! z_{i+1} = (z_{i} - k_i) / (2^k).
//! z_W is constrained to be zero.
//! The difference between each interstitial running sum output is constrained
//! to be K bits, i.e.
//! range_check(k_i, 2^K),
//! where range_check(word, range)
//! = word * (1 - word) * (2 - word) * ... * ((range - 1) - word)
//! is an expression of degree range.
//!
//! Given that the range_check constraint will be toggled by a selector, in
//! practice we will have a selector * range_check(word, range) expression
//! of degree range + 1.
//!
//! This means that 2^K has to be at most degree_bound - 1 in order for
//! the range check constraint to stay within the degree bound.
use ff::PrimeFieldBits;
use halo2::{
circuit::Region,
plonk::{Advice, Column, ConstraintSystem, Error, Permutation, Selector},
poly::Rotation,
};
use super::{copy, range_check, CellValue, Var};
use crate::constants::util::decompose_word;
use pasta_curves::arithmetic::FieldExt;
use std::marker::PhantomData;
/// The running sum [z_0, z_1, ..., z_W], where z_0 = alpha, and z_W = zero.
pub struct RunningSum<F: FieldExt + PrimeFieldBits>(Vec<CellValue<F>>);
impl<F: FieldExt + PrimeFieldBits> std::ops::Deref for RunningSum<F> {
type Target = Vec<CellValue<F>>;
fn deref(&self) -> &Vec<CellValue<F>> {
&self.0
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct RunningSumConfig<
F: FieldExt + PrimeFieldBits,
const WORD_NUM_BITS: usize,
const WINDOW_NUM_BITS: usize,
const NUM_WINDOWS: usize,
> {
q_range_check: Selector,
q_final_z: Selector,
pub z: Column<Advice>,
perm: Permutation,
_marker: PhantomData<F>,
}
impl<
F: FieldExt + PrimeFieldBits,
const WORD_NUM_BITS: usize,
const WINDOW_NUM_BITS: usize,
const NUM_WINDOWS: usize,
> RunningSumConfig<F, WORD_NUM_BITS, WINDOW_NUM_BITS, NUM_WINDOWS>
{
/// `perm` MUST include the advice column `z`.
///
/// # Panics
///
/// Panics if WINDOW_NUM_BITS > 3.
/// Panics if there are too many windows for the given word size.
pub fn configure(
meta: &mut ConstraintSystem<F>,
q_range_check: Selector,
z: Column<Advice>,
perm: Permutation,
) -> Self {
assert!(WINDOW_NUM_BITS <= 3);
assert!(WINDOW_NUM_BITS * NUM_WINDOWS < WORD_NUM_BITS + WINDOW_NUM_BITS);
let config = Self {
q_range_check,
q_final_z: meta.selector(),
z,
perm,
_marker: PhantomData,
};
meta.create_gate("range check", |meta| {
let q_range_check = meta.query_selector(config.q_range_check);
let z_cur = meta.query_advice(config.z, Rotation::cur());
let z_next = meta.query_advice(config.z, Rotation::next());
// z_i = 2^{K}⋅z_{i + 1} + k_i
// => k_i = z_i - 2^{K}⋅z_{i + 1}
let word = z_cur - z_next * F::from_u64(1 << WINDOW_NUM_BITS);
vec![q_range_check * range_check(word, 1 << WINDOW_NUM_BITS)]
});
meta.create_gate("final z = 0", |meta| {
let q_final_z = meta.query_selector(config.q_final_z);
let z_final = meta.query_advice(config.z, Rotation::cur());
vec![q_final_z * z_final]
});
config
}
/// Decompose a field element alpha that is witnessed in this helper.
///
/// `strict` = true constrains the final running sum to be zero, i.e.
/// constrains alpha to be within WINDOW_NUM_BITS * NUM_WINDOWS bits.
pub fn witness_decompose(
&self,
region: &mut Region<'_, F>,
offset: usize,
alpha: Option<F>,
strict: bool,
) -> Result<(CellValue<F>, RunningSum<F>), Error> {
let z_0 = {
let cell = region.assign_advice(
|| "Witness alpha",
self.z,
offset,
|| alpha.ok_or(Error::SynthesisError),
)?;
CellValue::new(cell, alpha)
};
self.decompose(region, offset, z_0, strict)
}
/// Decompose an existing variable alpha that is copied into this helper.
///
/// `strict` = true constrains the final running sum to be zero, i.e.
/// constrains alpha to be within WINDOW_NUM_BITS * NUM_WINDOWS bits.
pub fn copy_decompose(
&self,
region: &mut Region<'_, F>,
offset: usize,
alpha: CellValue<F>,
strict: bool,
) -> Result<(CellValue<F>, RunningSum<F>), Error> {
let z_0 = copy(region, || "Copy alpha", self.z, offset, &alpha, &self.perm)?;
self.decompose(region, offset, z_0, strict)
}
fn decompose(
&self,
region: &mut Region<'_, F>,
offset: usize,
z_0: CellValue<F>,
strict: bool,
) -> Result<(CellValue<F>, RunningSum<F>), Error> {
// Enable selectors
{
for idx in 0..NUM_WINDOWS {
self.q_range_check.enable(region, offset + idx)?;
}
if strict {
// Constrain the final running sum output to be zero.
self.q_final_z.enable(region, offset + NUM_WINDOWS)?;
}
}
// Decompose base field element into K-bit words.
let words: Vec<Option<u8>> = {
let words = z_0
.value()
.map(|word| decompose_word::<F>(word, WORD_NUM_BITS, WINDOW_NUM_BITS));
if let Some(words) = words {
words.into_iter().map(Some).collect()
} else {
vec![None; NUM_WINDOWS]
}
};
// Initialize empty vector to store running sum values [z_0, ..., z_W].
let mut zs: Vec<CellValue<F>> = Vec::with_capacity(NUM_WINDOWS);
let mut z = z_0;
// Assign running sum `z_i`, i = 0..=n, where z_{i+1} = (z_i - a_i) / (2^K).
// Outside of this helper, z_0 = alpha must have already been loaded into the
// `z` column at `offset`.
let offset = offset + 1;
let two_pow_k_inv = F::from_u64(1 << WINDOW_NUM_BITS as u64).invert().unwrap();
for (idx, word) in words.iter().enumerate() {
// z_next = (z_cur - word) / (2^K)
let z_next = {
let word = word.map(|word| F::from_u64(word as u64));
let z_next_val = z
.value()
.zip(word)
.map(|(z_cur_val, word)| (z_cur_val - word) * two_pow_k_inv);
let cell = region.assign_advice(
|| format!("z_{:?}", idx + 1),
self.z,
offset + idx,
|| z_next_val.ok_or(Error::SynthesisError),
)?;
CellValue::new(cell, z_next_val)
};
// Update `z`.
z = z_next;
zs.push(z);
}
Ok((z_0, RunningSum(zs)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::constants::{
FIXED_BASE_WINDOW_SIZE, L_ORCHARD_BASE, L_VALUE, NUM_WINDOWS, NUM_WINDOWS_SHORT,
};
use halo2::{
circuit::{Layouter, SimpleFloorPlanner},
dev::{MockProver, VerifyFailure},
plonk::{Circuit, ConstraintSystem, Error},
};
use pasta_curves::{arithmetic::FieldExt, pallas};
#[test]
fn test_running_sum() {
struct MyCircuit<
F: FieldExt + PrimeFieldBits,
const WORD_NUM_BITS: usize,
const WINDOW_NUM_BITS: usize,
const NUM_WINDOWS: usize,
> {
alpha: Option<F>,
strict: bool,
}
impl<
F: FieldExt + PrimeFieldBits,
const WORD_NUM_BITS: usize,
const WINDOW_NUM_BITS: usize,
const NUM_WINDOWS: usize,
> Circuit<F> for MyCircuit<F, WORD_NUM_BITS, WINDOW_NUM_BITS, NUM_WINDOWS>
{
type Config = RunningSumConfig<F, WORD_NUM_BITS, WINDOW_NUM_BITS, NUM_WINDOWS>;
type FloorPlanner = SimpleFloorPlanner;
fn without_witnesses(&self) -> Self {
Self {
alpha: None,
strict: self.strict
}
}
fn configure(meta: &mut ConstraintSystem<F>) -> Self::Config {
let z = meta.advice_column();
let q_range_check = meta.selector();
let perm = meta.permutation(&[z.into()]);
RunningSumConfig::<F, WORD_NUM_BITS, WINDOW_NUM_BITS, NUM_WINDOWS>::configure(
meta,
q_range_check,
z,
perm,
)
}
fn synthesize(
&self,
config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
layouter.assign_region(
|| "decompose",
|mut region| {
let offset = 0;
let (alpha, _zs) = config.witness_decompose(
&mut region,
offset,
self.alpha,
self.strict,
)?;
let offset = offset + NUM_WINDOWS + 1;
config.copy_decompose(&mut region, offset, alpha, self.strict)?;
Ok(())
},
)
}
}
// Random base field element
{
let alpha = pallas::Base::rand();
// Strict full decomposition should pass.
let circuit: MyCircuit<
pallas::Base,
L_ORCHARD_BASE,
FIXED_BASE_WINDOW_SIZE,
NUM_WINDOWS,
> = MyCircuit {
alpha: Some(alpha),
strict: true,
};
let prover = MockProver::<pallas::Base>::run(8, &circuit, vec![]).unwrap();
assert_eq!(prover.verify(), Ok(()));
// Strict partial decomposition should fail.
let circuit: MyCircuit<
pallas::Base,
L_ORCHARD_BASE,
FIXED_BASE_WINDOW_SIZE,
NUM_WINDOWS_SHORT,
> = MyCircuit {
alpha: Some(alpha),
strict: true,
};
let prover = MockProver::<pallas::Base>::run(8, &circuit, vec![]).unwrap();
assert_eq!(
prover.verify(),
Err(vec![
VerifyFailure::Constraint {
constraint: ((1, "final z = 0").into(), 0, "").into(),
row: 22
},
VerifyFailure::Constraint {
constraint: ((1, "final z = 0").into(), 0, "").into(),
row: 45
}
])
);
// Non-strict partial decomposition should pass.
let circuit: MyCircuit<
pallas::Base,
L_ORCHARD_BASE,
FIXED_BASE_WINDOW_SIZE,
NUM_WINDOWS_SHORT,
> = MyCircuit {
alpha: Some(alpha),
strict: false,
};
let prover = MockProver::<pallas::Base>::run(8, &circuit, vec![]).unwrap();
assert_eq!(prover.verify(), Ok(()));
}
// Random 64-bit word
{
let alpha = pallas::Base::from_u64(rand::random());
// Strict full decomposition should pass.
let circuit: MyCircuit<
pallas::Base,
L_VALUE,
FIXED_BASE_WINDOW_SIZE,
NUM_WINDOWS_SHORT,
> = MyCircuit {
alpha: Some(alpha),
strict: true,
};
let prover = MockProver::<pallas::Base>::run(8, &circuit, vec![]).unwrap();
assert_eq!(prover.verify(), Ok(()));
}
}
}

View File

@ -1,7 +1,7 @@
use ff::PrimeFieldBits;
use halo2::arithmetic::{CurveAffine, FieldExt};
/// Decompose a scalar into `window_num_bits` bits (little-endian)
/// Decompose a word `alpha` into `window_num_bits` bits (little-endian)
/// For a window size of `w`, this returns [k_0, ..., k_n] where each `k_i`
/// is a `w`-bit value, and `scalar = k_0 + k_1 * w + k_n * w^n`.
///
@ -9,22 +9,22 @@ use halo2::arithmetic::{CurveAffine, FieldExt};
///
/// We are returning a `Vec<u8>` which means the window size is limited to
/// <= 8 bits.
pub fn decompose_scalar_fixed<F: PrimeFieldBits>(
scalar: F,
scalar_num_bits: usize,
pub fn decompose_word<F: PrimeFieldBits>(
word: F,
word_num_bits: usize,
window_num_bits: usize,
) -> Vec<u8> {
assert!(window_num_bits <= 8);
// Pad bits to multiple of window_num_bits
let padding = (window_num_bits - (scalar_num_bits % window_num_bits)) % window_num_bits;
let bits: Vec<bool> = scalar
let padding = (window_num_bits - (word_num_bits % window_num_bits)) % window_num_bits;
let bits: Vec<bool> = word
.to_le_bits()
.into_iter()
.take(scalar_num_bits)
.take(word_num_bits)
.chain(std::iter::repeat(false).take(padding))
.collect();
assert_eq!(bits.len(), scalar_num_bits + padding);
assert_eq!(bits.len(), word_num_bits + padding);
bits.chunks_exact(window_num_bits)
.map(|chunk| chunk.iter().rev().fold(0, |acc, b| (acc << 1) + (*b as u8)))
@ -54,7 +54,7 @@ pub fn gen_const_array<Output: Copy + Default, const LEN: usize>(
#[cfg(test)]
mod tests {
use super::decompose_scalar_fixed;
use super::decompose_word;
use ff::PrimeField;
use pasta_curves::{arithmetic::FieldExt, pallas};
use proptest::prelude::*;
@ -72,12 +72,12 @@ mod tests {
proptest! {
#[test]
fn test_decompose_scalar_fixed(
fn test_decompose_word(
scalar in arb_scalar(),
window_num_bits in 1u8..9
) {
// Get decomposition into `window_num_bits` bits
let decomposed = decompose_scalar_fixed(scalar, pallas::Scalar::NUM_BITS as usize, window_num_bits as usize);
let decomposed = decompose_word(scalar, pallas::Scalar::NUM_BITS as usize, window_num_bits as usize);
// Flatten bits
let bits = decomposed