Migrate to halo2 version with `AssignedCell`

We change `CellValue` into a typedef of `AssignedCell` to simplify the
migration in this commit.

The migration from `CellValue` to `AssignedCell` requires several other
changes:

- `<CellValue as Var>::value()` returned `Option<F>`, whereas
  `AssignedCell::<F, F>::value()` returns `Option<&F>`. This means we
  need to dereference, use `Option::cloned`, or alter functions to take
  `&F` arguments.
- `StateWord` in the Poseidon chip has been changed to a newtype around
  `AssignedCell` (the chip was written before `CellValue` existed).
This commit is contained in:
Jack Grigg 2021-12-02 00:10:00 +00:00
parent 5cb838f1a2
commit 9b41a06363
22 changed files with 122 additions and 139 deletions

View File

@ -90,4 +90,4 @@ debug = true
[patch.crates-io] [patch.crates-io]
zcash_note_encryption = { git = "https://github.com/zcash/librustzcash.git", rev = "35e75420657599fdc701cb45704878eb3fa2e59a" } zcash_note_encryption = { git = "https://github.com/zcash/librustzcash.git", rev = "35e75420657599fdc701cb45704878eb3fa2e59a" }
incrementalmerkletree = { git = "https://github.com/zcash/incrementalmerkletree.git", rev = "b7bd6246122a6e9ace8edb51553fbf5228906cbb" } incrementalmerkletree = { git = "https://github.com/zcash/incrementalmerkletree.git", rev = "b7bd6246122a6e9ace8edb51553fbf5228906cbb" }
halo2 = { git = "https://github.com/zcash/halo2.git", rev = "8bfc58b7c76ae83ba5a9ed7ecdfe0ddfd40ed571" } halo2 = { git = "https://github.com/zcash/halo2.git", rev = "afd7bc5469674cd08eae1634225fd02706a36a4f" }

View File

@ -121,7 +121,7 @@ where
0, 0,
|| self.output.ok_or(Error::Synthesis), || self.output.ok_or(Error::Synthesis),
)?; )?;
region.constrain_equal(output.cell(), expected_var) region.constrain_equal(output.cell(), expected_var.cell())
}, },
) )
} }

View File

@ -54,7 +54,7 @@ impl EccPoint {
if x.is_zero_vartime() && y.is_zero_vartime() { if x.is_zero_vartime() && y.is_zero_vartime() {
Some(pallas::Affine::identity()) Some(pallas::Affine::identity())
} else { } else {
Some(pallas::Affine::from_xy(x, y).unwrap()) Some(pallas::Affine::from_xy(*x, *y).unwrap())
} }
} }
_ => None, _ => None,
@ -104,7 +104,7 @@ impl NonIdentityEccPoint {
match (self.x.value(), self.y.value()) { match (self.x.value(), self.y.value()) {
(Some(x), Some(y)) => { (Some(x), Some(y)) => {
assert!(!x.is_zero_vartime() && !y.is_zero_vartime()); assert!(!x.is_zero_vartime() && !y.is_zero_vartime());
Some(pallas::Affine::from_xy(x, y).unwrap()) Some(pallas::Affine::from_xy(*x, *y).unwrap())
} }
_ => None, _ => None,
} }

View File

@ -248,7 +248,7 @@ impl Config {
let gamma = x_q; let gamma = x_q;
let delta = y_q + y_p; let delta = y_q + y_p;
let mut inverses = [alpha, beta, gamma, delta]; let mut inverses = [alpha, *beta, *gamma, delta];
inverses.batch_invert(); inverses.batch_invert();
inverses inverses
}); });
@ -329,11 +329,11 @@ impl Config {
{ {
if x_p.is_zero_vartime() { if x_p.is_zero_vartime() {
// 0 + Q = Q // 0 + Q = Q
(x_q, y_q) (*x_q, *y_q)
} else if x_q.is_zero_vartime() { } else if x_q.is_zero_vartime() {
// P + 0 = P // P + 0 = P
(x_p, y_p) (*x_p, *y_p)
} else if (x_q == x_p) && (y_q == -y_p) { } else if (x_q == x_p) && (*y_q == -y_p) {
// P + (-P) maps to (0,0) // P + (-P) maps to (0,0)
(pallas::Base::zero(), pallas::Base::zero()) (pallas::Base::zero(), pallas::Base::zero())
} else { } else {

View File

@ -364,7 +364,7 @@ impl Config {
// If `lsb` is 0, return `Acc + (-P)`. If `lsb` is 1, simply return `Acc + 0`. // If `lsb` is 0, return `Acc + (-P)`. If `lsb` is 1, simply return `Acc + 0`.
let x = if let Some(lsb) = lsb { let x = if let Some(lsb) = lsb {
if !lsb { if !lsb {
base.x.value() base.x.value().cloned()
} else { } else {
Some(pallas::Base::zero()) Some(pallas::Base::zero())
} }
@ -441,7 +441,7 @@ impl<F: FieldExt> Deref for Z<F> {
} }
} }
fn decompose_for_scalar_mul(scalar: Option<pallas::Base>) -> Vec<Option<bool>> { fn decompose_for_scalar_mul(scalar: Option<&pallas::Base>) -> Vec<Option<bool>> {
let bitstring = scalar.map(|scalar| { let bitstring = scalar.map(|scalar| {
// We use `k = scalar + t_q` in the double-and-add algorithm, where // We use `k = scalar + t_q` in the double-and-add algorithm, where
// the scalar field `F_q = 2^254 + t_q`. // the scalar field `F_q = 2^254 + t_q`.

View File

@ -161,8 +161,10 @@ impl Config {
)?; )?;
// If the bit is set, use `y`; if the bit is not set, use `-y` // If the bit is set, use `y`; if the bit is not set, use `-y`
let y_p = base_y let y_p =
base_y
.value() .value()
.cloned()
.zip(k.as_ref()) .zip(k.as_ref())
.map(|(base_y, k)| if !k { -base_y } else { base_y }); .map(|(base_y, k)| if !k { -base_y } else { base_y });

View File

@ -190,8 +190,8 @@ impl<const NUM_BITS: usize> Config<NUM_BITS> {
assert_eq!(bits.len(), NUM_BITS); assert_eq!(bits.len(), NUM_BITS);
// Handle exceptional cases // Handle exceptional cases
let (x_p, y_p) = (base.x.value(), base.y.value()); let (x_p, y_p) = (base.x.value().cloned(), base.y.value().cloned());
let (x_a, y_a) = (acc.0.value(), acc.1.value()); let (x_a, y_a) = (acc.0.value().cloned(), acc.1.value().cloned());
if let (Some(x_a), Some(y_a), Some(x_p), Some(y_p)) = (x_a, y_a, x_p, y_p) { if let (Some(x_a), Some(y_a), Some(x_p), Some(y_p)) = (x_a, y_a, x_p, y_p) {
// A is point at infinity // A is point at infinity
@ -229,7 +229,7 @@ impl<const NUM_BITS: usize> Config<NUM_BITS> {
let x_a = copy(region, || "starting x_a", self.x_a, offset + 1, &acc.0)?; let x_a = copy(region, || "starting x_a", self.x_a, offset + 1, &acc.0)?;
let y_a = copy(region, || "starting y_a", self.lambda1, offset, &acc.1)?; let y_a = copy(region, || "starting y_a", self.lambda1, offset, &acc.1)?;
(x_a, y_a.value(), z) (x_a, y_a.value().cloned(), z)
}; };
// Increase offset by 1; we used row 0 for initializing `z`. // Increase offset by 1; we used row 0 for initializing `z`.
@ -313,7 +313,7 @@ impl<const NUM_BITS: usize> Config<NUM_BITS> {
.zip(x_r) .zip(x_r)
.map(|((lambda2, x_a), x_r)| lambda2.square() - x_a - x_r); .map(|((lambda2, x_a), x_r)| lambda2.square() - x_a - x_r);
y_a = lambda2 y_a = lambda2
.zip(x_a.value()) .zip(x_a.value().cloned())
.zip(x_a_new) .zip(x_a_new)
.zip(y_a) .zip(y_a)
.map(|(((lambda2, x_a), x_a_new), y_a)| lambda2 * (x_a - x_a_new) - y_a); .map(|(((lambda2, x_a), x_a_new), y_a)| lambda2 * (x_a - x_a_new) - y_a);

View File

@ -4,7 +4,6 @@ use super::H_BASE;
use crate::{ use crate::{
circuit::gadget::utilities::{ circuit::gadget::utilities::{
bitrange_subset, copy, lookup_range_check::LookupRangeCheckConfig, range_check, CellValue, bitrange_subset, copy, lookup_range_check::LookupRangeCheckConfig, range_check, CellValue,
Var,
}, },
constants::{self, T_P}, constants::{self, T_P},
primitives::sinsemilla, primitives::sinsemilla,

View File

@ -85,7 +85,7 @@ impl Config {
// Decompose scalar into `k-bit` windows // Decompose scalar into `k-bit` windows
let scalar_windows: Option<Vec<u8>> = scalar.map(|scalar| { let scalar_windows: Option<Vec<u8>> = scalar.map(|scalar| {
util::decompose_word::<pallas::Scalar>( util::decompose_word::<pallas::Scalar>(
scalar, &scalar,
SCALAR_NUM_BITS, SCALAR_NUM_BITS,
constants::FIXED_BASE_WINDOW_SIZE, constants::FIXED_BASE_WINDOW_SIZE,
) )

View File

@ -160,10 +160,10 @@ impl Config {
// Conditionally negate `y`-coordinate // Conditionally negate `y`-coordinate
let y_val = if let Some(sign) = sign.value() { let y_val = if let Some(sign) = sign.value() {
if sign == -pallas::Base::one() { if sign == &-pallas::Base::one() {
magnitude_mul.y.value().map(|y: pallas::Base| -y) magnitude_mul.y.value().cloned().map(|y: pallas::Base| -y)
} else { } else {
magnitude_mul.y.value() magnitude_mul.y.value().cloned()
} }
} else { } else {
None None
@ -199,7 +199,7 @@ impl Config {
if let (Some(magnitude), Some(sign)) = (scalar.magnitude.value(), scalar.sign.value()) { if let (Some(magnitude), Some(sign)) = (scalar.magnitude.value(), scalar.sign.value()) {
let magnitude_is_valid = let magnitude_is_valid =
magnitude <= pallas::Base::from_u64(0xFFFF_FFFF_FFFF_FFFFu64); magnitude <= &pallas::Base::from_u64(0xFFFF_FFFF_FFFF_FFFFu64);
let sign_is_valid = sign * sign == pallas::Base::one(); let sign_is_valid = sign * sign == pallas::Base::one();
if magnitude_is_valid && sign_is_valid { if magnitude_is_valid && sign_is_valid {
let base: super::OrchardFixedBases = base.clone().into(); let base: super::OrchardFixedBases = base.clone().into();
@ -211,7 +211,7 @@ impl Config {
let magnitude = let magnitude =
pallas::Scalar::from_bytes(&magnitude.to_bytes()).unwrap(); pallas::Scalar::from_bytes(&magnitude.to_bytes()).unwrap();
let sign = if sign == pallas::Base::one() { let sign = if sign == &pallas::Base::one() {
pallas::Scalar::one() pallas::Scalar::one()
} else { } else {
-pallas::Scalar::one() -pallas::Scalar::one()

View File

@ -3,13 +3,13 @@ use std::iter;
use halo2::{ use halo2::{
arithmetic::FieldExt, arithmetic::FieldExt,
circuit::{Cell, Chip, Layouter, Region}, circuit::{AssignedCell, Cell, Chip, Layouter, Region},
plonk::{Advice, Column, ConstraintSystem, Error, Expression, Fixed, Selector}, plonk::{Advice, Column, ConstraintSystem, Error, Expression, Fixed, Selector},
poly::Rotation, poly::Rotation,
}; };
use super::{PoseidonDuplexInstructions, PoseidonInstructions}; use super::{PoseidonDuplexInstructions, PoseidonInstructions};
use crate::circuit::gadget::utilities::{CellValue, Var}; use crate::circuit::gadget::utilities::Var;
use crate::primitives::poseidon::{Domain, Mds, Spec, SpongeState, State}; use crate::primitives::poseidon::{Domain, Mds, Spec, SpongeState, State};
/// Configuration for a [`Pow5Chip`]. /// Configuration for a [`Pow5Chip`].
@ -288,10 +288,7 @@ impl<F: FieldExt, S: Spec<F, WIDTH, RATE>, const WIDTH: usize, const RATE: usize
0, 0,
value, value,
)?; )?;
state.push(StateWord { state.push(StateWord(var));
var,
value: Some(value),
});
Ok(()) Ok(())
}; };
@ -323,15 +320,15 @@ impl<F: FieldExt, S: Spec<F, WIDTH, RATE>, const WIDTH: usize, const RATE: usize
// Load the initial state into this region. // Load the initial state into this region.
let load_state_word = |i: usize| { let load_state_word = |i: usize| {
let value = initial_state[i].value; let value = initial_state[i].0.value().cloned();
let var = region.assign_advice( let var = region.assign_advice(
|| format!("load state_{}", i), || format!("load state_{}", i),
config.state[i], config.state[i],
0, 0,
|| value.ok_or(Error::Synthesis), || value.ok_or(Error::Synthesis),
)?; )?;
region.constrain_equal(initial_state[i].var, var)?; region.constrain_equal(initial_state[i].0.cell(), var.cell())?;
Ok(StateWord { var, value }) Ok(StateWord(var))
}; };
let initial_state: Result<Vec<_>, Error> = let initial_state: Result<Vec<_>, Error> =
(0..WIDTH).map(load_state_word).collect(); (0..WIDTH).map(load_state_word).collect();
@ -342,7 +339,7 @@ impl<F: FieldExt, S: Spec<F, WIDTH, RATE>, const WIDTH: usize, const RATE: usize
// Load the input and padding into this region. // Load the input and padding into this region.
let load_input_word = |i: usize| { let load_input_word = |i: usize| {
let (constraint_var, value) = match (input[i].clone(), padding_values[i]) { let (constraint_var, value) = match (input[i].clone(), padding_values[i]) {
(Some(word), None) => (word.var, word.value), (Some(word), None) => (word.0.cell(), word.0.value().cloned()),
(None, Some(padding_value)) => { (None, Some(padding_value)) => {
let padding_var = region.assign_fixed( let padding_var = region.assign_fixed(
|| format!("load pad_{}", i), || format!("load pad_{}", i),
@ -350,7 +347,7 @@ impl<F: FieldExt, S: Spec<F, WIDTH, RATE>, const WIDTH: usize, const RATE: usize
1, 1,
|| Ok(padding_value), || Ok(padding_value),
)?; )?;
(padding_var, Some(padding_value)) (padding_var.cell(), Some(padding_value))
} }
_ => panic!("Input and padding don't match"), _ => panic!("Input and padding don't match"),
}; };
@ -360,30 +357,31 @@ impl<F: FieldExt, S: Spec<F, WIDTH, RATE>, const WIDTH: usize, const RATE: usize
1, 1,
|| value.ok_or(Error::Synthesis), || value.ok_or(Error::Synthesis),
)?; )?;
region.constrain_equal(constraint_var, var)?; region.constrain_equal(constraint_var, var.cell())?;
Ok(StateWord { var, value }) Ok(StateWord(var))
}; };
let input: Result<Vec<_>, Error> = (0..RATE).map(load_input_word).collect(); let input: Result<Vec<_>, Error> = (0..RATE).map(load_input_word).collect();
let input = input?; let input = input?;
// Constrain the output. // Constrain the output.
let constrain_output_word = |i: usize| { let constrain_output_word = |i: usize| {
let value = initial_state[i].value.and_then(|initial_word| { let value = initial_state[i].0.value().and_then(|initial_word| {
input input
.get(i) .get(i)
.map(|word| word.value) .map(|word| word.0.value().cloned())
// The capacity element is never altered by the input. // The capacity element is never altered by the input.
.unwrap_or_else(|| Some(F::zero())) .unwrap_or_else(|| Some(F::zero()))
.map(|input_word| initial_word + input_word) .map(|input_word| *initial_word + input_word)
}); });
let var = region.assign_advice( region
.assign_advice(
|| format!("load output_{}", i), || format!("load output_{}", i),
config.state[i], config.state[i],
2, 2,
|| value.ok_or(Error::Synthesis), || value.ok_or(Error::Synthesis),
)?; )
Ok(StateWord { var, value }) .map(StateWord)
}; };
let output: Result<Vec<_>, Error> = (0..WIDTH).map(constrain_output_word).collect(); let output: Result<Vec<_>, Error> = (0..WIDTH).map(constrain_output_word).collect();
@ -404,34 +402,31 @@ impl<F: FieldExt, S: Spec<F, WIDTH, RATE>, const WIDTH: usize, const RATE: usize
/// A word in the Poseidon state. /// A word in the Poseidon state.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct StateWord<F: FieldExt> { pub struct StateWord<F: FieldExt>(AssignedCell<F, F>);
var: Cell,
value: Option<F>,
}
impl<F: FieldExt> From<StateWord<F>> for CellValue<F> { impl<F: FieldExt> From<StateWord<F>> for AssignedCell<F, F> {
fn from(state_word: StateWord<F>) -> CellValue<F> { fn from(state_word: StateWord<F>) -> AssignedCell<F, F> {
CellValue::new(state_word.var, state_word.value) state_word.0
} }
} }
impl<F: FieldExt> From<CellValue<F>> for StateWord<F> { impl<F: FieldExt> From<AssignedCell<F, F>> for StateWord<F> {
fn from(cell_value: CellValue<F>) -> StateWord<F> { fn from(cell_value: AssignedCell<F, F>) -> StateWord<F> {
StateWord::new(cell_value.cell(), cell_value.value()) StateWord(cell_value)
} }
} }
impl<F: FieldExt> Var<F> for StateWord<F> { impl<F: FieldExt> Var<F> for StateWord<F> {
fn new(var: Cell, value: Option<F>) -> Self { fn new(var: AssignedCell<F, F>, value: Option<F>) -> Self {
Self { var, value } Self(var)
} }
fn cell(&self) -> Cell { fn cell(&self) -> Cell {
self.var self.0.cell()
} }
fn value(&self) -> Option<F> { fn value(&self) -> Option<F> {
self.value self.0.value().cloned()
} }
} }
@ -447,11 +442,11 @@ impl<F: FieldExt, const WIDTH: usize> Pow5State<F, WIDTH> {
offset: usize, offset: usize,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
Self::round(region, config, round, offset, config.s_full, |_| { Self::round(region, config, round, offset, config.s_full, |_| {
let q = self let q = self.0.iter().enumerate().map(|(idx, word)| {
.0 word.0
.iter() .value()
.enumerate() .map(|v| *v + config.round_constants[round][idx])
.map(|(idx, word)| word.value.map(|v| v + config.round_constants[round][idx])); });
let r: Option<Vec<F>> = q.map(|q| q.map(|q| q.pow(&config.alpha))).collect(); let r: Option<Vec<F>> = q.map(|q| q.map(|q| q.pow(&config.alpha))).collect();
let m = &config.m_reg; let m = &config.m_reg;
let state = m.iter().map(|m_i| { let state = m.iter().map(|m_i| {
@ -475,7 +470,7 @@ impl<F: FieldExt, const WIDTH: usize> Pow5State<F, WIDTH> {
) -> Result<Self, Error> { ) -> Result<Self, Error> {
Self::round(region, config, round, offset, config.s_partial, |region| { Self::round(region, config, round, offset, config.s_partial, |region| {
let m = &config.m_reg; let m = &config.m_reg;
let p: Option<Vec<_>> = self.0.iter().map(|word| word.value).collect(); let p: Option<Vec<_>> = self.0.iter().map(|word| word.0.value().cloned()).collect();
let r: Option<Vec<_>> = p.map(|p| { let r: Option<Vec<_>> = p.map(|p| {
let r_0 = (p[0] + config.round_constants[round][0]).pow(&config.alpha); let r_0 = (p[0] + config.round_constants[round][0]).pow(&config.alpha);
@ -547,15 +542,15 @@ impl<F: FieldExt, const WIDTH: usize> Pow5State<F, WIDTH> {
initial_state: &State<StateWord<F>, WIDTH>, initial_state: &State<StateWord<F>, WIDTH>,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
let load_state_word = |i: usize| { let load_state_word = |i: usize| {
let value = initial_state[i].value; let value = initial_state[i].0.value().cloned();
let var = region.assign_advice( let var = region.assign_advice(
|| format!("load state_{}", i), || format!("load state_{}", i),
config.state[i], config.state[i],
0, 0,
|| value.ok_or(Error::Synthesis), || value.ok_or(Error::Synthesis),
)?; )?;
region.constrain_equal(initial_state[i].var, var)?; region.constrain_equal(initial_state[i].0.cell(), var.cell())?;
Ok(StateWord { var, value }) Ok(StateWord(var))
}; };
let state: Result<Vec<_>, _> = (0..WIDTH).map(load_state_word).collect(); let state: Result<Vec<_>, _> = (0..WIDTH).map(load_state_word).collect();
@ -597,7 +592,7 @@ impl<F: FieldExt, const WIDTH: usize> Pow5State<F, WIDTH> {
offset + 1, offset + 1,
|| value.ok_or(Error::Synthesis), || value.ok_or(Error::Synthesis),
)?; )?;
Ok(StateWord { var, value }) Ok(StateWord(var))
}; };
let next_state: Result<Vec<_>, _> = (0..WIDTH).map(next_state_word).collect(); let next_state: Result<Vec<_>, _> = (0..WIDTH).map(next_state_word).collect();
@ -674,7 +669,7 @@ mod tests {
0, 0,
|| value.ok_or(Error::Synthesis), || value.ok_or(Error::Synthesis),
)?; )?;
Ok(StateWord { var, value }) Ok(StateWord(var))
}; };
let state: Result<Vec<_>, Error> = (0..WIDTH).map(state_word).collect(); let state: Result<Vec<_>, Error> = (0..WIDTH).map(state_word).collect();
@ -713,7 +708,7 @@ mod tests {
0, 0,
|| Ok(expected_final_state[i]), || Ok(expected_final_state[i]),
)?; )?;
region.constrain_equal(final_state[i].var, var) region.constrain_equal(final_state[i].0.cell(), var.cell())
}; };
for i in 0..(WIDTH) { for i in 0..(WIDTH) {
@ -821,7 +816,7 @@ mod tests {
0, 0,
|| self.output.ok_or(Error::Synthesis), || self.output.ok_or(Error::Synthesis),
)?; )?;
region.constrain_equal(output.cell(), expected_var) region.constrain_equal(output.cell(), expected_var.cell())
}, },
) )
} }

View File

@ -142,7 +142,7 @@ impl SinsemillaChip {
.chunks(K) .chunks(K)
.fold(Q.to_curve(), |acc, chunk| (acc + S(chunk)) + acc); .fold(Q.to_curve(), |acc, chunk| (acc + S(chunk)) + acc);
let actual_point = let actual_point =
pallas::Affine::from_xy(x_a.value().unwrap(), y_a.value().unwrap()).unwrap(); pallas::Affine::from_xy(*x_a.value().unwrap(), *y_a.value().unwrap()).unwrap();
assert_eq!(expected_point.to_affine(), actual_point); assert_eq!(expected_point.to_affine(), actual_point);
} }
} }
@ -270,7 +270,7 @@ impl SinsemillaChip {
offset, offset,
|| piece.field_elem().ok_or(Error::Synthesis), || piece.field_elem().ok_or(Error::Synthesis),
)?; )?;
region.constrain_equal(piece.cell(), cell)?; region.constrain_equal(piece.cell(), cell.cell())?;
zs.push(CellValue::new(cell, piece.field_elem())); zs.push(CellValue::new(cell, piece.field_elem()));
// Assign cumulative sum such that for 0 <= i < n, // Assign cumulative sum such that for 0 <= i < n,

View File

@ -8,7 +8,7 @@ use pasta_curves::{arithmetic::FieldExt, pallas};
use crate::{ use crate::{
circuit::gadget::{ circuit::gadget::{
ecc::{chip::EccChip, X}, ecc::{chip::EccChip, X},
utilities::{bitrange_subset, bool_check, copy, CellValue, Var}, utilities::{bitrange_subset, bool_check, copy, CellValue},
}, },
constants::T_P, constants::T_P,
}; };
@ -641,7 +641,7 @@ mod tests {
ecc::chip::{EccChip, EccConfig}, ecc::chip::{EccChip, EccConfig},
sinsemilla::chip::SinsemillaChip, sinsemilla::chip::SinsemillaChip,
utilities::{ utilities::{
lookup_range_check::LookupRangeCheckConfig, CellValue, UtilitiesInstructions, Var, lookup_range_check::LookupRangeCheckConfig, CellValue, UtilitiesInstructions,
}, },
}, },
constants::{COMMIT_IVK_PERSONALIZATION, L_ORCHARD_BASE, T_Q}, constants::{COMMIT_IVK_PERSONALIZATION, L_ORCHARD_BASE, T_Q},
@ -803,7 +803,7 @@ mod tests {
.unwrap() .unwrap()
}; };
assert_eq!(expected_ivk, ivk.inner().value().unwrap()); assert_eq!(&expected_ivk, ivk.inner().value().unwrap());
Ok(()) Ok(())
} }

View File

@ -139,7 +139,7 @@ pub mod tests {
use crate::{ use crate::{
circuit::gadget::{ circuit::gadget::{
sinsemilla::chip::{SinsemillaChip, SinsemillaHashDomains}, sinsemilla::chip::{SinsemillaChip, SinsemillaHashDomains},
utilities::{lookup_range_check::LookupRangeCheckConfig, UtilitiesInstructions, Var}, utilities::{lookup_range_check::LookupRangeCheckConfig, UtilitiesInstructions},
}, },
constants::MERKLE_DEPTH_ORCHARD, constants::MERKLE_DEPTH_ORCHARD,
note::commitment::ExtractedNoteCommitment, note::commitment::ExtractedNoteCommitment,
@ -266,7 +266,7 @@ pub mod tests {
}; };
// Check the computed final root against the expected final root. // Check the computed final root against the expected final root.
assert_eq!(computed_final_root.value().unwrap(), final_root.inner()); assert_eq!(computed_final_root.value().unwrap(), &final_root.inner());
} }
Ok(()) Ok(())

View File

@ -15,7 +15,7 @@ use crate::{
circuit::gadget::utilities::{ circuit::gadget::utilities::{
bitrange_subset, bitrange_subset,
cond_swap::{CondSwapChip, CondSwapConfig, CondSwapInstructions}, cond_swap::{CondSwapChip, CondSwapConfig, CondSwapInstructions},
copy, CellValue, UtilitiesInstructions, Var, copy, CellValue, UtilitiesInstructions,
}, },
constants::{L_ORCHARD_BASE, MERKLE_DEPTH_ORCHARD}, constants::{L_ORCHARD_BASE, MERKLE_DEPTH_ORCHARD},
primitives::sinsemilla, primitives::sinsemilla,
@ -185,7 +185,7 @@ impl MerkleInstructions<pallas::Affine, MERKLE_DEPTH_ORCHARD, { sinsemilla::K },
let a = { let a = {
let a = { let a = {
// a_0 = l // a_0 = l
let a_0 = bitrange_subset(pallas::Base::from_u64(l as u64), 0..10); let a_0 = bitrange_subset(&pallas::Base::from_u64(l as u64), 0..10);
// a_1 = (bits 0..=239 of `left`) // a_1 = (bits 0..=239 of `left`)
let a_1 = left.value().map(|value| bitrange_subset(value, 0..240)); let a_1 = left.value().map(|value| bitrange_subset(value, 0..240));

View File

@ -40,7 +40,7 @@ pub struct MessagePiece<F: FieldExt, const K: usize> {
} }
impl<F: FieldExt + PrimeFieldBits, const K: usize> MessagePiece<F, K> { impl<F: FieldExt + PrimeFieldBits, const K: usize> MessagePiece<F, K> {
pub fn new(cell: Cell, field_elem: Option<F>, num_words: usize) -> Self { pub fn new(cell: CellValue<F>, field_elem: Option<F>, num_words: usize) -> Self {
assert!(num_words * K < F::NUM_BITS as usize); assert!(num_words * K < F::NUM_BITS as usize);
let cell_value = CellValue::new(cell, field_elem); let cell_value = CellValue::new(cell, field_elem);
Self { Self {
@ -58,7 +58,7 @@ impl<F: FieldExt + PrimeFieldBits, const K: usize> MessagePiece<F, K> {
} }
pub fn field_elem(&self) -> Option<F> { pub fn field_elem(&self) -> Option<F> {
self.cell_value.value() self.cell_value.value().cloned()
} }
pub fn cell_value(&self) -> CellValue<F> { pub fn cell_value(&self) -> CellValue<F> {

View File

@ -530,8 +530,10 @@ impl NoteCommitConfig {
psi: CellValue<pallas::Base>, psi: CellValue<pallas::Base>,
rcm: Option<pallas::Scalar>, rcm: Option<pallas::Scalar>,
) -> Result<Point<pallas::Affine, EccChip>, Error> { ) -> Result<Point<pallas::Affine, EccChip>, Error> {
let (gd_x, gd_y) = (g_d.x().value(), g_d.y().value()); let (gd_x, gd_y) = (g_d.x(), g_d.y());
let (pkd_x, pkd_y) = (pk_d.x().value(), pk_d.y().value()); let (pkd_x, pkd_y) = (pk_d.x(), pk_d.y());
let (gd_x, gd_y) = (gd_x.value(), gd_y.value());
let (pkd_x, pkd_y) = (pkd_x.value(), pkd_y.value());
let value_val = value.value(); let value_val = value.value();
let rho_val = rho.value(); let rho_val = rho.value();
let psi_val = psi.value(); let psi_val = psi.value();

View File

@ -2,7 +2,7 @@
use ff::PrimeFieldBits; use ff::PrimeFieldBits;
use halo2::{ use halo2::{
circuit::{Cell, Layouter, Region}, circuit::{AssignedCell, Cell, Layouter, Region},
plonk::{Advice, Column, Error, Expression}, plonk::{Advice, Column, Error, Expression},
}; };
use pasta_curves::arithmetic::FieldExt; use pasta_curves::arithmetic::FieldExt;
@ -13,16 +13,12 @@ pub(crate) mod decompose_running_sum;
pub(crate) mod lookup_range_check; pub(crate) mod lookup_range_check;
/// A variable representing a field element. /// A variable representing a field element.
#[derive(Clone, Debug)] pub type CellValue<F> = AssignedCell<F, F>;
pub struct CellValue<F: FieldExt> {
cell: Cell,
value: Option<F>,
}
/// Trait for a variable in the circuit. /// Trait for a variable in the circuit.
pub trait Var<F: FieldExt>: Clone + std::fmt::Debug { pub trait Var<F: FieldExt>: Clone + std::fmt::Debug {
/// Construct a new variable. /// Construct a new variable.
fn new(cell: Cell, value: Option<F>) -> Self; fn new(cell: AssignedCell<F, F>, value: Option<F>) -> Self;
/// The cell at which this variable was allocated. /// The cell at which this variable was allocated.
fn cell(&self) -> Cell; fn cell(&self) -> Cell;
@ -32,16 +28,16 @@ pub trait Var<F: FieldExt>: Clone + std::fmt::Debug {
} }
impl<F: FieldExt> Var<F> for CellValue<F> { impl<F: FieldExt> Var<F> for CellValue<F> {
fn new(cell: Cell, value: Option<F>) -> Self { fn new(cell: AssignedCell<F, F>, _value: Option<F>) -> Self {
Self { cell, value } cell
} }
fn cell(&self) -> Cell { fn cell(&self) -> Cell {
self.cell self.cell()
} }
fn value(&self) -> Option<F> { fn value(&self) -> Option<F> {
self.value self.value().cloned()
} }
} }
@ -90,13 +86,9 @@ where
A: Fn() -> AR, A: Fn() -> AR,
AR: Into<String>, AR: Into<String>,
{ {
let cell = region.assign_advice(annotation, column, offset, || { // Temporarily implement `copy()` in terms of `AssignedCell::copy_advice`.
copy.value.ok_or(Error::Synthesis) // We will remove this in a subsequent commit.
})?; copy.copy_advice(annotation, region, column, offset)
region.constrain_equal(cell, copy.cell)?;
Ok(CellValue::new(cell, copy.value))
} }
pub(crate) fn transpose_option_array<T: Copy + std::fmt::Debug, const LEN: usize>( pub(crate) fn transpose_option_array<T: Copy + std::fmt::Debug, const LEN: usize>(
@ -126,7 +118,7 @@ pub fn ternary<F: FieldExt>(a: Expression<F>, b: Expression<F>, c: Expression<F>
/// Takes a specified subsequence of the little-endian bit representation of a field element. /// Takes a specified subsequence of the little-endian bit representation of a field element.
/// The bits are numbered from 0 for the LSB. /// The bits are numbered from 0 for the LSB.
pub fn bitrange_subset<F: FieldExt + PrimeFieldBits>(field_elem: F, bitrange: Range<usize>) -> F { pub fn bitrange_subset<F: FieldExt + PrimeFieldBits>(field_elem: &F, bitrange: Range<usize>) -> F {
assert!(bitrange.end <= F::NUM_BITS as usize); assert!(bitrange.end <= F::NUM_BITS as usize);
let bits: Vec<bool> = field_elem let bits: Vec<bool> = field_elem
@ -251,7 +243,7 @@ mod tests {
{ {
let field_elem = pallas::Base::rand(); let field_elem = pallas::Base::rand();
let bitrange = 0..(pallas::Base::NUM_BITS as usize); let bitrange = 0..(pallas::Base::NUM_BITS as usize);
let subset = bitrange_subset(field_elem, bitrange); let subset = bitrange_subset(&field_elem, bitrange);
assert_eq!(field_elem, subset); assert_eq!(field_elem, subset);
} }
@ -259,7 +251,7 @@ mod tests {
{ {
let field_elem = pallas::Base::rand(); let field_elem = pallas::Base::rand();
let bitrange = 0..0; let bitrange = 0..0;
let subset = bitrange_subset(field_elem, bitrange); let subset = bitrange_subset(&field_elem, bitrange);
assert_eq!(pallas::Base::zero(), subset); assert_eq!(pallas::Base::zero(), subset);
} }
@ -286,7 +278,7 @@ mod tests {
let subsets = ranges let subsets = ranges
.iter() .iter()
.map(|range| bitrange_subset(field_elem, range.clone())) .map(|range| bitrange_subset(&field_elem, range.clone()))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let mut sum = subsets[0]; let mut sum = subsets[0];

View File

@ -98,39 +98,33 @@ impl<F: FieldExt> CondSwapInstructions<F> for CondSwapChip<F> {
// Conditionally swap a // Conditionally swap a
let a_swapped = { let a_swapped = {
let a_swapped = a let a_swapped = a
.value .value()
.zip(b.value) .zip(b.value())
.zip(swap) .zip(swap)
.map(|((a, b), swap)| if swap { b } else { a }); .map(|((a, b), swap)| if swap { b } else { a })
let a_swapped_cell = region.assign_advice( .cloned();
region.assign_advice(
|| "a_swapped", || "a_swapped",
config.a_swapped, config.a_swapped,
0, 0,
|| a_swapped.ok_or(Error::Synthesis), || a_swapped.ok_or(Error::Synthesis),
)?; )?
CellValue {
cell: a_swapped_cell,
value: a_swapped,
}
}; };
// Conditionally swap b // Conditionally swap b
let b_swapped = { let b_swapped = {
let b_swapped = a let b_swapped = a
.value .value()
.zip(b.value) .zip(b.value())
.zip(swap) .zip(swap)
.map(|((a, b), swap)| if swap { a } else { b }); .map(|((a, b), swap)| if swap { a } else { b })
let b_swapped_cell = region.assign_advice( .cloned();
region.assign_advice(
|| "b_swapped", || "b_swapped",
config.b_swapped, config.b_swapped,
0, 0,
|| b_swapped.ok_or(Error::Synthesis), || b_swapped.ok_or(Error::Synthesis),
)?; )?
CellValue {
cell: b_swapped_cell,
value: b_swapped,
}
}; };
// Return swapped pair // Return swapped pair
@ -261,12 +255,12 @@ mod tests {
if let Some(swap) = self.swap { if let Some(swap) = self.swap {
if swap { if swap {
// Check that `a` and `b` have been swapped // Check that `a` and `b` have been swapped
assert_eq!(swapped_pair.0.value.unwrap(), self.b.unwrap()); assert_eq!(swapped_pair.0.value().unwrap(), &self.b.unwrap());
assert_eq!(swapped_pair.1.value.unwrap(), a.value.unwrap()); assert_eq!(swapped_pair.1.value().unwrap(), a.value().unwrap());
} else { } else {
// Check that `a` and `b` have not been swapped // Check that `a` and `b` have not been swapped
assert_eq!(swapped_pair.0.value.unwrap(), a.value.unwrap()); assert_eq!(swapped_pair.0.value().unwrap(), a.value().unwrap());
assert_eq!(swapped_pair.1.value.unwrap(), self.b.unwrap()); assert_eq!(swapped_pair.1.value().unwrap(), &self.b.unwrap());
} }
} }

View File

@ -193,7 +193,7 @@ impl<F: FieldExt + PrimeFieldBits, const WINDOW_NUM_BITS: usize>
let z_next_val = z let z_next_val = z
.value() .value()
.zip(word) .zip(word)
.map(|(z_cur_val, word)| (z_cur_val - word) * two_pow_k_inv); .map(|(z_cur_val, word)| (*z_cur_val - word) * two_pow_k_inv);
let cell = region.assign_advice( let cell = region.assign_advice(
|| format!("z_{:?}", i + 1), || format!("z_{:?}", i + 1),
self.z, self.z,

View File

@ -244,7 +244,7 @@ impl<F: FieldExt + PrimeFieldBits, const K: usize> LookupRangeCheckConfig<F, K>
let z_val = z let z_val = z
.value() .value()
.zip(*word) .zip(*word)
.map(|(z, word)| (z - word) * inv_two_pow_k); .map(|(z, word)| (*z - word) * inv_two_pow_k);
// Assign z_next // Assign z_next
let z_cell = region.assign_advice( let z_cell = region.assign_advice(
@ -344,7 +344,7 @@ impl<F: FieldExt + PrimeFieldBits, const K: usize> LookupRangeCheckConfig<F, K>
// Assign shifted `element * 2^{K - num_bits}` // Assign shifted `element * 2^{K - num_bits}`
let shifted = element.value().map(|element| { let shifted = element.value().map(|element| {
let shift = F::from_u64(1 << (K - num_bits)); let shift = F::from_u64(1 << (K - num_bits));
element * shift *element * shift
}); });
region.assign_advice( region.assign_advice(
@ -369,7 +369,6 @@ impl<F: FieldExt + PrimeFieldBits, const K: usize> LookupRangeCheckConfig<F, K>
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::super::Var;
use super::LookupRangeCheckConfig; use super::LookupRangeCheckConfig;
use crate::primitives::sinsemilla::{INV_TWO_POW_K, K}; use crate::primitives::sinsemilla::{INV_TWO_POW_K, K};
@ -468,7 +467,7 @@ mod tests {
for (expected_z, z) in expected_zs.into_iter().zip(zs.iter()) { for (expected_z, z) in expected_zs.into_iter().zip(zs.iter()) {
if let Some(z) = z.value() { if let Some(z) = z.value() {
assert_eq!(expected_z, z); assert_eq!(&expected_z, z);
} }
} }
} }

View File

@ -10,7 +10,7 @@ use halo2::arithmetic::{CurveAffine, FieldExt};
/// We are returning a `Vec<u8>` which means the window size is limited to /// We are returning a `Vec<u8>` which means the window size is limited to
/// <= 8 bits. /// <= 8 bits.
pub fn decompose_word<F: PrimeFieldBits>( pub fn decompose_word<F: PrimeFieldBits>(
word: F, word: &F,
word_num_bits: usize, word_num_bits: usize,
window_num_bits: usize, window_num_bits: usize,
) -> Vec<u8> { ) -> Vec<u8> {
@ -86,7 +86,7 @@ mod tests {
window_num_bits in 1u8..9 window_num_bits in 1u8..9
) { ) {
// Get decomposition into `window_num_bits` bits // Get decomposition into `window_num_bits` bits
let decomposed = decompose_word(scalar, pallas::Scalar::NUM_BITS as usize, window_num_bits as usize); let decomposed = decompose_word(&scalar, pallas::Scalar::NUM_BITS as usize, window_num_bits as usize);
// Flatten bits // Flatten bits
let bits = decomposed let bits = decomposed