[zk-token-sdk] Allow discrete log to be executed in the current thread (#443)

This commit is contained in:
samkim-crypto 2024-03-30 06:37:43 +09:00 committed by GitHub
parent c5b9196df7
commit fb1ee7842f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 55 additions and 52 deletions

View File

@ -28,7 +28,7 @@ use {
}, },
itertools::Itertools, itertools::Itertools,
serde::{Deserialize, Serialize}, serde::{Deserialize, Serialize},
std::collections::HashMap, std::{collections::HashMap, num::NonZeroUsize},
thiserror::Error, thiserror::Error,
}; };
@ -57,14 +57,14 @@ pub struct DiscreteLog {
/// Target point for discrete log /// Target point for discrete log
pub target: RistrettoPoint, pub target: RistrettoPoint,
/// Number of threads used for discrete log computation /// Number of threads used for discrete log computation
num_threads: usize, num_threads: Option<NonZeroUsize>,
/// Range bound for discrete log search derived from the max value to search for and /// Range bound for discrete log search derived from the max value to search for and
/// `num_threads` /// `num_threads`
range_bound: usize, range_bound: NonZeroUsize,
/// Ristretto point representing each step of the discrete log search /// Ristretto point representing each step of the discrete log search
step_point: RistrettoPoint, step_point: RistrettoPoint,
/// Ristretto point compression batch size /// Ristretto point compression batch size
compression_batch_size: usize, compression_batch_size: NonZeroUsize,
} }
#[derive(Serialize, Deserialize, Default)] #[derive(Serialize, Deserialize, Default)]
@ -107,24 +107,27 @@ impl DiscreteLog {
Self { Self {
generator, generator,
target, target,
num_threads: 1, num_threads: None,
range_bound: TWO16 as usize, range_bound: (TWO16 as usize).try_into().unwrap(),
step_point: G, step_point: G,
compression_batch_size: 32, compression_batch_size: 32.try_into().unwrap(),
} }
} }
/// Adjusts number of threads in a discrete log instance. /// Adjusts number of threads in a discrete log instance.
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
pub fn num_threads(&mut self, num_threads: usize) -> Result<(), DiscreteLogError> { pub fn num_threads(&mut self, num_threads: NonZeroUsize) -> Result<(), DiscreteLogError> {
// number of threads must be a positive power-of-two integer // number of threads must be a positive power-of-two integer
if num_threads == 0 || (num_threads & (num_threads - 1)) != 0 || num_threads > MAX_THREAD { if !num_threads.is_power_of_two() || num_threads.get() > MAX_THREAD {
return Err(DiscreteLogError::DiscreteLogThreads); return Err(DiscreteLogError::DiscreteLogThreads);
} }
self.num_threads = num_threads; self.num_threads = Some(num_threads);
self.range_bound = (TWO16 as usize).checked_div(num_threads).unwrap(); self.range_bound = (TWO16 as usize)
self.step_point = Scalar::from(num_threads as u64) * G; .checked_div(num_threads.get())
.and_then(|range_bound| range_bound.try_into().ok())
.unwrap(); // `num_threads` cannot exceed `TWO16`, so `range_bound` always non-zero
self.step_point = Scalar::from(num_threads.get() as u64) * G;
Ok(()) Ok(())
} }
@ -132,9 +135,9 @@ impl DiscreteLog {
/// Adjusts inversion batch size in a discrete log instance. /// Adjusts inversion batch size in a discrete log instance.
pub fn set_compression_batch_size( pub fn set_compression_batch_size(
&mut self, &mut self,
compression_batch_size: usize, compression_batch_size: NonZeroUsize,
) -> Result<(), DiscreteLogError> { ) -> Result<(), DiscreteLogError> {
if compression_batch_size >= TWO16 as usize || compression_batch_size == 0 { if compression_batch_size.get() >= TWO16 as usize {
return Err(DiscreteLogError::DiscreteLogBatchSize); return Err(DiscreteLogError::DiscreteLogBatchSize);
} }
self.compression_batch_size = compression_batch_size; self.compression_batch_size = compression_batch_size;
@ -145,41 +148,41 @@ impl DiscreteLog {
/// Solves the discrete log problem under the assumption that the solution /// Solves the discrete log problem under the assumption that the solution
/// is a positive 32-bit number. /// is a positive 32-bit number.
pub fn decode_u32(self) -> Option<u64> { pub fn decode_u32(self) -> Option<u64> {
#[cfg(not(target_arch = "wasm32"))] if let Some(num_threads) = self.num_threads {
{ #[cfg(not(target_arch = "wasm32"))]
let mut starting_point = self.target; {
let handles = (0..self.num_threads) let mut starting_point = self.target;
.map(|i| { let handles = (0..num_threads.get())
let ristretto_iterator = RistrettoIterator::new( .map(|i| {
(starting_point, i as u64), let ristretto_iterator = RistrettoIterator::new(
(-(&self.step_point), self.num_threads as u64), (starting_point, i as u64),
); (-(&self.step_point), num_threads.get() as u64),
);
let handle = thread::spawn(move || { let handle = thread::spawn(move || {
Self::decode_range( Self::decode_range(
ristretto_iterator, ristretto_iterator,
self.range_bound, self.range_bound,
self.compression_batch_size, self.compression_batch_size,
) )
}); });
starting_point -= G; starting_point -= G;
handle handle
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
handles handles
.into_iter() .into_iter()
.map_while(|h| h.join().ok()) .map_while(|h| h.join().ok())
.find(|x| x.is_some()) .find(|x| x.is_some())
.flatten() .flatten()
} }
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
{ unreachable!() // `self.num_threads` always `None` on wasm target
let ristretto_iterator = RistrettoIterator::new( } else {
(self.target, 0_u64), let ristretto_iterator =
(-(&self.step_point), self.num_threads as u64), RistrettoIterator::new((self.target, 0_u64), (-(&self.step_point), 1u64));
);
Self::decode_range( Self::decode_range(
ristretto_iterator, ristretto_iterator,
@ -191,15 +194,15 @@ impl DiscreteLog {
fn decode_range( fn decode_range(
ristretto_iterator: RistrettoIterator, ristretto_iterator: RistrettoIterator,
range_bound: usize, range_bound: NonZeroUsize,
compression_batch_size: usize, compression_batch_size: NonZeroUsize,
) -> Option<u64> { ) -> Option<u64> {
let hashmap = &DECODE_PRECOMPUTATION_FOR_G; let hashmap = &DECODE_PRECOMPUTATION_FOR_G;
let mut decoded = None; let mut decoded = None;
for batch in &ristretto_iterator for batch in &ristretto_iterator
.take(range_bound) .take(range_bound.get())
.chunks(compression_batch_size) .chunks(compression_batch_size.get())
{ {
// batch compression currently errors if any point in the batch is the identity point // batch compression currently errors if any point in the batch is the identity point
let (batch_points, batch_indices): (Vec<_>, Vec<_>) = batch let (batch_points, batch_indices): (Vec<_>, Vec<_>) = batch
@ -298,7 +301,7 @@ mod tests {
let amount: u64 = 55; let amount: u64 = 55;
let mut instance = DiscreteLog::new(G, Scalar::from(amount) * G); let mut instance = DiscreteLog::new(G, Scalar::from(amount) * G);
instance.num_threads(4).unwrap(); instance.num_threads(4.try_into().unwrap()).unwrap();
// Very informal measurements for now // Very informal measurements for now
let start_computation = Instant::now(); let start_computation = Instant::now();

View File

@ -799,7 +799,7 @@ mod tests {
let ciphertext = ElGamal::encrypt(&public, amount); let ciphertext = ElGamal::encrypt(&public, amount);
let mut instance = ElGamal::decrypt(&secret, &ciphertext); let mut instance = ElGamal::decrypt(&secret, &ciphertext);
instance.num_threads(4).unwrap(); instance.num_threads(4.try_into().unwrap()).unwrap();
assert_eq!(57_u64, instance.decode_u32().unwrap()); assert_eq!(57_u64, instance.decode_u32().unwrap());
} }