Move decompose_word into from constants::util into gadget::utilities.

This helper is not used outside of the gadget.
This commit is contained in:
therealyingtong 2021-08-19 16:41:31 +08:00
parent e3aad46785
commit 76431eefad
4 changed files with 75 additions and 83 deletions

View File

@ -1,8 +1,8 @@
use super::super::{EccConfig, EccPoint, EccScalarFixed, FixedPoints};
use crate::{
circuit::gadget::utilities::{range_check, CellValue, Var},
constants::{self, util, L_ORCHARD_SCALAR, NUM_WINDOWS},
circuit::gadget::utilities::{decompose_word, range_check, CellValue, Var},
constants::{self, L_ORCHARD_SCALAR, NUM_WINDOWS},
};
use arrayvec::ArrayVec;
use halo2::{
@ -78,7 +78,7 @@ impl<Fixed: FixedPoints<pallas::Affine>> Config<Fixed> {
// Decompose scalar into `k-bit` windows
let scalar_windows: Option<Vec<u8>> = scalar.map(|scalar| {
util::decompose_word::<pallas::Scalar>(
decompose_word::<pallas::Scalar>(
scalar,
SCALAR_NUM_BITS,
constants::FIXED_BASE_WINDOW_SIZE,

View File

@ -135,6 +135,36 @@ pub fn range_check<F: FieldExt>(word: Expression<F>, range: usize) -> Expression
})
}
/// Decompose a word `alpha` into `window_num_bits` bits (little-endian)
/// For a window size of `w`, this returns [k_0, ..., k_n] where each `k_i`
/// is a `w`-bit value, and `scalar = k_0 + k_1 * w + k_n * w^n`.
///
/// # Panics
///
/// We are returning a `Vec<u8>` which means the window size is limited to
/// <= 8 bits.
pub fn decompose_word<F: PrimeFieldBits>(
word: F,
word_num_bits: usize,
window_num_bits: usize,
) -> Vec<u8> {
assert!(window_num_bits <= 8);
// Pad bits to multiple of window_num_bits
let padding = (window_num_bits - (word_num_bits % window_num_bits)) % window_num_bits;
let bits: Vec<bool> = word
.to_le_bits()
.into_iter()
.take(word_num_bits)
.chain(std::iter::repeat(false).take(padding))
.collect();
assert_eq!(bits.len(), word_num_bits + padding);
bits.chunks_exact(window_num_bits)
.map(|chunk| chunk.iter().rev().fold(0, |acc, b| (acc << 1) + (*b as u8)))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
@ -146,7 +176,10 @@ mod tests {
plonk::{Circuit, ConstraintSystem, Error, Selector},
poly::Rotation,
};
use pasta_curves::pallas;
use pasta_curves::{arithmetic::FieldExt, pallas};
use proptest::prelude::*;
use std::convert::TryInto;
use std::iter;
#[test]
fn test_range_check() {
@ -296,4 +329,40 @@ mod tests {
&[0..50, 50..100, 100..150, 150..200, 200..255],
);
}
prop_compose! {
fn arb_scalar()(bytes in prop::array::uniform32(0u8..)) -> pallas::Scalar {
// Instead of rejecting out-of-range bytes, let's reduce them.
let mut buf = [0; 64];
buf[..32].copy_from_slice(&bytes);
pallas::Scalar::from_bytes_wide(&buf)
}
}
proptest! {
#[test]
fn test_decompose_word(
scalar in arb_scalar(),
window_num_bits in 1u8..9
) {
// Get decomposition into `window_num_bits` bits
let decomposed = decompose_word(scalar, pallas::Scalar::NUM_BITS as usize, window_num_bits as usize);
// Flatten bits
let bits = decomposed
.iter()
.flat_map(|window| (0..window_num_bits).map(move |mask| (window & (1 << mask)) != 0));
// Ensure this decomposition contains 256 or fewer set bits.
assert!(!bits.clone().skip(32*8).any(|b| b));
// Pad or truncate bits to 32 bytes
let bits: Vec<bool> = bits.chain(iter::repeat(false)).take(32*8).collect();
let bytes: Vec<u8> = bits.chunks_exact(8).map(|chunk| chunk.iter().rev().fold(0, |acc, b| (acc << 1) + (*b as u8))).collect();
// Check that original scalar is recovered from decomposition
assert_eq!(scalar, pallas::Scalar::from_bytes(&bytes.try_into().unwrap()).unwrap());
}
}
}

View File

@ -29,8 +29,7 @@ use halo2::{
poly::Rotation,
};
use super::{copy, range_check, CellValue, Var};
use crate::constants::util::decompose_word;
use super::{copy, decompose_word, range_check, CellValue, Var};
use pasta_curves::arithmetic::FieldExt;
use std::marker::PhantomData;

View File

@ -1,36 +1,6 @@
use ff::{Field, PrimeFieldBits};
use ff::Field;
use halo2::arithmetic::{CurveAffine, FieldExt};
/// Decompose a word `alpha` into `window_num_bits` bits (little-endian)
/// For a window size of `w`, this returns [k_0, ..., k_n] where each `k_i`
/// is a `w`-bit value, and `scalar = k_0 + k_1 * w + k_n * w^n`.
///
/// # Panics
///
/// We are returning a `Vec<u8>` which means the window size is limited to
/// <= 8 bits.
pub fn decompose_word<F: PrimeFieldBits>(
word: F,
word_num_bits: usize,
window_num_bits: usize,
) -> Vec<u8> {
assert!(window_num_bits <= 8);
// Pad bits to multiple of window_num_bits
let padding = (window_num_bits - (word_num_bits % window_num_bits)) % window_num_bits;
let bits: Vec<bool> = word
.to_le_bits()
.into_iter()
.take(word_num_bits)
.chain(std::iter::repeat(false).take(padding))
.collect();
assert_eq!(bits.len(), word_num_bits + padding);
bits.chunks_exact(window_num_bits)
.map(|chunk| chunk.iter().rev().fold(0, |acc, b| (acc << 1) + (*b as u8)))
.collect()
}
/// Evaluate y = f(x) given the coefficients of f(x)
pub fn evaluate<C: CurveAffine>(x: u8, coeffs: &[C::Base]) -> C::Base {
let x = C::Base::from_u64(x as u64);
@ -60,49 +30,3 @@ pub(crate) fn gen_const_array_with_default<Output: Copy, const LEN: usize>(
}
ret
}
#[cfg(test)]
mod tests {
use super::decompose_word;
use ff::PrimeField;
use pasta_curves::{arithmetic::FieldExt, pallas};
use proptest::prelude::*;
use std::convert::TryInto;
use std::iter;
prop_compose! {
fn arb_scalar()(bytes in prop::array::uniform32(0u8..)) -> pallas::Scalar {
// Instead of rejecting out-of-range bytes, let's reduce them.
let mut buf = [0; 64];
buf[..32].copy_from_slice(&bytes);
pallas::Scalar::from_bytes_wide(&buf)
}
}
proptest! {
#[test]
fn test_decompose_word(
scalar in arb_scalar(),
window_num_bits in 1u8..9
) {
// Get decomposition into `window_num_bits` bits
let decomposed = decompose_word(scalar, pallas::Scalar::NUM_BITS as usize, window_num_bits as usize);
// Flatten bits
let bits = decomposed
.iter()
.flat_map(|window| (0..window_num_bits).map(move |mask| (window & (1 << mask)) != 0));
// Ensure this decomposition contains 256 or fewer set bits.
assert!(!bits.clone().skip(32*8).any(|b| b));
// Pad or truncate bits to 32 bytes
let bits: Vec<bool> = bits.chain(iter::repeat(false)).take(32*8).collect();
let bytes: Vec<u8> = bits.chunks_exact(8).map(|chunk| chunk.iter().rev().fold(0, |acc, b| (acc << 1) + (*b as u8))).collect();
// Check that original scalar is recovered from decomposition
assert_eq!(scalar, pallas::Scalar::from_bytes(&bytes.try_into().unwrap()).unwrap());
}
}
}