From f2d1f1d56a0b5d98dbfe34172828f96112df6a96 Mon Sep 17 00:00:00 2001 From: Jack Grigg Date: Fri, 3 Dec 2021 03:13:12 +0000 Subject: [PATCH] sha256: Add `InitialRound` and `MainRoundIdx` structs This enables the runtime `assert!(matches!(..))` on `RoundIdx` to be replaced by type system checks. --- .../table16/compression/compression_util.rs | 129 +++++++++--------- .../table16/compression/subregion_initial.rs | 24 ++-- .../table16/compression/subregion_main.rs | 12 +- 3 files changed, 81 insertions(+), 84 deletions(-) diff --git a/halo2_gadgets/src/sha256/table16/compression/compression_util.rs b/halo2_gadgets/src/sha256/table16/compression/compression_util.rs index 42f758cd..afa9c78e 100644 --- a/halo2_gadgets/src/sha256/table16/compression/compression_util.rs +++ b/halo2_gadgets/src/sha256/table16/compression/compression_util.rs @@ -39,70 +39,77 @@ pub const SUBREGION_MAIN_WORD: usize = DECOMPOSE_ABCD + SIGMA_0_ROWS + DECOMPOSE_EFGH + SIGMA_1_ROWS + CH_ROWS + MAJ_ROWS; pub const SUBREGION_MAIN_ROWS: usize = SUBREGION_MAIN_LEN * SUBREGION_MAIN_WORD; +/// The initial round. +pub struct InitialRound; + +/// A main round index. +#[derive(Debug, Copy, Clone)] +pub struct MainRoundIdx(usize); + /// Round index. #[derive(Debug, Copy, Clone)] pub enum RoundIdx { Init, - Main(usize), + Main(MainRoundIdx), } -impl RoundIdx { +impl From for RoundIdx { + fn from(_: InitialRound) -> Self { + RoundIdx::Init + } +} + +impl From for RoundIdx { + fn from(idx: MainRoundIdx) -> Self { + RoundIdx::Main(idx) + } +} + +impl MainRoundIdx { pub(crate) fn as_usize(&self) -> usize { - match self { - Self::Main(idx) => *idx, - _ => panic!(), - } + self.0 } } -impl From for RoundIdx { +impl From for MainRoundIdx { fn from(idx: usize) -> Self { - Self::Main(idx) + MainRoundIdx(idx) } } -impl std::ops::Add for RoundIdx { +impl std::ops::Add for MainRoundIdx { type Output = Self; fn add(self, rhs: usize) -> Self::Output { - match self { - Self::Main(idx) => Self::Main(idx + rhs), - _ => panic!(), - } + MainRoundIdx(self.0 + rhs) } } -impl Ord for RoundIdx { +impl Ord for MainRoundIdx { fn cmp(&self, other: &Self) -> std::cmp::Ordering { - match (self, other) { - (Self::Main(idx_0), Self::Main(idx_1)) => idx_0.cmp(idx_1), - _ => panic!(), - } + self.0.cmp(&other.0) } } -impl PartialOrd for RoundIdx { +impl PartialOrd for MainRoundIdx { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -impl PartialEq for RoundIdx { +impl PartialEq for MainRoundIdx { fn eq(&self, other: &Self) -> bool { - match (self, other) { - (Self::Main(idx_0), Self::Main(idx_1)) => idx_0 == idx_1, - _ => panic!(), - } + self.0 == other.0 } } -impl Eq for RoundIdx {} +impl Eq for MainRoundIdx {} /// Returns starting row number of a compression round pub fn get_round_row(round_idx: RoundIdx) -> usize { match round_idx { RoundIdx::Init => 0, - RoundIdx::Main(idx) => { + RoundIdx::Main(MainRoundIdx(idx)) => { assert!(idx < 64); (idx as usize) * SUBREGION_MAIN_WORD } @@ -113,81 +120,73 @@ pub fn get_decompose_e_row(round_idx: RoundIdx) -> usize { get_round_row(round_idx) } -pub fn get_decompose_f_row(round_idx: RoundIdx) -> usize { - assert!(matches!(round_idx, RoundIdx::Init)); - get_decompose_e_row(round_idx) + DECOMPOSE_EFGH +pub fn get_decompose_f_row(round_idx: InitialRound) -> usize { + get_decompose_e_row(round_idx.into()) + DECOMPOSE_EFGH } -pub fn get_decompose_g_row(round_idx: RoundIdx) -> usize { +pub fn get_decompose_g_row(round_idx: InitialRound) -> usize { get_decompose_f_row(round_idx) + DECOMPOSE_EFGH } -pub fn get_upper_sigma_1_row(round_idx: RoundIdx) -> usize { - assert!(matches!(round_idx, RoundIdx::Main(_))); - get_decompose_e_row(round_idx) + DECOMPOSE_EFGH + 1 +pub fn get_upper_sigma_1_row(round_idx: MainRoundIdx) -> usize { + get_decompose_e_row(round_idx.into()) + DECOMPOSE_EFGH + 1 } -pub fn get_ch_row(round_idx: RoundIdx) -> usize { - assert!(matches!(round_idx, RoundIdx::Main(_))); - get_decompose_e_row(round_idx) + DECOMPOSE_EFGH + SIGMA_1_ROWS + 1 +pub fn get_ch_row(round_idx: MainRoundIdx) -> usize { + get_decompose_e_row(round_idx.into()) + DECOMPOSE_EFGH + SIGMA_1_ROWS + 1 } -pub fn get_ch_neg_row(round_idx: RoundIdx) -> usize { +pub fn get_ch_neg_row(round_idx: MainRoundIdx) -> usize { get_ch_row(round_idx) + CH_ROWS / 2 } pub fn get_decompose_a_row(round_idx: RoundIdx) -> usize { match round_idx { RoundIdx::Init => get_h_row(round_idx) + DECOMPOSE_EFGH, - _ => get_ch_neg_row(round_idx) - 1 + CH_ROWS / 2, + RoundIdx::Main(mri) => get_ch_neg_row(mri) - 1 + CH_ROWS / 2, } } -pub fn get_upper_sigma_0_row(round_idx: RoundIdx) -> usize { - assert!(matches!(round_idx, RoundIdx::Main(_))); - get_decompose_a_row(round_idx) + DECOMPOSE_ABCD + 1 +pub fn get_upper_sigma_0_row(round_idx: MainRoundIdx) -> usize { + get_decompose_a_row(round_idx.into()) + DECOMPOSE_ABCD + 1 } -pub fn get_decompose_b_row(round_idx: RoundIdx) -> usize { - assert!(matches!(round_idx, RoundIdx::Init)); - get_decompose_a_row(round_idx) + DECOMPOSE_ABCD +pub fn get_decompose_b_row(round_idx: InitialRound) -> usize { + get_decompose_a_row(round_idx.into()) + DECOMPOSE_ABCD } -pub fn get_decompose_c_row(round_idx: RoundIdx) -> usize { +pub fn get_decompose_c_row(round_idx: InitialRound) -> usize { get_decompose_b_row(round_idx) + DECOMPOSE_ABCD } -pub fn get_maj_row(round_idx: RoundIdx) -> usize { - assert!(matches!(round_idx, RoundIdx::Main(_))); +pub fn get_maj_row(round_idx: MainRoundIdx) -> usize { get_upper_sigma_0_row(round_idx) + SIGMA_0_ROWS } // Get state word rows pub fn get_h_row(round_idx: RoundIdx) -> usize { match round_idx { - RoundIdx::Init => get_decompose_g_row(round_idx) + DECOMPOSE_EFGH, - _ => get_ch_row(round_idx) - 1, + RoundIdx::Init => get_decompose_g_row(InitialRound) + DECOMPOSE_EFGH, + RoundIdx::Main(mri) => get_ch_row(mri) - 1, } } -pub fn get_h_prime_row(round_idx: RoundIdx) -> usize { - assert!(matches!(round_idx, RoundIdx::Main(_))); +pub fn get_h_prime_row(round_idx: MainRoundIdx) -> usize { get_ch_row(round_idx) } pub fn get_d_row(round_idx: RoundIdx) -> usize { match round_idx { - RoundIdx::Init => get_decompose_c_row(round_idx) + DECOMPOSE_ABCD, - _ => get_ch_row(round_idx) + 2, + RoundIdx::Init => get_decompose_c_row(InitialRound) + DECOMPOSE_ABCD, + RoundIdx::Main(mri) => get_ch_row(mri) + 2, } } -pub fn get_e_new_row(round_idx: RoundIdx) -> usize { - assert!(matches!(round_idx, RoundIdx::Main(_))); - get_d_row(round_idx) +pub fn get_e_new_row(round_idx: MainRoundIdx) -> usize { + get_d_row(round_idx.into()) } -pub fn get_a_new_row(round_idx: RoundIdx) -> usize { +pub fn get_a_new_row(round_idx: MainRoundIdx) -> usize { get_maj_row(round_idx) } @@ -371,7 +370,7 @@ impl CompressionConfig { pub(super) fn assign_upper_sigma_0( &self, region: &mut Region<'_, pallas::Base>, - round_idx: RoundIdx, + round_idx: MainRoundIdx, word: AbcdVar, ) -> Result<(AssignedBits<16>, AssignedBits<16>), Error> { // Rename these here for ease of matching the gates to the specification. @@ -429,7 +428,7 @@ impl CompressionConfig { pub(super) fn assign_upper_sigma_1( &self, region: &mut Region<'_, pallas::Base>, - round_idx: RoundIdx, + round_idx: MainRoundIdx, word: EfghVar, ) -> Result<(AssignedBits<16>, AssignedBits<16>), Error> { // Rename these here for ease of matching the gates to the specification. @@ -513,7 +512,7 @@ impl CompressionConfig { pub(super) fn assign_ch( &self, region: &mut Region<'_, pallas::Base>, - round_idx: RoundIdx, + round_idx: MainRoundIdx, spread_halves_e: RoundWordSpread, spread_halves_f: RoundWordSpread, ) -> Result<(AssignedBits<16>, AssignedBits<16>), Error> { @@ -559,7 +558,7 @@ impl CompressionConfig { pub(super) fn assign_ch_neg( &self, region: &mut Region<'_, pallas::Base>, - round_idx: RoundIdx, + round_idx: MainRoundIdx, spread_halves_e: RoundWordSpread, spread_halves_g: RoundWordSpread, ) -> Result<(AssignedBits<16>, AssignedBits<16>), Error> { @@ -662,7 +661,7 @@ impl CompressionConfig { pub(super) fn assign_maj( &self, region: &mut Region<'_, pallas::Base>, - round_idx: RoundIdx, + round_idx: MainRoundIdx, spread_halves_a: RoundWordSpread, spread_halves_b: RoundWordSpread, spread_halves_c: RoundWordSpread, @@ -720,7 +719,7 @@ impl CompressionConfig { pub(super) fn assign_h_prime( &self, region: &mut Region<'_, pallas::Base>, - round_idx: RoundIdx, + round_idx: MainRoundIdx, h: RoundWordDense, ch: (AssignedBits<16>, AssignedBits<16>), ch_neg: (AssignedBits<16>, AssignedBits<16>), @@ -805,7 +804,7 @@ impl CompressionConfig { pub(super) fn assign_e_new( &self, region: &mut Region<'_, pallas::Base>, - round_idx: RoundIdx, + round_idx: MainRoundIdx, d: &RoundWordDense, h_prime: &RoundWordDense, ) -> Result { @@ -842,7 +841,7 @@ impl CompressionConfig { pub(super) fn assign_a_new( &self, region: &mut Region<'_, pallas::Base>, - round_idx: RoundIdx, + round_idx: MainRoundIdx, maj: (AssignedBits<16>, AssignedBits<16>), sigma_0: (AssignedBits<16>, AssignedBits<16>), h_prime: RoundWordDense, diff --git a/halo2_gadgets/src/sha256/table16/compression/subregion_initial.rs b/halo2_gadgets/src/sha256/table16/compression/subregion_initial.rs index 367c5ad7..3a15e318 100644 --- a/halo2_gadgets/src/sha256/table16/compression/subregion_initial.rs +++ b/halo2_gadgets/src/sha256/table16/compression/subregion_initial.rs @@ -15,8 +15,8 @@ impl CompressionConfig { let e = self.decompose_e(region, RoundIdx::Init, Some(iv[4]))?; // Decompose F, G - let f = self.decompose_f(region, RoundIdx::Init, Some(iv[5]))?; - let g = self.decompose_g(region, RoundIdx::Init, Some(iv[6]))?; + let f = self.decompose_f(region, InitialRound, Some(iv[5]))?; + let g = self.decompose_g(region, InitialRound, Some(iv[6]))?; // Assign H let h_row = get_h_row(RoundIdx::Init); @@ -26,8 +26,8 @@ impl CompressionConfig { let a = self.decompose_a(region, RoundIdx::Init, Some(iv[0]))?; // Decompose B, C - let b = self.decompose_b(region, RoundIdx::Init, Some(iv[1]))?; - let c = self.decompose_c(region, RoundIdx::Init, Some(iv[2]))?; + let b = self.decompose_b(region, InitialRound, Some(iv[1]))?; + let c = self.decompose_c(region, InitialRound, Some(iv[2]))?; // Assign D let d_row = get_d_row(RoundIdx::Init); @@ -60,9 +60,9 @@ impl CompressionConfig { // Decompose F, G let f = f.dense_halves.value(); - let f = self.decompose_f(region, RoundIdx::Init, f)?; + let f = self.decompose_f(region, InitialRound, f)?; let g = g.dense_halves.value(); - let g = self.decompose_g(region, RoundIdx::Init, g)?; + let g = self.decompose_g(region, InitialRound, g)?; // Assign H let h = h.value(); @@ -75,9 +75,9 @@ impl CompressionConfig { // Decompose B, C let b = b.dense_halves.value(); - let b = self.decompose_b(region, RoundIdx::Init, b)?; + let b = self.decompose_b(region, InitialRound, b)?; let c = c.dense_halves.value(); - let c = self.decompose_c(region, RoundIdx::Init, c)?; + let c = self.decompose_c(region, InitialRound, c)?; // Assign D let d = d.value(); @@ -99,7 +99,7 @@ impl CompressionConfig { fn decompose_b( &self, region: &mut Region<'_, pallas::Base>, - round_idx: RoundIdx, + round_idx: InitialRound, b_val: Option, ) -> Result { let row = get_decompose_b_row(round_idx); @@ -112,7 +112,7 @@ impl CompressionConfig { fn decompose_c( &self, region: &mut Region<'_, pallas::Base>, - round_idx: RoundIdx, + round_idx: InitialRound, c_val: Option, ) -> Result { let row = get_decompose_c_row(round_idx); @@ -125,7 +125,7 @@ impl CompressionConfig { fn decompose_f( &self, region: &mut Region<'_, pallas::Base>, - round_idx: RoundIdx, + round_idx: InitialRound, f_val: Option, ) -> Result { let row = get_decompose_f_row(round_idx); @@ -138,7 +138,7 @@ impl CompressionConfig { fn decompose_g( &self, region: &mut Region<'_, pallas::Base>, - round_idx: RoundIdx, + round_idx: InitialRound, g_val: Option, ) -> Result { let row = get_decompose_g_row(round_idx); diff --git a/halo2_gadgets/src/sha256/table16/compression/subregion_main.rs b/halo2_gadgets/src/sha256/table16/compression/subregion_main.rs index a9f4a224..2a0f433c 100644 --- a/halo2_gadgets/src/sha256/table16/compression/subregion_main.rs +++ b/halo2_gadgets/src/sha256/table16/compression/subregion_main.rs @@ -7,12 +7,10 @@ impl CompressionConfig { pub fn assign_round( &self, region: &mut Region<'_, pallas::Base>, - round_idx: RoundIdx, + round_idx: MainRoundIdx, state: State, schedule_word: &(AssignedBits<16>, AssignedBits<16>), ) -> Result { - assert!(matches!(round_idx, RoundIdx::Main(_))); - let a_3 = self.extras[0]; let a_4 = self.extras[1]; let a_7 = self.extras[3]; @@ -70,7 +68,7 @@ impl CompressionConfig { if round_idx < 63.into() { // Assign and copy A_new - let a_new_row = get_decompose_a_row(round_idx + 1); + let a_new_row = get_decompose_a_row((round_idx + 1).into()); a_new_dense .0 .copy_advice(|| "a_new_lo", region, a_7, a_new_row)?; @@ -79,7 +77,7 @@ impl CompressionConfig { .copy_advice(|| "a_new_hi", region, a_7, a_new_row + 1)?; // Assign and copy E_new - let e_new_row = get_decompose_e_row(round_idx + 1); + let e_new_row = get_decompose_e_row((round_idx + 1).into()); e_new_dense .0 .copy_advice(|| "e_new_lo", region, a_7, e_new_row)?; @@ -88,10 +86,10 @@ impl CompressionConfig { .copy_advice(|| "e_new_hi", region, a_7, e_new_row + 1)?; // Decompose A into (2, 11, 9, 10)-bit chunks - let a_new = self.decompose_a(region, round_idx + 1, a_new_val)?; + let a_new = self.decompose_a(region, (round_idx + 1).into(), a_new_val)?; // Decompose E into (6, 5, 14, 7)-bit chunks - let e_new = self.decompose_e(region, round_idx + 1, e_new_val)?; + let e_new = self.decompose_e(region, (round_idx + 1).into(), e_new_val)?; Ok(State::new( StateWord::A(a_new),