mul_fixed::short: Copy (magnitude, sign) instead of witnessing Scalar.

In the Orchard circuit, the short signed scalar is v_old - v_new,
which will be witnessed as two cells: a 64-bit magnitude, and a
sign that is +/- 1.
This commit is contained in:
therealyingtong 2021-07-10 23:56:24 +08:00
parent 426f954b1d
commit a8bd2d6abf
4 changed files with 189 additions and 195 deletions

View File

@ -102,7 +102,7 @@ pub trait EccInstructions<C: CurveAffine>: Chip<C::Base> + UtilitiesInstructions
fn mul_fixed_short(
&self,
layouter: &mut impl Layouter<C::Base>,
scalar: Option<C::Scalar>,
magnitude_sign: (Self::Var, Self::Var),
base: &Self::FixedPointsShort,
) -> Result<(Self::Point, Self::ScalarFixedShort), Error>;
@ -271,6 +271,7 @@ impl<C: CurveAffine, EccChip> FixedPoint<C, EccChip>
where
EccChip: EccInstructions<C> + Clone + Debug + Eq,
{
#[allow(clippy::type_complexity)]
/// Returns `[by] self`.
pub fn mul(
&self,
@ -329,14 +330,15 @@ impl<C: CurveAffine, EccChip> FixedPointShort<C, EccChip>
where
EccChip: EccInstructions<C> + Clone + Debug + Eq,
{
#[allow(clippy::type_complexity)]
/// Returns `[by] self`.
pub fn mul(
&self,
mut layouter: impl Layouter<C::Base>,
by: Option<C::Scalar>,
magnitude_sign: (EccChip::Var, EccChip::Var),
) -> Result<(Point<C, EccChip>, ScalarFixedShort<C, EccChip>), Error> {
self.chip
.mul_fixed_short(&mut layouter, by, &self.inner)
.mul_fixed_short(&mut layouter, magnitude_sign, &self.inner)
.map(|(point, scalar)| {
(
Point {

View File

@ -294,9 +294,9 @@ pub struct EccScalarFixed {
/// k_21 must be a single bit, i.e. 0 or 1.
#[derive(Clone, Debug)]
pub struct EccScalarFixedShort {
magnitude: Option<pallas::Scalar>,
magnitude: CellValue<pallas::Base>,
sign: CellValue<pallas::Base>,
windows: ArrayVec<CellValue<pallas::Base>, { constants::NUM_WINDOWS_SHORT }>,
running_sum: ArrayVec<CellValue<pallas::Base>, { constants::NUM_WINDOWS_SHORT }>,
}
/// A base field element used for fixed-base scalar multiplication.
@ -420,13 +420,13 @@ impl EccInstructions<pallas::Affine> for EccChip {
fn mul_fixed_short(
&self,
layouter: &mut impl Layouter<pallas::Base>,
scalar: Option<pallas::Scalar>,
magnitude_sign: (CellValue<pallas::Base>, CellValue<pallas::Base>),
base: &Self::FixedPointsShort,
) -> Result<(Self::Point, Self::ScalarFixedShort), Error> {
let config: mul_fixed::short::Config = self.config().into();
config.assign(
layouter.namespace(|| format!("short fixed-base mul of {:?}", base)),
scalar,
magnitude_sign,
base,
)
}

View File

@ -510,35 +510,34 @@ impl From<&EccBaseFieldElemFixed> for ScalarFixed {
}
impl ScalarFixed {
fn windows(&self) -> &[CellValue<pallas::Base>] {
match self {
ScalarFixed::FullWidth(scalar) => &scalar.windows,
ScalarFixed::Short(scalar) => &scalar.windows,
_ => unreachable!("The base field element is not witnessed as windows."),
}
}
// The scalar decomposition was done in the base field. For computation
// outside the circuit, we now convert them back into the scalar field.
fn windows_field(&self) -> Vec<Option<pallas::Scalar>> {
let running_sum_to_windows = |zs: Vec<CellValue<pallas::Base>>| {
(0..(zs.len() - 1))
.map(|idx| {
let z_cur = zs[idx].value();
let z_next = zs[idx + 1].value();
let word = z_cur
.zip(z_next)
.map(|(z_cur, z_next)| z_cur - z_next * *H_BASE);
word.map(|word| pallas::Scalar::from_bytes(&word.to_bytes()).unwrap())
})
.collect::<Vec<_>>()
};
match self {
Self::BaseFieldElem(scalar) => {
let mut zs = vec![scalar.base_field_elem];
zs.extend_from_slice(&scalar.running_sum);
(0..(zs.len() - 1))
.map(|idx| {
let z_cur = zs[idx].value();
let z_next = zs[idx + 1].value();
let word = z_cur
.zip(z_next)
.map(|(z_cur, z_next)| z_cur - z_next * *H_BASE);
word.map(|word| pallas::Scalar::from_bytes(&word.to_bytes()).unwrap())
})
.collect::<Vec<_>>()
running_sum_to_windows(zs)
}
_ => self
.windows()
Self::Short(scalar) => {
let mut zs = vec![scalar.magnitude];
zs.extend_from_slice(&scalar.running_sum);
running_sum_to_windows(zs)
}
Self::FullWidth(scalar) => scalar
.windows
.iter()
.map(|bits| {
bits.value()

View File

@ -1,7 +1,10 @@
use std::array;
use std::{array, convert::TryInto};
use super::super::{copy, CellValue, EccConfig, EccPoint, EccScalarFixedShort, Var};
use crate::constants::{ValueCommitV, L_VALUE, NUM_WINDOWS_SHORT};
use super::super::{EccConfig, EccPoint, EccScalarFixedShort};
use crate::{
circuit::gadget::utilities::{copy, decompose_running_sum::RunningSumConfig, CellValue, Var},
constants::{self, ValueCommitV, FIXED_BASE_WINDOW_SIZE, L_VALUE, NUM_WINDOWS_SHORT},
};
use halo2::{
circuit::{Layouter, Region},
@ -13,7 +16,13 @@ use pasta_curves::{arithmetic::FieldExt, pallas};
pub struct Config {
// Selector used for fixed-base scalar mul with short signed exponent.
q_mul_fixed_short: Selector,
q_scalar_fixed_short: Selector,
q_mul_fixed_running_sum: Selector,
running_sum_config: RunningSumConfig<
pallas::Base,
{ L_VALUE },
{ FIXED_BASE_WINDOW_SIZE },
{ NUM_WINDOWS_SHORT },
>,
super_config: super::Config<NUM_WINDOWS_SHORT>,
}
@ -21,40 +30,45 @@ impl From<&EccConfig> for Config {
fn from(config: &EccConfig) -> Self {
Self {
q_mul_fixed_short: config.q_mul_fixed_short,
q_scalar_fixed_short: config.q_scalar_fixed_short,
q_mul_fixed_running_sum: config.q_mul_fixed_running_sum,
running_sum_config: config.running_sum_short_config.clone(),
super_config: config.into(),
}
}
}
impl Config {
// We reuse the constraints in the `mul_fixed` gate so exclude them here.
// Here, we add some new constraints specific to the short signed case.
pub(crate) fn create_gate(&self, meta: &mut ConstraintSystem<pallas::Base>) {
// Check that sign is either 1 or -1.
// Check that last window is either 0 or 1.
meta.create_gate("Check sign and last window", |meta| {
let q_scalar_fixed_short = meta.query_selector(self.q_scalar_fixed_short);
let last_window = meta.query_advice(self.super_config.window, Rotation::prev());
let sign = meta.query_advice(self.super_config.window, Rotation::cur());
// Check that each window uses the correct y_p and interpolated x_p.
meta.create_gate("Coordinates check", |meta| {
let q_mul_fixed_running_sum = meta.query_selector(self.q_mul_fixed_running_sum);
let one = Expression::Constant(pallas::Base::one());
let z_cur = meta.query_advice(self.super_config.window, Rotation::cur());
let z_next = meta.query_advice(self.super_config.window, Rotation::next());
let last_window_check = last_window.clone() * (one.clone() - last_window);
let sign_check = sign.clone() * sign - one;
// z_{i+1} = (z_i - a_i) / 2^3
// => a_i = z_i - z_{i+1} * 2^3
let word = z_cur - z_next * pallas::Base::from_u64(constants::H as u64);
vec![
q_scalar_fixed_short.clone() * last_window_check,
q_scalar_fixed_short * sign_check,
]
self.super_config
.coords_check(meta, q_mul_fixed_running_sum, word)
});
meta.create_gate("Short fixed-base mul gate", |meta| {
let q_mul_fixed_short = meta.query_selector(self.q_mul_fixed_short);
let y_p = meta.query_advice(self.super_config.y_p, Rotation::cur());
let y_a = meta.query_advice(self.super_config.add_config.y_qr, Rotation::cur());
// z_21
let last_window = meta.query_advice(self.super_config.u, Rotation::cur());
let sign = meta.query_advice(self.super_config.window, Rotation::cur());
let one = Expression::Constant(pallas::Base::one());
// Check that last window is either 0 or 1.
let last_window_check = last_window.clone() * (one.clone() - last_window);
// Check that sign is either 1 or -1.
let sign_check = sign.clone() * sign.clone() - one;
// `(x_a, y_a)` is the result of `[m]B`, where `m` is the magnitude.
// We conditionally negate this result using `y_p = y_a * s`, where `s` is the sign.
@ -66,66 +80,40 @@ impl Config {
// Check that the correct sign is witnessed s.t. sign * y_p = y_a
let negation_check = sign * y_p - y_a;
array::IntoIter::new([y_check, negation_check])
.map(move |poly| q_mul_fixed_short.clone() * poly)
array::IntoIter::new([
("last_window_check", last_window_check),
("sign_check", sign_check),
("y_check", y_check),
("negation_check", negation_check),
])
.map(move |(name, poly)| (name, q_mul_fixed_short.clone() * poly))
});
}
fn witness(
fn decompose(
&self,
region: &mut Region<'_, pallas::Base>,
offset: usize,
value: Option<pallas::Scalar>,
magnitude_sign: (CellValue<pallas::Base>, CellValue<pallas::Base>),
) -> Result<EccScalarFixedShort, Error> {
// Enable `q_scalar_fixed_short`
self.q_scalar_fixed_short
.enable(region, offset + NUM_WINDOWS_SHORT)?;
let (magnitude, sign) = magnitude_sign;
// Compute the scalar's sign and magnitude
let sign = value.map(|value| {
// t = (p - 1)/2
let t = (pallas::Scalar::zero() - &pallas::Scalar::one()) * &pallas::Scalar::TWO_INV;
if value > t {
-pallas::Scalar::one()
} else {
pallas::Scalar::one()
}
});
let magnitude = sign.zip(value).map(|(sign, value)| sign * &value);
// Decompose magnitude into `k`-bit windows
let windows = self
.super_config
.decompose_scalar_fixed::<L_VALUE>(magnitude, offset, region)?;
// Assign the sign and enable `q_scalar_fixed_short`
let sign = sign.map(|sign| {
assert!(sign == pallas::Scalar::one() || sign == -pallas::Scalar::one());
if sign == pallas::Scalar::one() {
pallas::Base::one()
} else {
-pallas::Base::one()
}
});
let sign_cell = region.assign_advice(
|| "sign",
self.super_config.window,
offset + NUM_WINDOWS_SHORT,
|| sign.ok_or(Error::SynthesisError),
)?;
// Decompose magnitude
let (magnitude, running_sum) = self
.running_sum_config
.copy_decompose(region, offset, magnitude, true)?;
Ok(EccScalarFixedShort {
magnitude,
sign: CellValue::<pallas::Base>::new(sign_cell, sign),
windows,
sign,
running_sum: (*running_sum).as_slice().try_into().unwrap(),
})
}
pub fn assign(
&self,
mut layouter: impl Layouter<pallas::Base>,
scalar: Option<pallas::Scalar>,
magnitude_sign: (CellValue<pallas::Base>, CellValue<pallas::Base>),
base: &ValueCommitV,
) -> Result<(EccPoint, EccScalarFixedShort), Error> {
let (scalar, acc, mul_b) = layouter.assign_region(
@ -133,8 +121,8 @@ impl Config {
|mut region| {
let offset = 0;
// Copy the scalar decomposition
let scalar = self.witness(&mut region, offset, scalar)?;
// Decompose the scalar
let scalar = self.decompose(&mut region, offset, magnitude_sign)?;
let (acc, mul_b) = self.super_config.assign_region_inner(
&mut region,
@ -148,6 +136,7 @@ impl Config {
},
)?;
// Last window
let result = layouter.assign_region(
|| "Short fixed-base mul (most significant word)",
|mut region| {
@ -163,7 +152,7 @@ impl Config {
// Increase offset by 1 after complete addition
let offset = offset + 1;
// Assign sign to `window` column
// Copy sign to `window` column
let sign = copy(
&mut region,
|| "sign",
@ -173,6 +162,17 @@ impl Config {
&self.super_config.perm,
)?;
// Copy last window to `u` column
let z_21 = scalar.running_sum[20];
copy(
&mut region,
|| "last_window",
self.super_config.u,
offset,
&z_21,
&self.super_config.perm,
)?;
// Conditionally negate `y`-coordinate
let y_val = if let Some(sign) = sign.value() {
if sign == -pallas::Base::one() {
@ -209,19 +209,25 @@ impl Config {
let base: super::OrchardFixedBases = base.clone().into();
let scalar = scalar
.magnitude
.zip(scalar.sign.value())
.map(|(magnitude, sign)| {
let sign = if sign == pallas::Base::one() {
pallas::Scalar::one()
} else if sign == -pallas::Base::one() {
-pallas::Scalar::one()
} else {
panic!("Sign should be 1 or -1.")
};
magnitude * sign
});
let scalar =
scalar
.magnitude
.value()
.zip(scalar.sign.value())
.map(|(magnitude, sign)| {
// Move magnitude from base field into scalar field (which always fits
// for Pallas).
let magnitude = pallas::Scalar::from_bytes(&magnitude.to_bytes()).unwrap();
let sign = if sign == pallas::Base::one() {
pallas::Scalar::one()
} else if sign == -pallas::Base::one() {
-pallas::Scalar::one()
} else {
panic!("Sign should be 1 or -1.")
};
magnitude * sign
});
let real_mul = scalar.map(|scalar| base.generator() * scalar);
let result = result.point();
@ -237,10 +243,16 @@ impl Config {
#[cfg(test)]
pub mod tests {
use group::Curve;
use halo2::{circuit::Layouter, plonk::Error};
use halo2::{
circuit::{Chip, Layouter},
plonk::Error,
};
use pasta_curves::{arithmetic::FieldExt, pallas};
use crate::circuit::gadget::ecc::{chip::EccChip, FixedPointShort, Point};
use crate::circuit::gadget::{
ecc::{chip::EccChip, FixedPointShort, Point},
utilities::{CellValue, UtilitiesInstructions},
};
use crate::constants::load::ValueCommitV;
#[allow(clippy::op_ref)]
@ -253,6 +265,20 @@ pub mod tests {
let base_val = value_commit_v.generator;
let value_commit_v = FixedPointShort::from_inner(chip.clone(), value_commit_v);
fn load_magnitude_sign(
chip: EccChip,
mut layouter: impl Layouter<pallas::Base>,
magnitude: pallas::Base,
sign: pallas::Base,
) -> Result<(CellValue<pallas::Base>, CellValue<pallas::Base>), Error> {
let column = chip.config().advices[0];
let magnitude =
chip.load_private(layouter.namespace(|| "magnitude"), column, Some(magnitude))?;
let sign = chip.load_private(layouter.namespace(|| "sign"), column, Some(sign))?;
Ok((magnitude, sign))
}
fn constrain_equal(
chip: EccChip,
mut layouter: impl Layouter<pallas::Base>,
@ -268,96 +294,63 @@ pub mod tests {
result.constrain_equal(layouter.namespace(|| "constrain result"), &expected)
}
// [0]B should return (0,0) since it uses complete addition
// on the last step.
{
let scalar_fixed_short = pallas::Scalar::zero();
let (result, _) = value_commit_v.mul(
layouter.namespace(|| "mul by zero"),
Some(scalar_fixed_short),
)?;
let mut random_sign = pallas::Base::one();
if rand::random::<bool>() {
random_sign = -random_sign;
}
let magnitude_signs = [
("mul by +zero", pallas::Base::zero(), pallas::Base::one()),
("mul by -zero", pallas::Base::zero(), -pallas::Base::one()),
(
"random [a]B",
pallas::Base::from_u64(rand::random::<u64>()),
random_sign,
),
(
"[2^64 - 1]B",
pallas::Base::from_u64(0xFFFF_FFFF_FFFF_FFFFu64),
pallas::Base::one(),
),
(
"-[2^64 - 1]B",
pallas::Base::from_u64(0xFFFF_FFFF_FFFF_FFFFu64),
-pallas::Base::one(),
),
// There is a single canonical sequence of window values for which a doubling occurs on the last step:
// 1333333333333333333334 in octal.
// [0xB6DB_6DB6_DB6D_B6DC] B
(
"mul_with_double",
pallas::Base::from_u64(0xB6DB_6DB6_DB6D_B6DCu64),
pallas::Base::one(),
),
];
for (name, magnitude, sign) in magnitude_signs.iter() {
let (result, _) = {
let magnitude_sign = load_magnitude_sign(
chip.clone(),
layouter.namespace(|| *name),
*magnitude,
*sign,
)?;
value_commit_v.mul(layouter.namespace(|| *name), magnitude_sign)?
};
// Move from base field into scalar field
let scalar = {
let magnitude = pallas::Scalar::from_bytes(&magnitude.to_bytes()).unwrap();
let sign = if *sign == pallas::Base::one() {
pallas::Scalar::one()
} else {
-pallas::Scalar::one()
};
magnitude * sign
};
constrain_equal(
chip.clone(),
layouter.namespace(|| "mul by zero"),
layouter.namespace(|| *name),
base_val,
scalar_fixed_short,
result,
)?;
}
// Random [a]B
{
let scalar_fixed_short = pallas::Scalar::from_u64(rand::random::<u64>());
let mut sign = pallas::Scalar::one();
if rand::random::<bool>() {
sign = -sign;
}
let scalar_fixed_short = sign * &scalar_fixed_short;
let (result, _) = value_commit_v.mul(
layouter.namespace(|| "random short scalar"),
Some(scalar_fixed_short),
)?;
constrain_equal(
chip.clone(),
layouter.namespace(|| "random [a]B"),
base_val,
scalar_fixed_short,
result,
)?;
}
// [2^64 - 1]B
{
let scalar_fixed_short = pallas::Scalar::from_u64(0xFFFF_FFFF_FFFF_FFFFu64);
let (result, _) = value_commit_v.mul(
layouter.namespace(|| "[2^64 - 1]B"),
Some(scalar_fixed_short),
)?;
constrain_equal(
chip.clone(),
layouter.namespace(|| "[2^64 - 1]B"),
base_val,
scalar_fixed_short,
result,
)?;
}
// [-(2^64 - 1)]B
{
let scalar_fixed_short = -pallas::Scalar::from_u64(0xFFFF_FFFF_FFFF_FFFFu64);
let (result, _) = value_commit_v.mul(
layouter.namespace(|| "-[2^64 - 1]B"),
Some(scalar_fixed_short),
)?;
constrain_equal(
chip.clone(),
layouter.namespace(|| "[-2^64 - 1]B"),
base_val,
scalar_fixed_short,
result,
)?;
}
// There is a single canonical sequence of window values for which a doubling occurs on the last step:
// 1333333333333333333334 in octal.
// [0xB6DB_6DB6_DB6D_B6DC] B
{
let scalar_fixed_short = pallas::Scalar::from_u64(0xB6DB_6DB6_DB6D_B6DCu64);
let (result, _) = value_commit_v.mul(
layouter.namespace(|| "mul with double"),
Some(scalar_fixed_short),
)?;
constrain_equal(
chip,
layouter.namespace(|| "mul with double"),
base_val,
scalar_fixed_short,
scalar,
result,
)?;
}