diff --git a/zk-token-sdk/src/encryption/discrete_log.rs b/zk-token-sdk/src/encryption/discrete_log.rs index 5ffc1c206..55b9c82ad 100644 --- a/zk-token-sdk/src/encryption/discrete_log.rs +++ b/zk-token-sdk/src/encryption/discrete_log.rs @@ -28,7 +28,7 @@ use { }, itertools::Itertools, serde::{Deserialize, Serialize}, - std::collections::HashMap, + std::{collections::HashMap, num::NonZeroUsize}, thiserror::Error, }; @@ -57,14 +57,14 @@ pub struct DiscreteLog { /// Target point for discrete log pub target: RistrettoPoint, /// Number of threads used for discrete log computation - num_threads: usize, + num_threads: Option, /// Range bound for discrete log search derived from the max value to search for and /// `num_threads` - range_bound: usize, + range_bound: NonZeroUsize, /// Ristretto point representing each step of the discrete log search step_point: RistrettoPoint, /// Ristretto point compression batch size - compression_batch_size: usize, + compression_batch_size: NonZeroUsize, } #[derive(Serialize, Deserialize, Default)] @@ -107,24 +107,27 @@ impl DiscreteLog { Self { generator, target, - num_threads: 1, - range_bound: TWO16 as usize, + num_threads: None, + range_bound: (TWO16 as usize).try_into().unwrap(), step_point: G, - compression_batch_size: 32, + compression_batch_size: 32.try_into().unwrap(), } } /// Adjusts number of threads in a discrete log instance. #[cfg(not(target_arch = "wasm32"))] - pub fn num_threads(&mut self, num_threads: usize) -> Result<(), DiscreteLogError> { + pub fn num_threads(&mut self, num_threads: NonZeroUsize) -> Result<(), DiscreteLogError> { // number of threads must be a positive power-of-two integer - if num_threads == 0 || (num_threads & (num_threads - 1)) != 0 || num_threads > MAX_THREAD { + if !num_threads.is_power_of_two() || num_threads.get() > MAX_THREAD { return Err(DiscreteLogError::DiscreteLogThreads); } - self.num_threads = num_threads; - self.range_bound = (TWO16 as usize).checked_div(num_threads).unwrap(); - self.step_point = Scalar::from(num_threads as u64) * G; + self.num_threads = Some(num_threads); + self.range_bound = (TWO16 as usize) + .checked_div(num_threads.get()) + .and_then(|range_bound| range_bound.try_into().ok()) + .unwrap(); // `num_threads` cannot exceed `TWO16`, so `range_bound` always non-zero + self.step_point = Scalar::from(num_threads.get() as u64) * G; Ok(()) } @@ -132,9 +135,9 @@ impl DiscreteLog { /// Adjusts inversion batch size in a discrete log instance. pub fn set_compression_batch_size( &mut self, - compression_batch_size: usize, + compression_batch_size: NonZeroUsize, ) -> Result<(), DiscreteLogError> { - if compression_batch_size >= TWO16 as usize || compression_batch_size == 0 { + if compression_batch_size.get() >= TWO16 as usize { return Err(DiscreteLogError::DiscreteLogBatchSize); } self.compression_batch_size = compression_batch_size; @@ -145,41 +148,41 @@ impl DiscreteLog { /// Solves the discrete log problem under the assumption that the solution /// is a positive 32-bit number. pub fn decode_u32(self) -> Option { - #[cfg(not(target_arch = "wasm32"))] - { - let mut starting_point = self.target; - let handles = (0..self.num_threads) - .map(|i| { - let ristretto_iterator = RistrettoIterator::new( - (starting_point, i as u64), - (-(&self.step_point), self.num_threads as u64), - ); + if let Some(num_threads) = self.num_threads { + #[cfg(not(target_arch = "wasm32"))] + { + let mut starting_point = self.target; + let handles = (0..num_threads.get()) + .map(|i| { + let ristretto_iterator = RistrettoIterator::new( + (starting_point, i as u64), + (-(&self.step_point), num_threads.get() as u64), + ); - let handle = thread::spawn(move || { - Self::decode_range( - ristretto_iterator, - self.range_bound, - self.compression_batch_size, - ) - }); + let handle = thread::spawn(move || { + Self::decode_range( + ristretto_iterator, + self.range_bound, + self.compression_batch_size, + ) + }); - starting_point -= G; - handle - }) - .collect::>(); + starting_point -= G; + handle + }) + .collect::>(); - handles - .into_iter() - .map_while(|h| h.join().ok()) - .find(|x| x.is_some()) - .flatten() - } - #[cfg(target_arch = "wasm32")] - { - let ristretto_iterator = RistrettoIterator::new( - (self.target, 0_u64), - (-(&self.step_point), self.num_threads as u64), - ); + handles + .into_iter() + .map_while(|h| h.join().ok()) + .find(|x| x.is_some()) + .flatten() + } + #[cfg(target_arch = "wasm32")] + unreachable!() // `self.num_threads` always `None` on wasm target + } else { + let ristretto_iterator = + RistrettoIterator::new((self.target, 0_u64), (-(&self.step_point), 1u64)); Self::decode_range( ristretto_iterator, @@ -191,15 +194,15 @@ impl DiscreteLog { fn decode_range( ristretto_iterator: RistrettoIterator, - range_bound: usize, - compression_batch_size: usize, + range_bound: NonZeroUsize, + compression_batch_size: NonZeroUsize, ) -> Option { let hashmap = &DECODE_PRECOMPUTATION_FOR_G; let mut decoded = None; for batch in &ristretto_iterator - .take(range_bound) - .chunks(compression_batch_size) + .take(range_bound.get()) + .chunks(compression_batch_size.get()) { // batch compression currently errors if any point in the batch is the identity point let (batch_points, batch_indices): (Vec<_>, Vec<_>) = batch @@ -298,7 +301,7 @@ mod tests { let amount: u64 = 55; let mut instance = DiscreteLog::new(G, Scalar::from(amount) * G); - instance.num_threads(4).unwrap(); + instance.num_threads(4.try_into().unwrap()).unwrap(); // Very informal measurements for now let start_computation = Instant::now(); diff --git a/zk-token-sdk/src/encryption/elgamal.rs b/zk-token-sdk/src/encryption/elgamal.rs index e499106e1..0bc9eb051 100644 --- a/zk-token-sdk/src/encryption/elgamal.rs +++ b/zk-token-sdk/src/encryption/elgamal.rs @@ -799,7 +799,7 @@ mod tests { let ciphertext = ElGamal::encrypt(&public, amount); let mut instance = ElGamal::decrypt(&secret, &ciphertext); - instance.num_threads(4).unwrap(); + instance.num_threads(4.try_into().unwrap()).unwrap(); assert_eq!(57_u64, instance.decode_u32().unwrap()); }