From a8bd2d6abf878e019ea5589abcedb07dc495257a Mon Sep 17 00:00:00 2001 From: therealyingtong Date: Sat, 10 Jul 2021 23:56:24 +0800 Subject: [PATCH] mul_fixed::short: Copy (magnitude, sign) instead of witnessing Scalar. In the Orchard circuit, the short signed scalar is v_old - v_new, which will be witnessed as two cells: a 64-bit magnitude, and a sign that is +/- 1. --- src/circuit/gadget/ecc.rs | 8 +- src/circuit/gadget/ecc/chip.rs | 8 +- src/circuit/gadget/ecc/chip/mul_fixed.rs | 41 ++- .../gadget/ecc/chip/mul_fixed/short.rs | 327 +++++++++--------- 4 files changed, 189 insertions(+), 195 deletions(-) diff --git a/src/circuit/gadget/ecc.rs b/src/circuit/gadget/ecc.rs index 9441a5e2..773c4278 100644 --- a/src/circuit/gadget/ecc.rs +++ b/src/circuit/gadget/ecc.rs @@ -102,7 +102,7 @@ pub trait EccInstructions: Chip + UtilitiesInstructions fn mul_fixed_short( &self, layouter: &mut impl Layouter, - scalar: Option, + magnitude_sign: (Self::Var, Self::Var), base: &Self::FixedPointsShort, ) -> Result<(Self::Point, Self::ScalarFixedShort), Error>; @@ -271,6 +271,7 @@ impl FixedPoint where EccChip: EccInstructions + Clone + Debug + Eq, { + #[allow(clippy::type_complexity)] /// Returns `[by] self`. pub fn mul( &self, @@ -329,14 +330,15 @@ impl FixedPointShort where EccChip: EccInstructions + Clone + Debug + Eq, { + #[allow(clippy::type_complexity)] /// Returns `[by] self`. pub fn mul( &self, mut layouter: impl Layouter, - by: Option, + magnitude_sign: (EccChip::Var, EccChip::Var), ) -> Result<(Point, ScalarFixedShort), Error> { self.chip - .mul_fixed_short(&mut layouter, by, &self.inner) + .mul_fixed_short(&mut layouter, magnitude_sign, &self.inner) .map(|(point, scalar)| { ( Point { diff --git a/src/circuit/gadget/ecc/chip.rs b/src/circuit/gadget/ecc/chip.rs index c37d3301..616a28e1 100644 --- a/src/circuit/gadget/ecc/chip.rs +++ b/src/circuit/gadget/ecc/chip.rs @@ -294,9 +294,9 @@ pub struct EccScalarFixed { /// k_21 must be a single bit, i.e. 0 or 1. #[derive(Clone, Debug)] pub struct EccScalarFixedShort { - magnitude: Option, + magnitude: CellValue, sign: CellValue, - windows: ArrayVec, { constants::NUM_WINDOWS_SHORT }>, + running_sum: ArrayVec, { constants::NUM_WINDOWS_SHORT }>, } /// A base field element used for fixed-base scalar multiplication. @@ -420,13 +420,13 @@ impl EccInstructions for EccChip { fn mul_fixed_short( &self, layouter: &mut impl Layouter, - scalar: Option, + magnitude_sign: (CellValue, CellValue), base: &Self::FixedPointsShort, ) -> Result<(Self::Point, Self::ScalarFixedShort), Error> { let config: mul_fixed::short::Config = self.config().into(); config.assign( layouter.namespace(|| format!("short fixed-base mul of {:?}", base)), - scalar, + magnitude_sign, base, ) } diff --git a/src/circuit/gadget/ecc/chip/mul_fixed.rs b/src/circuit/gadget/ecc/chip/mul_fixed.rs index 0e4f4d38..57eb58a1 100644 --- a/src/circuit/gadget/ecc/chip/mul_fixed.rs +++ b/src/circuit/gadget/ecc/chip/mul_fixed.rs @@ -510,35 +510,34 @@ impl From<&EccBaseFieldElemFixed> for ScalarFixed { } impl ScalarFixed { - fn windows(&self) -> &[CellValue] { - match self { - ScalarFixed::FullWidth(scalar) => &scalar.windows, - ScalarFixed::Short(scalar) => &scalar.windows, - _ => unreachable!("The base field element is not witnessed as windows."), - } - } - // The scalar decomposition was done in the base field. For computation // outside the circuit, we now convert them back into the scalar field. fn windows_field(&self) -> Vec> { + let running_sum_to_windows = |zs: Vec>| { + (0..(zs.len() - 1)) + .map(|idx| { + let z_cur = zs[idx].value(); + let z_next = zs[idx + 1].value(); + let word = z_cur + .zip(z_next) + .map(|(z_cur, z_next)| z_cur - z_next * *H_BASE); + word.map(|word| pallas::Scalar::from_bytes(&word.to_bytes()).unwrap()) + }) + .collect::>() + }; match self { Self::BaseFieldElem(scalar) => { let mut zs = vec![scalar.base_field_elem]; zs.extend_from_slice(&scalar.running_sum); - - (0..(zs.len() - 1)) - .map(|idx| { - let z_cur = zs[idx].value(); - let z_next = zs[idx + 1].value(); - let word = z_cur - .zip(z_next) - .map(|(z_cur, z_next)| z_cur - z_next * *H_BASE); - word.map(|word| pallas::Scalar::from_bytes(&word.to_bytes()).unwrap()) - }) - .collect::>() + running_sum_to_windows(zs) } - _ => self - .windows() + Self::Short(scalar) => { + let mut zs = vec![scalar.magnitude]; + zs.extend_from_slice(&scalar.running_sum); + running_sum_to_windows(zs) + } + Self::FullWidth(scalar) => scalar + .windows .iter() .map(|bits| { bits.value() diff --git a/src/circuit/gadget/ecc/chip/mul_fixed/short.rs b/src/circuit/gadget/ecc/chip/mul_fixed/short.rs index 15a8624f..85065668 100644 --- a/src/circuit/gadget/ecc/chip/mul_fixed/short.rs +++ b/src/circuit/gadget/ecc/chip/mul_fixed/short.rs @@ -1,7 +1,10 @@ -use std::array; +use std::{array, convert::TryInto}; -use super::super::{copy, CellValue, EccConfig, EccPoint, EccScalarFixedShort, Var}; -use crate::constants::{ValueCommitV, L_VALUE, NUM_WINDOWS_SHORT}; +use super::super::{EccConfig, EccPoint, EccScalarFixedShort}; +use crate::{ + circuit::gadget::utilities::{copy, decompose_running_sum::RunningSumConfig, CellValue, Var}, + constants::{self, ValueCommitV, FIXED_BASE_WINDOW_SIZE, L_VALUE, NUM_WINDOWS_SHORT}, +}; use halo2::{ circuit::{Layouter, Region}, @@ -13,7 +16,13 @@ use pasta_curves::{arithmetic::FieldExt, pallas}; pub struct Config { // Selector used for fixed-base scalar mul with short signed exponent. q_mul_fixed_short: Selector, - q_scalar_fixed_short: Selector, + q_mul_fixed_running_sum: Selector, + running_sum_config: RunningSumConfig< + pallas::Base, + { L_VALUE }, + { FIXED_BASE_WINDOW_SIZE }, + { NUM_WINDOWS_SHORT }, + >, super_config: super::Config, } @@ -21,40 +30,45 @@ impl From<&EccConfig> for Config { fn from(config: &EccConfig) -> Self { Self { q_mul_fixed_short: config.q_mul_fixed_short, - q_scalar_fixed_short: config.q_scalar_fixed_short, + q_mul_fixed_running_sum: config.q_mul_fixed_running_sum, + running_sum_config: config.running_sum_short_config.clone(), super_config: config.into(), } } } impl Config { - // We reuse the constraints in the `mul_fixed` gate so exclude them here. - // Here, we add some new constraints specific to the short signed case. pub(crate) fn create_gate(&self, meta: &mut ConstraintSystem) { - // Check that sign is either 1 or -1. - // Check that last window is either 0 or 1. - meta.create_gate("Check sign and last window", |meta| { - let q_scalar_fixed_short = meta.query_selector(self.q_scalar_fixed_short); - let last_window = meta.query_advice(self.super_config.window, Rotation::prev()); - let sign = meta.query_advice(self.super_config.window, Rotation::cur()); + // Check that each window uses the correct y_p and interpolated x_p. + meta.create_gate("Coordinates check", |meta| { + let q_mul_fixed_running_sum = meta.query_selector(self.q_mul_fixed_running_sum); - let one = Expression::Constant(pallas::Base::one()); + let z_cur = meta.query_advice(self.super_config.window, Rotation::cur()); + let z_next = meta.query_advice(self.super_config.window, Rotation::next()); - let last_window_check = last_window.clone() * (one.clone() - last_window); - let sign_check = sign.clone() * sign - one; + // z_{i+1} = (z_i - a_i) / 2^3 + // => a_i = z_i - z_{i+1} * 2^3 + let word = z_cur - z_next * pallas::Base::from_u64(constants::H as u64); - vec![ - q_scalar_fixed_short.clone() * last_window_check, - q_scalar_fixed_short * sign_check, - ] + self.super_config + .coords_check(meta, q_mul_fixed_running_sum, word) }); meta.create_gate("Short fixed-base mul gate", |meta| { let q_mul_fixed_short = meta.query_selector(self.q_mul_fixed_short); let y_p = meta.query_advice(self.super_config.y_p, Rotation::cur()); let y_a = meta.query_advice(self.super_config.add_config.y_qr, Rotation::cur()); + // z_21 + let last_window = meta.query_advice(self.super_config.u, Rotation::cur()); let sign = meta.query_advice(self.super_config.window, Rotation::cur()); + let one = Expression::Constant(pallas::Base::one()); + + // Check that last window is either 0 or 1. + let last_window_check = last_window.clone() * (one.clone() - last_window); + // Check that sign is either 1 or -1. + let sign_check = sign.clone() * sign.clone() - one; + // `(x_a, y_a)` is the result of `[m]B`, where `m` is the magnitude. // We conditionally negate this result using `y_p = y_a * s`, where `s` is the sign. @@ -66,66 +80,40 @@ impl Config { // Check that the correct sign is witnessed s.t. sign * y_p = y_a let negation_check = sign * y_p - y_a; - array::IntoIter::new([y_check, negation_check]) - .map(move |poly| q_mul_fixed_short.clone() * poly) + array::IntoIter::new([ + ("last_window_check", last_window_check), + ("sign_check", sign_check), + ("y_check", y_check), + ("negation_check", negation_check), + ]) + .map(move |(name, poly)| (name, q_mul_fixed_short.clone() * poly)) }); } - fn witness( + fn decompose( &self, region: &mut Region<'_, pallas::Base>, offset: usize, - value: Option, + magnitude_sign: (CellValue, CellValue), ) -> Result { - // Enable `q_scalar_fixed_short` - self.q_scalar_fixed_short - .enable(region, offset + NUM_WINDOWS_SHORT)?; + let (magnitude, sign) = magnitude_sign; - // Compute the scalar's sign and magnitude - let sign = value.map(|value| { - // t = (p - 1)/2 - let t = (pallas::Scalar::zero() - &pallas::Scalar::one()) * &pallas::Scalar::TWO_INV; - if value > t { - -pallas::Scalar::one() - } else { - pallas::Scalar::one() - } - }); - - let magnitude = sign.zip(value).map(|(sign, value)| sign * &value); - - // Decompose magnitude into `k`-bit windows - let windows = self - .super_config - .decompose_scalar_fixed::(magnitude, offset, region)?; - - // Assign the sign and enable `q_scalar_fixed_short` - let sign = sign.map(|sign| { - assert!(sign == pallas::Scalar::one() || sign == -pallas::Scalar::one()); - if sign == pallas::Scalar::one() { - pallas::Base::one() - } else { - -pallas::Base::one() - } - }); - let sign_cell = region.assign_advice( - || "sign", - self.super_config.window, - offset + NUM_WINDOWS_SHORT, - || sign.ok_or(Error::SynthesisError), - )?; + // Decompose magnitude + let (magnitude, running_sum) = self + .running_sum_config + .copy_decompose(region, offset, magnitude, true)?; Ok(EccScalarFixedShort { magnitude, - sign: CellValue::::new(sign_cell, sign), - windows, + sign, + running_sum: (*running_sum).as_slice().try_into().unwrap(), }) } pub fn assign( &self, mut layouter: impl Layouter, - scalar: Option, + magnitude_sign: (CellValue, CellValue), base: &ValueCommitV, ) -> Result<(EccPoint, EccScalarFixedShort), Error> { let (scalar, acc, mul_b) = layouter.assign_region( @@ -133,8 +121,8 @@ impl Config { |mut region| { let offset = 0; - // Copy the scalar decomposition - let scalar = self.witness(&mut region, offset, scalar)?; + // Decompose the scalar + let scalar = self.decompose(&mut region, offset, magnitude_sign)?; let (acc, mul_b) = self.super_config.assign_region_inner( &mut region, @@ -148,6 +136,7 @@ impl Config { }, )?; + // Last window let result = layouter.assign_region( || "Short fixed-base mul (most significant word)", |mut region| { @@ -163,7 +152,7 @@ impl Config { // Increase offset by 1 after complete addition let offset = offset + 1; - // Assign sign to `window` column + // Copy sign to `window` column let sign = copy( &mut region, || "sign", @@ -173,6 +162,17 @@ impl Config { &self.super_config.perm, )?; + // Copy last window to `u` column + let z_21 = scalar.running_sum[20]; + copy( + &mut region, + || "last_window", + self.super_config.u, + offset, + &z_21, + &self.super_config.perm, + )?; + // Conditionally negate `y`-coordinate let y_val = if let Some(sign) = sign.value() { if sign == -pallas::Base::one() { @@ -209,19 +209,25 @@ impl Config { let base: super::OrchardFixedBases = base.clone().into(); - let scalar = scalar - .magnitude - .zip(scalar.sign.value()) - .map(|(magnitude, sign)| { - let sign = if sign == pallas::Base::one() { - pallas::Scalar::one() - } else if sign == -pallas::Base::one() { - -pallas::Scalar::one() - } else { - panic!("Sign should be 1 or -1.") - }; - magnitude * sign - }); + let scalar = + scalar + .magnitude + .value() + .zip(scalar.sign.value()) + .map(|(magnitude, sign)| { + // Move magnitude from base field into scalar field (which always fits + // for Pallas). + let magnitude = pallas::Scalar::from_bytes(&magnitude.to_bytes()).unwrap(); + + let sign = if sign == pallas::Base::one() { + pallas::Scalar::one() + } else if sign == -pallas::Base::one() { + -pallas::Scalar::one() + } else { + panic!("Sign should be 1 or -1.") + }; + magnitude * sign + }); let real_mul = scalar.map(|scalar| base.generator() * scalar); let result = result.point(); @@ -237,10 +243,16 @@ impl Config { #[cfg(test)] pub mod tests { use group::Curve; - use halo2::{circuit::Layouter, plonk::Error}; + use halo2::{ + circuit::{Chip, Layouter}, + plonk::Error, + }; use pasta_curves::{arithmetic::FieldExt, pallas}; - use crate::circuit::gadget::ecc::{chip::EccChip, FixedPointShort, Point}; + use crate::circuit::gadget::{ + ecc::{chip::EccChip, FixedPointShort, Point}, + utilities::{CellValue, UtilitiesInstructions}, + }; use crate::constants::load::ValueCommitV; #[allow(clippy::op_ref)] @@ -253,6 +265,20 @@ pub mod tests { let base_val = value_commit_v.generator; let value_commit_v = FixedPointShort::from_inner(chip.clone(), value_commit_v); + fn load_magnitude_sign( + chip: EccChip, + mut layouter: impl Layouter, + magnitude: pallas::Base, + sign: pallas::Base, + ) -> Result<(CellValue, CellValue), Error> { + let column = chip.config().advices[0]; + let magnitude = + chip.load_private(layouter.namespace(|| "magnitude"), column, Some(magnitude))?; + let sign = chip.load_private(layouter.namespace(|| "sign"), column, Some(sign))?; + + Ok((magnitude, sign)) + } + fn constrain_equal( chip: EccChip, mut layouter: impl Layouter, @@ -268,96 +294,63 @@ pub mod tests { result.constrain_equal(layouter.namespace(|| "constrain result"), &expected) } - // [0]B should return (0,0) since it uses complete addition - // on the last step. - { - let scalar_fixed_short = pallas::Scalar::zero(); - let (result, _) = value_commit_v.mul( - layouter.namespace(|| "mul by zero"), - Some(scalar_fixed_short), - )?; + let mut random_sign = pallas::Base::one(); + if rand::random::() { + random_sign = -random_sign; + } + let magnitude_signs = [ + ("mul by +zero", pallas::Base::zero(), pallas::Base::one()), + ("mul by -zero", pallas::Base::zero(), -pallas::Base::one()), + ( + "random [a]B", + pallas::Base::from_u64(rand::random::()), + random_sign, + ), + ( + "[2^64 - 1]B", + pallas::Base::from_u64(0xFFFF_FFFF_FFFF_FFFFu64), + pallas::Base::one(), + ), + ( + "-[2^64 - 1]B", + pallas::Base::from_u64(0xFFFF_FFFF_FFFF_FFFFu64), + -pallas::Base::one(), + ), + // There is a single canonical sequence of window values for which a doubling occurs on the last step: + // 1333333333333333333334 in octal. + // [0xB6DB_6DB6_DB6D_B6DC] B + ( + "mul_with_double", + pallas::Base::from_u64(0xB6DB_6DB6_DB6D_B6DCu64), + pallas::Base::one(), + ), + ]; + for (name, magnitude, sign) in magnitude_signs.iter() { + let (result, _) = { + let magnitude_sign = load_magnitude_sign( + chip.clone(), + layouter.namespace(|| *name), + *magnitude, + *sign, + )?; + value_commit_v.mul(layouter.namespace(|| *name), magnitude_sign)? + }; + // Move from base field into scalar field + let scalar = { + let magnitude = pallas::Scalar::from_bytes(&magnitude.to_bytes()).unwrap(); + let sign = if *sign == pallas::Base::one() { + pallas::Scalar::one() + } else { + -pallas::Scalar::one() + }; + magnitude * sign + }; constrain_equal( chip.clone(), - layouter.namespace(|| "mul by zero"), + layouter.namespace(|| *name), base_val, - scalar_fixed_short, - result, - )?; - } - - // Random [a]B - { - let scalar_fixed_short = pallas::Scalar::from_u64(rand::random::()); - let mut sign = pallas::Scalar::one(); - if rand::random::() { - sign = -sign; - } - let scalar_fixed_short = sign * &scalar_fixed_short; - let (result, _) = value_commit_v.mul( - layouter.namespace(|| "random short scalar"), - Some(scalar_fixed_short), - )?; - - constrain_equal( - chip.clone(), - layouter.namespace(|| "random [a]B"), - base_val, - scalar_fixed_short, - result, - )?; - } - - // [2^64 - 1]B - { - let scalar_fixed_short = pallas::Scalar::from_u64(0xFFFF_FFFF_FFFF_FFFFu64); - let (result, _) = value_commit_v.mul( - layouter.namespace(|| "[2^64 - 1]B"), - Some(scalar_fixed_short), - )?; - - constrain_equal( - chip.clone(), - layouter.namespace(|| "[2^64 - 1]B"), - base_val, - scalar_fixed_short, - result, - )?; - } - - // [-(2^64 - 1)]B - { - let scalar_fixed_short = -pallas::Scalar::from_u64(0xFFFF_FFFF_FFFF_FFFFu64); - let (result, _) = value_commit_v.mul( - layouter.namespace(|| "-[2^64 - 1]B"), - Some(scalar_fixed_short), - )?; - - constrain_equal( - chip.clone(), - layouter.namespace(|| "[-2^64 - 1]B"), - base_val, - scalar_fixed_short, - result, - )?; - } - - // There is a single canonical sequence of window values for which a doubling occurs on the last step: - // 1333333333333333333334 in octal. - // [0xB6DB_6DB6_DB6D_B6DC] B - { - let scalar_fixed_short = pallas::Scalar::from_u64(0xB6DB_6DB6_DB6D_B6DCu64); - - let (result, _) = value_commit_v.mul( - layouter.namespace(|| "mul with double"), - Some(scalar_fixed_short), - )?; - - constrain_equal( - chip, - layouter.namespace(|| "mul with double"), - base_val, - scalar_fixed_short, + scalar, result, )?; }