[zk-token-sdk] Allow discrete log to be executed in the current thread (#443)

This commit is contained in:
samkim-crypto 2024-03-30 06:37:43 +09:00 committed by GitHub
parent c5b9196df7
commit fb1ee7842f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 55 additions and 52 deletions

View File

@ -28,7 +28,7 @@ use {
},
itertools::Itertools,
serde::{Deserialize, Serialize},
std::collections::HashMap,
std::{collections::HashMap, num::NonZeroUsize},
thiserror::Error,
};
@ -57,14 +57,14 @@ pub struct DiscreteLog {
/// Target point for discrete log
pub target: RistrettoPoint,
/// Number of threads used for discrete log computation
num_threads: usize,
num_threads: Option<NonZeroUsize>,
/// Range bound for discrete log search derived from the max value to search for and
/// `num_threads`
range_bound: usize,
range_bound: NonZeroUsize,
/// Ristretto point representing each step of the discrete log search
step_point: RistrettoPoint,
/// Ristretto point compression batch size
compression_batch_size: usize,
compression_batch_size: NonZeroUsize,
}
#[derive(Serialize, Deserialize, Default)]
@ -107,24 +107,27 @@ impl DiscreteLog {
Self {
generator,
target,
num_threads: 1,
range_bound: TWO16 as usize,
num_threads: None,
range_bound: (TWO16 as usize).try_into().unwrap(),
step_point: G,
compression_batch_size: 32,
compression_batch_size: 32.try_into().unwrap(),
}
}
/// Adjusts number of threads in a discrete log instance.
#[cfg(not(target_arch = "wasm32"))]
pub fn num_threads(&mut self, num_threads: usize) -> Result<(), DiscreteLogError> {
pub fn num_threads(&mut self, num_threads: NonZeroUsize) -> Result<(), DiscreteLogError> {
// number of threads must be a positive power-of-two integer
if num_threads == 0 || (num_threads & (num_threads - 1)) != 0 || num_threads > MAX_THREAD {
if !num_threads.is_power_of_two() || num_threads.get() > MAX_THREAD {
return Err(DiscreteLogError::DiscreteLogThreads);
}
self.num_threads = num_threads;
self.range_bound = (TWO16 as usize).checked_div(num_threads).unwrap();
self.step_point = Scalar::from(num_threads as u64) * G;
self.num_threads = Some(num_threads);
self.range_bound = (TWO16 as usize)
.checked_div(num_threads.get())
.and_then(|range_bound| range_bound.try_into().ok())
.unwrap(); // `num_threads` cannot exceed `TWO16`, so `range_bound` always non-zero
self.step_point = Scalar::from(num_threads.get() as u64) * G;
Ok(())
}
@ -132,9 +135,9 @@ impl DiscreteLog {
/// Adjusts inversion batch size in a discrete log instance.
pub fn set_compression_batch_size(
&mut self,
compression_batch_size: usize,
compression_batch_size: NonZeroUsize,
) -> Result<(), DiscreteLogError> {
if compression_batch_size >= TWO16 as usize || compression_batch_size == 0 {
if compression_batch_size.get() >= TWO16 as usize {
return Err(DiscreteLogError::DiscreteLogBatchSize);
}
self.compression_batch_size = compression_batch_size;
@ -145,41 +148,41 @@ impl DiscreteLog {
/// Solves the discrete log problem under the assumption that the solution
/// is a positive 32-bit number.
pub fn decode_u32(self) -> Option<u64> {
#[cfg(not(target_arch = "wasm32"))]
{
let mut starting_point = self.target;
let handles = (0..self.num_threads)
.map(|i| {
let ristretto_iterator = RistrettoIterator::new(
(starting_point, i as u64),
(-(&self.step_point), self.num_threads as u64),
);
if let Some(num_threads) = self.num_threads {
#[cfg(not(target_arch = "wasm32"))]
{
let mut starting_point = self.target;
let handles = (0..num_threads.get())
.map(|i| {
let ristretto_iterator = RistrettoIterator::new(
(starting_point, i as u64),
(-(&self.step_point), num_threads.get() as u64),
);
let handle = thread::spawn(move || {
Self::decode_range(
ristretto_iterator,
self.range_bound,
self.compression_batch_size,
)
});
let handle = thread::spawn(move || {
Self::decode_range(
ristretto_iterator,
self.range_bound,
self.compression_batch_size,
)
});
starting_point -= G;
handle
})
.collect::<Vec<_>>();
starting_point -= G;
handle
})
.collect::<Vec<_>>();
handles
.into_iter()
.map_while(|h| h.join().ok())
.find(|x| x.is_some())
.flatten()
}
#[cfg(target_arch = "wasm32")]
{
let ristretto_iterator = RistrettoIterator::new(
(self.target, 0_u64),
(-(&self.step_point), self.num_threads as u64),
);
handles
.into_iter()
.map_while(|h| h.join().ok())
.find(|x| x.is_some())
.flatten()
}
#[cfg(target_arch = "wasm32")]
unreachable!() // `self.num_threads` always `None` on wasm target
} else {
let ristretto_iterator =
RistrettoIterator::new((self.target, 0_u64), (-(&self.step_point), 1u64));
Self::decode_range(
ristretto_iterator,
@ -191,15 +194,15 @@ impl DiscreteLog {
fn decode_range(
ristretto_iterator: RistrettoIterator,
range_bound: usize,
compression_batch_size: usize,
range_bound: NonZeroUsize,
compression_batch_size: NonZeroUsize,
) -> Option<u64> {
let hashmap = &DECODE_PRECOMPUTATION_FOR_G;
let mut decoded = None;
for batch in &ristretto_iterator
.take(range_bound)
.chunks(compression_batch_size)
.take(range_bound.get())
.chunks(compression_batch_size.get())
{
// batch compression currently errors if any point in the batch is the identity point
let (batch_points, batch_indices): (Vec<_>, Vec<_>) = batch
@ -298,7 +301,7 @@ mod tests {
let amount: u64 = 55;
let mut instance = DiscreteLog::new(G, Scalar::from(amount) * G);
instance.num_threads(4).unwrap();
instance.num_threads(4.try_into().unwrap()).unwrap();
// Very informal measurements for now
let start_computation = Instant::now();

View File

@ -799,7 +799,7 @@ mod tests {
let ciphertext = ElGamal::encrypt(&public, amount);
let mut instance = ElGamal::decrypt(&secret, &ciphertext);
instance.num_threads(4).unwrap();
instance.num_threads(4.try_into().unwrap()).unwrap();
assert_eq!(57_u64, instance.decode_u32().unwrap());
}