From 9b41a063633e4f81b1cc65f4fa830bc91332c25f Mon Sep 17 00:00:00 2001 From: Jack Grigg Date: Thu, 2 Dec 2021 00:10:00 +0000 Subject: [PATCH] Migrate to halo2 version with `AssignedCell` We change `CellValue` into a typedef of `AssignedCell` to simplify the migration in this commit. The migration from `CellValue` to `AssignedCell` requires several other changes: - `::value()` returned `Option`, whereas `AssignedCell::::value()` returns `Option<&F>`. This means we need to dereference, use `Option::cloned`, or alter functions to take `&F` arguments. - `StateWord` in the Poseidon chip has been changed to a newtype around `AssignedCell` (the chip was written before `CellValue` existed). --- Cargo.toml | 2 +- benches/poseidon.rs | 2 +- src/circuit/gadget/ecc/chip.rs | 4 +- src/circuit/gadget/ecc/chip/add.rs | 8 +- src/circuit/gadget/ecc/chip/mul.rs | 4 +- src/circuit/gadget/ecc/chip/mul/complete.rs | 10 +- src/circuit/gadget/ecc/chip/mul/incomplete.rs | 8 +- .../ecc/chip/mul_fixed/base_field_elem.rs | 1 - .../gadget/ecc/chip/mul_fixed/full_width.rs | 2 +- .../gadget/ecc/chip/mul_fixed/short.rs | 10 +- src/circuit/gadget/poseidon/pow5.rs | 95 +++++++++---------- .../gadget/sinsemilla/chip/hash_to_point.rs | 4 +- src/circuit/gadget/sinsemilla/commit_ivk.rs | 6 +- src/circuit/gadget/sinsemilla/merkle.rs | 4 +- src/circuit/gadget/sinsemilla/merkle/chip.rs | 4 +- src/circuit/gadget/sinsemilla/message.rs | 4 +- src/circuit/gadget/sinsemilla/note_commit.rs | 6 +- src/circuit/gadget/utilities.rs | 36 +++---- src/circuit/gadget/utilities/cond_swap.rs | 38 ++++---- .../gadget/utilities/decompose_running_sum.rs | 2 +- .../gadget/utilities/lookup_range_check.rs | 7 +- src/constants/util.rs | 4 +- 22 files changed, 122 insertions(+), 139 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2b38917e..365804a5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -90,4 +90,4 @@ debug = true [patch.crates-io] zcash_note_encryption = { git = "https://github.com/zcash/librustzcash.git", rev = "35e75420657599fdc701cb45704878eb3fa2e59a" } incrementalmerkletree = { git = "https://github.com/zcash/incrementalmerkletree.git", rev = "b7bd6246122a6e9ace8edb51553fbf5228906cbb" } -halo2 = { git = "https://github.com/zcash/halo2.git", rev = "8bfc58b7c76ae83ba5a9ed7ecdfe0ddfd40ed571" } +halo2 = { git = "https://github.com/zcash/halo2.git", rev = "afd7bc5469674cd08eae1634225fd02706a36a4f" } diff --git a/benches/poseidon.rs b/benches/poseidon.rs index 4f40d9a6..7591e144 100644 --- a/benches/poseidon.rs +++ b/benches/poseidon.rs @@ -121,7 +121,7 @@ where 0, || self.output.ok_or(Error::Synthesis), )?; - region.constrain_equal(output.cell(), expected_var) + region.constrain_equal(output.cell(), expected_var.cell()) }, ) } diff --git a/src/circuit/gadget/ecc/chip.rs b/src/circuit/gadget/ecc/chip.rs index e354ca66..92ad5f4c 100644 --- a/src/circuit/gadget/ecc/chip.rs +++ b/src/circuit/gadget/ecc/chip.rs @@ -54,7 +54,7 @@ impl EccPoint { if x.is_zero_vartime() && y.is_zero_vartime() { Some(pallas::Affine::identity()) } else { - Some(pallas::Affine::from_xy(x, y).unwrap()) + Some(pallas::Affine::from_xy(*x, *y).unwrap()) } } _ => None, @@ -104,7 +104,7 @@ impl NonIdentityEccPoint { match (self.x.value(), self.y.value()) { (Some(x), Some(y)) => { assert!(!x.is_zero_vartime() && !y.is_zero_vartime()); - Some(pallas::Affine::from_xy(x, y).unwrap()) + Some(pallas::Affine::from_xy(*x, *y).unwrap()) } _ => None, } diff --git a/src/circuit/gadget/ecc/chip/add.rs b/src/circuit/gadget/ecc/chip/add.rs index 023a678b..510ea6a6 100644 --- a/src/circuit/gadget/ecc/chip/add.rs +++ b/src/circuit/gadget/ecc/chip/add.rs @@ -248,7 +248,7 @@ impl Config { let gamma = x_q; let delta = y_q + y_p; - let mut inverses = [alpha, beta, gamma, delta]; + let mut inverses = [alpha, *beta, *gamma, delta]; inverses.batch_invert(); inverses }); @@ -329,11 +329,11 @@ impl Config { { if x_p.is_zero_vartime() { // 0 + Q = Q - (x_q, y_q) + (*x_q, *y_q) } else if x_q.is_zero_vartime() { // P + 0 = P - (x_p, y_p) - } else if (x_q == x_p) && (y_q == -y_p) { + (*x_p, *y_p) + } else if (x_q == x_p) && (*y_q == -y_p) { // P + (-P) maps to (0,0) (pallas::Base::zero(), pallas::Base::zero()) } else { diff --git a/src/circuit/gadget/ecc/chip/mul.rs b/src/circuit/gadget/ecc/chip/mul.rs index bf1bae58..fc99adfe 100644 --- a/src/circuit/gadget/ecc/chip/mul.rs +++ b/src/circuit/gadget/ecc/chip/mul.rs @@ -364,7 +364,7 @@ impl Config { // If `lsb` is 0, return `Acc + (-P)`. If `lsb` is 1, simply return `Acc + 0`. let x = if let Some(lsb) = lsb { if !lsb { - base.x.value() + base.x.value().cloned() } else { Some(pallas::Base::zero()) } @@ -441,7 +441,7 @@ impl Deref for Z { } } -fn decompose_for_scalar_mul(scalar: Option) -> Vec> { +fn decompose_for_scalar_mul(scalar: Option<&pallas::Base>) -> Vec> { let bitstring = scalar.map(|scalar| { // We use `k = scalar + t_q` in the double-and-add algorithm, where // the scalar field `F_q = 2^254 + t_q`. diff --git a/src/circuit/gadget/ecc/chip/mul/complete.rs b/src/circuit/gadget/ecc/chip/mul/complete.rs index 476efd80..12c1d525 100644 --- a/src/circuit/gadget/ecc/chip/mul/complete.rs +++ b/src/circuit/gadget/ecc/chip/mul/complete.rs @@ -161,10 +161,12 @@ impl Config { )?; // If the bit is set, use `y`; if the bit is not set, use `-y` - let y_p = base_y - .value() - .zip(k.as_ref()) - .map(|(base_y, k)| if !k { -base_y } else { base_y }); + let y_p = + base_y + .value() + .cloned() + .zip(k.as_ref()) + .map(|(base_y, k)| if !k { -base_y } else { base_y }); let y_p_cell = region.assign_advice( || "y_p", diff --git a/src/circuit/gadget/ecc/chip/mul/incomplete.rs b/src/circuit/gadget/ecc/chip/mul/incomplete.rs index 4ccd817d..17874ec2 100644 --- a/src/circuit/gadget/ecc/chip/mul/incomplete.rs +++ b/src/circuit/gadget/ecc/chip/mul/incomplete.rs @@ -190,8 +190,8 @@ impl Config { assert_eq!(bits.len(), NUM_BITS); // Handle exceptional cases - let (x_p, y_p) = (base.x.value(), base.y.value()); - let (x_a, y_a) = (acc.0.value(), acc.1.value()); + let (x_p, y_p) = (base.x.value().cloned(), base.y.value().cloned()); + let (x_a, y_a) = (acc.0.value().cloned(), acc.1.value().cloned()); if let (Some(x_a), Some(y_a), Some(x_p), Some(y_p)) = (x_a, y_a, x_p, y_p) { // A is point at infinity @@ -229,7 +229,7 @@ impl Config { let x_a = copy(region, || "starting x_a", self.x_a, offset + 1, &acc.0)?; let y_a = copy(region, || "starting y_a", self.lambda1, offset, &acc.1)?; - (x_a, y_a.value(), z) + (x_a, y_a.value().cloned(), z) }; // Increase offset by 1; we used row 0 for initializing `z`. @@ -313,7 +313,7 @@ impl Config { .zip(x_r) .map(|((lambda2, x_a), x_r)| lambda2.square() - x_a - x_r); y_a = lambda2 - .zip(x_a.value()) + .zip(x_a.value().cloned()) .zip(x_a_new) .zip(y_a) .map(|(((lambda2, x_a), x_a_new), y_a)| lambda2 * (x_a - x_a_new) - y_a); diff --git a/src/circuit/gadget/ecc/chip/mul_fixed/base_field_elem.rs b/src/circuit/gadget/ecc/chip/mul_fixed/base_field_elem.rs index 06283905..dc421e53 100644 --- a/src/circuit/gadget/ecc/chip/mul_fixed/base_field_elem.rs +++ b/src/circuit/gadget/ecc/chip/mul_fixed/base_field_elem.rs @@ -4,7 +4,6 @@ use super::H_BASE; use crate::{ circuit::gadget::utilities::{ bitrange_subset, copy, lookup_range_check::LookupRangeCheckConfig, range_check, CellValue, - Var, }, constants::{self, T_P}, primitives::sinsemilla, diff --git a/src/circuit/gadget/ecc/chip/mul_fixed/full_width.rs b/src/circuit/gadget/ecc/chip/mul_fixed/full_width.rs index 03e976ab..7e632ff0 100644 --- a/src/circuit/gadget/ecc/chip/mul_fixed/full_width.rs +++ b/src/circuit/gadget/ecc/chip/mul_fixed/full_width.rs @@ -85,7 +85,7 @@ impl Config { // Decompose scalar into `k-bit` windows let scalar_windows: Option> = scalar.map(|scalar| { util::decompose_word::( - scalar, + &scalar, SCALAR_NUM_BITS, constants::FIXED_BASE_WINDOW_SIZE, ) diff --git a/src/circuit/gadget/ecc/chip/mul_fixed/short.rs b/src/circuit/gadget/ecc/chip/mul_fixed/short.rs index 319511e8..7982d998 100644 --- a/src/circuit/gadget/ecc/chip/mul_fixed/short.rs +++ b/src/circuit/gadget/ecc/chip/mul_fixed/short.rs @@ -160,10 +160,10 @@ impl Config { // Conditionally negate `y`-coordinate let y_val = if let Some(sign) = sign.value() { - if sign == -pallas::Base::one() { - magnitude_mul.y.value().map(|y: pallas::Base| -y) + if sign == &-pallas::Base::one() { + magnitude_mul.y.value().cloned().map(|y: pallas::Base| -y) } else { - magnitude_mul.y.value() + magnitude_mul.y.value().cloned() } } else { None @@ -199,7 +199,7 @@ impl Config { if let (Some(magnitude), Some(sign)) = (scalar.magnitude.value(), scalar.sign.value()) { let magnitude_is_valid = - magnitude <= pallas::Base::from_u64(0xFFFF_FFFF_FFFF_FFFFu64); + magnitude <= &pallas::Base::from_u64(0xFFFF_FFFF_FFFF_FFFFu64); let sign_is_valid = sign * sign == pallas::Base::one(); if magnitude_is_valid && sign_is_valid { let base: super::OrchardFixedBases = base.clone().into(); @@ -211,7 +211,7 @@ impl Config { let magnitude = pallas::Scalar::from_bytes(&magnitude.to_bytes()).unwrap(); - let sign = if sign == pallas::Base::one() { + let sign = if sign == &pallas::Base::one() { pallas::Scalar::one() } else { -pallas::Scalar::one() diff --git a/src/circuit/gadget/poseidon/pow5.rs b/src/circuit/gadget/poseidon/pow5.rs index ba314fe3..fb8a9042 100644 --- a/src/circuit/gadget/poseidon/pow5.rs +++ b/src/circuit/gadget/poseidon/pow5.rs @@ -3,13 +3,13 @@ use std::iter; use halo2::{ arithmetic::FieldExt, - circuit::{Cell, Chip, Layouter, Region}, + circuit::{AssignedCell, Cell, Chip, Layouter, Region}, plonk::{Advice, Column, ConstraintSystem, Error, Expression, Fixed, Selector}, poly::Rotation, }; use super::{PoseidonDuplexInstructions, PoseidonInstructions}; -use crate::circuit::gadget::utilities::{CellValue, Var}; +use crate::circuit::gadget::utilities::Var; use crate::primitives::poseidon::{Domain, Mds, Spec, SpongeState, State}; /// Configuration for a [`Pow5Chip`]. @@ -288,10 +288,7 @@ impl, const WIDTH: usize, const RATE: usize 0, value, )?; - state.push(StateWord { - var, - value: Some(value), - }); + state.push(StateWord(var)); Ok(()) }; @@ -323,15 +320,15 @@ impl, const WIDTH: usize, const RATE: usize // Load the initial state into this region. let load_state_word = |i: usize| { - let value = initial_state[i].value; + let value = initial_state[i].0.value().cloned(); let var = region.assign_advice( || format!("load state_{}", i), config.state[i], 0, || value.ok_or(Error::Synthesis), )?; - region.constrain_equal(initial_state[i].var, var)?; - Ok(StateWord { var, value }) + region.constrain_equal(initial_state[i].0.cell(), var.cell())?; + Ok(StateWord(var)) }; let initial_state: Result, Error> = (0..WIDTH).map(load_state_word).collect(); @@ -342,7 +339,7 @@ impl, const WIDTH: usize, const RATE: usize // Load the input and padding into this region. let load_input_word = |i: usize| { let (constraint_var, value) = match (input[i].clone(), padding_values[i]) { - (Some(word), None) => (word.var, word.value), + (Some(word), None) => (word.0.cell(), word.0.value().cloned()), (None, Some(padding_value)) => { let padding_var = region.assign_fixed( || format!("load pad_{}", i), @@ -350,7 +347,7 @@ impl, const WIDTH: usize, const RATE: usize 1, || Ok(padding_value), )?; - (padding_var, Some(padding_value)) + (padding_var.cell(), Some(padding_value)) } _ => panic!("Input and padding don't match"), }; @@ -360,30 +357,31 @@ impl, const WIDTH: usize, const RATE: usize 1, || value.ok_or(Error::Synthesis), )?; - region.constrain_equal(constraint_var, var)?; + region.constrain_equal(constraint_var, var.cell())?; - Ok(StateWord { var, value }) + Ok(StateWord(var)) }; let input: Result, Error> = (0..RATE).map(load_input_word).collect(); let input = input?; // Constrain the output. let constrain_output_word = |i: usize| { - let value = initial_state[i].value.and_then(|initial_word| { + let value = initial_state[i].0.value().and_then(|initial_word| { input .get(i) - .map(|word| word.value) + .map(|word| word.0.value().cloned()) // The capacity element is never altered by the input. .unwrap_or_else(|| Some(F::zero())) - .map(|input_word| initial_word + input_word) + .map(|input_word| *initial_word + input_word) }); - let var = region.assign_advice( - || format!("load output_{}", i), - config.state[i], - 2, - || value.ok_or(Error::Synthesis), - )?; - Ok(StateWord { var, value }) + region + .assign_advice( + || format!("load output_{}", i), + config.state[i], + 2, + || value.ok_or(Error::Synthesis), + ) + .map(StateWord) }; let output: Result, Error> = (0..WIDTH).map(constrain_output_word).collect(); @@ -404,34 +402,31 @@ impl, const WIDTH: usize, const RATE: usize /// A word in the Poseidon state. #[derive(Clone, Debug)] -pub struct StateWord { - var: Cell, - value: Option, -} +pub struct StateWord(AssignedCell); -impl From> for CellValue { - fn from(state_word: StateWord) -> CellValue { - CellValue::new(state_word.var, state_word.value) +impl From> for AssignedCell { + fn from(state_word: StateWord) -> AssignedCell { + state_word.0 } } -impl From> for StateWord { - fn from(cell_value: CellValue) -> StateWord { - StateWord::new(cell_value.cell(), cell_value.value()) +impl From> for StateWord { + fn from(cell_value: AssignedCell) -> StateWord { + StateWord(cell_value) } } impl Var for StateWord { - fn new(var: Cell, value: Option) -> Self { - Self { var, value } + fn new(var: AssignedCell, value: Option) -> Self { + Self(var) } fn cell(&self) -> Cell { - self.var + self.0.cell() } fn value(&self) -> Option { - self.value + self.0.value().cloned() } } @@ -447,11 +442,11 @@ impl Pow5State { offset: usize, ) -> Result { Self::round(region, config, round, offset, config.s_full, |_| { - let q = self - .0 - .iter() - .enumerate() - .map(|(idx, word)| word.value.map(|v| v + config.round_constants[round][idx])); + let q = self.0.iter().enumerate().map(|(idx, word)| { + word.0 + .value() + .map(|v| *v + config.round_constants[round][idx]) + }); let r: Option> = q.map(|q| q.map(|q| q.pow(&config.alpha))).collect(); let m = &config.m_reg; let state = m.iter().map(|m_i| { @@ -475,7 +470,7 @@ impl Pow5State { ) -> Result { Self::round(region, config, round, offset, config.s_partial, |region| { let m = &config.m_reg; - let p: Option> = self.0.iter().map(|word| word.value).collect(); + let p: Option> = self.0.iter().map(|word| word.0.value().cloned()).collect(); let r: Option> = p.map(|p| { let r_0 = (p[0] + config.round_constants[round][0]).pow(&config.alpha); @@ -547,15 +542,15 @@ impl Pow5State { initial_state: &State, WIDTH>, ) -> Result { let load_state_word = |i: usize| { - let value = initial_state[i].value; + let value = initial_state[i].0.value().cloned(); let var = region.assign_advice( || format!("load state_{}", i), config.state[i], 0, || value.ok_or(Error::Synthesis), )?; - region.constrain_equal(initial_state[i].var, var)?; - Ok(StateWord { var, value }) + region.constrain_equal(initial_state[i].0.cell(), var.cell())?; + Ok(StateWord(var)) }; let state: Result, _> = (0..WIDTH).map(load_state_word).collect(); @@ -597,7 +592,7 @@ impl Pow5State { offset + 1, || value.ok_or(Error::Synthesis), )?; - Ok(StateWord { var, value }) + Ok(StateWord(var)) }; let next_state: Result, _> = (0..WIDTH).map(next_state_word).collect(); @@ -674,7 +669,7 @@ mod tests { 0, || value.ok_or(Error::Synthesis), )?; - Ok(StateWord { var, value }) + Ok(StateWord(var)) }; let state: Result, Error> = (0..WIDTH).map(state_word).collect(); @@ -713,7 +708,7 @@ mod tests { 0, || Ok(expected_final_state[i]), )?; - region.constrain_equal(final_state[i].var, var) + region.constrain_equal(final_state[i].0.cell(), var.cell()) }; for i in 0..(WIDTH) { @@ -821,7 +816,7 @@ mod tests { 0, || self.output.ok_or(Error::Synthesis), )?; - region.constrain_equal(output.cell(), expected_var) + region.constrain_equal(output.cell(), expected_var.cell()) }, ) } diff --git a/src/circuit/gadget/sinsemilla/chip/hash_to_point.rs b/src/circuit/gadget/sinsemilla/chip/hash_to_point.rs index 60b8f23c..8c8aff5b 100644 --- a/src/circuit/gadget/sinsemilla/chip/hash_to_point.rs +++ b/src/circuit/gadget/sinsemilla/chip/hash_to_point.rs @@ -142,7 +142,7 @@ impl SinsemillaChip { .chunks(K) .fold(Q.to_curve(), |acc, chunk| (acc + S(chunk)) + acc); let actual_point = - pallas::Affine::from_xy(x_a.value().unwrap(), y_a.value().unwrap()).unwrap(); + pallas::Affine::from_xy(*x_a.value().unwrap(), *y_a.value().unwrap()).unwrap(); assert_eq!(expected_point.to_affine(), actual_point); } } @@ -270,7 +270,7 @@ impl SinsemillaChip { offset, || piece.field_elem().ok_or(Error::Synthesis), )?; - region.constrain_equal(piece.cell(), cell)?; + region.constrain_equal(piece.cell(), cell.cell())?; zs.push(CellValue::new(cell, piece.field_elem())); // Assign cumulative sum such that for 0 <= i < n, diff --git a/src/circuit/gadget/sinsemilla/commit_ivk.rs b/src/circuit/gadget/sinsemilla/commit_ivk.rs index f089472e..a10f8c0c 100644 --- a/src/circuit/gadget/sinsemilla/commit_ivk.rs +++ b/src/circuit/gadget/sinsemilla/commit_ivk.rs @@ -8,7 +8,7 @@ use pasta_curves::{arithmetic::FieldExt, pallas}; use crate::{ circuit::gadget::{ ecc::{chip::EccChip, X}, - utilities::{bitrange_subset, bool_check, copy, CellValue, Var}, + utilities::{bitrange_subset, bool_check, copy, CellValue}, }, constants::T_P, }; @@ -641,7 +641,7 @@ mod tests { ecc::chip::{EccChip, EccConfig}, sinsemilla::chip::SinsemillaChip, utilities::{ - lookup_range_check::LookupRangeCheckConfig, CellValue, UtilitiesInstructions, Var, + lookup_range_check::LookupRangeCheckConfig, CellValue, UtilitiesInstructions, }, }, constants::{COMMIT_IVK_PERSONALIZATION, L_ORCHARD_BASE, T_Q}, @@ -803,7 +803,7 @@ mod tests { .unwrap() }; - assert_eq!(expected_ivk, ivk.inner().value().unwrap()); + assert_eq!(&expected_ivk, ivk.inner().value().unwrap()); Ok(()) } diff --git a/src/circuit/gadget/sinsemilla/merkle.rs b/src/circuit/gadget/sinsemilla/merkle.rs index a5d24d48..c28f83e4 100644 --- a/src/circuit/gadget/sinsemilla/merkle.rs +++ b/src/circuit/gadget/sinsemilla/merkle.rs @@ -139,7 +139,7 @@ pub mod tests { use crate::{ circuit::gadget::{ sinsemilla::chip::{SinsemillaChip, SinsemillaHashDomains}, - utilities::{lookup_range_check::LookupRangeCheckConfig, UtilitiesInstructions, Var}, + utilities::{lookup_range_check::LookupRangeCheckConfig, UtilitiesInstructions}, }, constants::MERKLE_DEPTH_ORCHARD, note::commitment::ExtractedNoteCommitment, @@ -266,7 +266,7 @@ pub mod tests { }; // Check the computed final root against the expected final root. - assert_eq!(computed_final_root.value().unwrap(), final_root.inner()); + assert_eq!(computed_final_root.value().unwrap(), &final_root.inner()); } Ok(()) diff --git a/src/circuit/gadget/sinsemilla/merkle/chip.rs b/src/circuit/gadget/sinsemilla/merkle/chip.rs index 46107c86..310a7ae2 100644 --- a/src/circuit/gadget/sinsemilla/merkle/chip.rs +++ b/src/circuit/gadget/sinsemilla/merkle/chip.rs @@ -15,7 +15,7 @@ use crate::{ circuit::gadget::utilities::{ bitrange_subset, cond_swap::{CondSwapChip, CondSwapConfig, CondSwapInstructions}, - copy, CellValue, UtilitiesInstructions, Var, + copy, CellValue, UtilitiesInstructions, }, constants::{L_ORCHARD_BASE, MERKLE_DEPTH_ORCHARD}, primitives::sinsemilla, @@ -185,7 +185,7 @@ impl MerkleInstructions { } impl MessagePiece { - pub fn new(cell: Cell, field_elem: Option, num_words: usize) -> Self { + pub fn new(cell: CellValue, field_elem: Option, num_words: usize) -> Self { assert!(num_words * K < F::NUM_BITS as usize); let cell_value = CellValue::new(cell, field_elem); Self { @@ -58,7 +58,7 @@ impl MessagePiece { } pub fn field_elem(&self) -> Option { - self.cell_value.value() + self.cell_value.value().cloned() } pub fn cell_value(&self) -> CellValue { diff --git a/src/circuit/gadget/sinsemilla/note_commit.rs b/src/circuit/gadget/sinsemilla/note_commit.rs index 6ad2e821..c50b9fdd 100644 --- a/src/circuit/gadget/sinsemilla/note_commit.rs +++ b/src/circuit/gadget/sinsemilla/note_commit.rs @@ -530,8 +530,10 @@ impl NoteCommitConfig { psi: CellValue, rcm: Option, ) -> Result, Error> { - let (gd_x, gd_y) = (g_d.x().value(), g_d.y().value()); - let (pkd_x, pkd_y) = (pk_d.x().value(), pk_d.y().value()); + let (gd_x, gd_y) = (g_d.x(), g_d.y()); + let (pkd_x, pkd_y) = (pk_d.x(), pk_d.y()); + let (gd_x, gd_y) = (gd_x.value(), gd_y.value()); + let (pkd_x, pkd_y) = (pkd_x.value(), pkd_y.value()); let value_val = value.value(); let rho_val = rho.value(); let psi_val = psi.value(); diff --git a/src/circuit/gadget/utilities.rs b/src/circuit/gadget/utilities.rs index 3bfed1b9..12997ce1 100644 --- a/src/circuit/gadget/utilities.rs +++ b/src/circuit/gadget/utilities.rs @@ -2,7 +2,7 @@ use ff::PrimeFieldBits; use halo2::{ - circuit::{Cell, Layouter, Region}, + circuit::{AssignedCell, Cell, Layouter, Region}, plonk::{Advice, Column, Error, Expression}, }; use pasta_curves::arithmetic::FieldExt; @@ -13,16 +13,12 @@ pub(crate) mod decompose_running_sum; pub(crate) mod lookup_range_check; /// A variable representing a field element. -#[derive(Clone, Debug)] -pub struct CellValue { - cell: Cell, - value: Option, -} +pub type CellValue = AssignedCell; /// Trait for a variable in the circuit. pub trait Var: Clone + std::fmt::Debug { /// Construct a new variable. - fn new(cell: Cell, value: Option) -> Self; + fn new(cell: AssignedCell, value: Option) -> Self; /// The cell at which this variable was allocated. fn cell(&self) -> Cell; @@ -32,16 +28,16 @@ pub trait Var: Clone + std::fmt::Debug { } impl Var for CellValue { - fn new(cell: Cell, value: Option) -> Self { - Self { cell, value } + fn new(cell: AssignedCell, _value: Option) -> Self { + cell } fn cell(&self) -> Cell { - self.cell + self.cell() } fn value(&self) -> Option { - self.value + self.value().cloned() } } @@ -90,13 +86,9 @@ where A: Fn() -> AR, AR: Into, { - let cell = region.assign_advice(annotation, column, offset, || { - copy.value.ok_or(Error::Synthesis) - })?; - - region.constrain_equal(cell, copy.cell)?; - - Ok(CellValue::new(cell, copy.value)) + // Temporarily implement `copy()` in terms of `AssignedCell::copy_advice`. + // We will remove this in a subsequent commit. + copy.copy_advice(annotation, region, column, offset) } pub(crate) fn transpose_option_array( @@ -126,7 +118,7 @@ pub fn ternary(a: Expression, b: Expression, c: Expression /// Takes a specified subsequence of the little-endian bit representation of a field element. /// The bits are numbered from 0 for the LSB. -pub fn bitrange_subset(field_elem: F, bitrange: Range) -> F { +pub fn bitrange_subset(field_elem: &F, bitrange: Range) -> F { assert!(bitrange.end <= F::NUM_BITS as usize); let bits: Vec = field_elem @@ -251,7 +243,7 @@ mod tests { { let field_elem = pallas::Base::rand(); let bitrange = 0..(pallas::Base::NUM_BITS as usize); - let subset = bitrange_subset(field_elem, bitrange); + let subset = bitrange_subset(&field_elem, bitrange); assert_eq!(field_elem, subset); } @@ -259,7 +251,7 @@ mod tests { { let field_elem = pallas::Base::rand(); let bitrange = 0..0; - let subset = bitrange_subset(field_elem, bitrange); + let subset = bitrange_subset(&field_elem, bitrange); assert_eq!(pallas::Base::zero(), subset); } @@ -286,7 +278,7 @@ mod tests { let subsets = ranges .iter() - .map(|range| bitrange_subset(field_elem, range.clone())) + .map(|range| bitrange_subset(&field_elem, range.clone())) .collect::>(); let mut sum = subsets[0]; diff --git a/src/circuit/gadget/utilities/cond_swap.rs b/src/circuit/gadget/utilities/cond_swap.rs index 365ba500..e979048e 100644 --- a/src/circuit/gadget/utilities/cond_swap.rs +++ b/src/circuit/gadget/utilities/cond_swap.rs @@ -98,39 +98,33 @@ impl CondSwapInstructions for CondSwapChip { // Conditionally swap a let a_swapped = { let a_swapped = a - .value - .zip(b.value) + .value() + .zip(b.value()) .zip(swap) - .map(|((a, b), swap)| if swap { b } else { a }); - let a_swapped_cell = region.assign_advice( + .map(|((a, b), swap)| if swap { b } else { a }) + .cloned(); + region.assign_advice( || "a_swapped", config.a_swapped, 0, || a_swapped.ok_or(Error::Synthesis), - )?; - CellValue { - cell: a_swapped_cell, - value: a_swapped, - } + )? }; // Conditionally swap b let b_swapped = { let b_swapped = a - .value - .zip(b.value) + .value() + .zip(b.value()) .zip(swap) - .map(|((a, b), swap)| if swap { a } else { b }); - let b_swapped_cell = region.assign_advice( + .map(|((a, b), swap)| if swap { a } else { b }) + .cloned(); + region.assign_advice( || "b_swapped", config.b_swapped, 0, || b_swapped.ok_or(Error::Synthesis), - )?; - CellValue { - cell: b_swapped_cell, - value: b_swapped, - } + )? }; // Return swapped pair @@ -261,12 +255,12 @@ mod tests { if let Some(swap) = self.swap { if swap { // Check that `a` and `b` have been swapped - assert_eq!(swapped_pair.0.value.unwrap(), self.b.unwrap()); - assert_eq!(swapped_pair.1.value.unwrap(), a.value.unwrap()); + assert_eq!(swapped_pair.0.value().unwrap(), &self.b.unwrap()); + assert_eq!(swapped_pair.1.value().unwrap(), a.value().unwrap()); } else { // Check that `a` and `b` have not been swapped - assert_eq!(swapped_pair.0.value.unwrap(), a.value.unwrap()); - assert_eq!(swapped_pair.1.value.unwrap(), self.b.unwrap()); + assert_eq!(swapped_pair.0.value().unwrap(), a.value().unwrap()); + assert_eq!(swapped_pair.1.value().unwrap(), &self.b.unwrap()); } } diff --git a/src/circuit/gadget/utilities/decompose_running_sum.rs b/src/circuit/gadget/utilities/decompose_running_sum.rs index 43e4c9b9..6b2da383 100644 --- a/src/circuit/gadget/utilities/decompose_running_sum.rs +++ b/src/circuit/gadget/utilities/decompose_running_sum.rs @@ -193,7 +193,7 @@ impl let z_next_val = z .value() .zip(word) - .map(|(z_cur_val, word)| (z_cur_val - word) * two_pow_k_inv); + .map(|(z_cur_val, word)| (*z_cur_val - word) * two_pow_k_inv); let cell = region.assign_advice( || format!("z_{:?}", i + 1), self.z, diff --git a/src/circuit/gadget/utilities/lookup_range_check.rs b/src/circuit/gadget/utilities/lookup_range_check.rs index 65ef4b59..0f535721 100644 --- a/src/circuit/gadget/utilities/lookup_range_check.rs +++ b/src/circuit/gadget/utilities/lookup_range_check.rs @@ -244,7 +244,7 @@ impl LookupRangeCheckConfig let z_val = z .value() .zip(*word) - .map(|(z, word)| (z - word) * inv_two_pow_k); + .map(|(z, word)| (*z - word) * inv_two_pow_k); // Assign z_next let z_cell = region.assign_advice( @@ -344,7 +344,7 @@ impl LookupRangeCheckConfig // Assign shifted `element * 2^{K - num_bits}` let shifted = element.value().map(|element| { let shift = F::from_u64(1 << (K - num_bits)); - element * shift + *element * shift }); region.assign_advice( @@ -369,7 +369,6 @@ impl LookupRangeCheckConfig #[cfg(test)] mod tests { - use super::super::Var; use super::LookupRangeCheckConfig; use crate::primitives::sinsemilla::{INV_TWO_POW_K, K}; @@ -468,7 +467,7 @@ mod tests { for (expected_z, z) in expected_zs.into_iter().zip(zs.iter()) { if let Some(z) = z.value() { - assert_eq!(expected_z, z); + assert_eq!(&expected_z, z); } } } diff --git a/src/constants/util.rs b/src/constants/util.rs index bdb84901..b325091c 100644 --- a/src/constants/util.rs +++ b/src/constants/util.rs @@ -10,7 +10,7 @@ use halo2::arithmetic::{CurveAffine, FieldExt}; /// We are returning a `Vec` which means the window size is limited to /// <= 8 bits. pub fn decompose_word( - word: F, + word: &F, word_num_bits: usize, window_num_bits: usize, ) -> Vec { @@ -86,7 +86,7 @@ mod tests { window_num_bits in 1u8..9 ) { // Get decomposition into `window_num_bits` bits - let decomposed = decompose_word(scalar, pallas::Scalar::NUM_BITS as usize, window_num_bits as usize); + let decomposed = decompose_word(&scalar, pallas::Scalar::NUM_BITS as usize, window_num_bits as usize); // Flatten bits let bits = decomposed