[zk-token-sdk] Allow discrete log to be executed in the current thread (#443)
This commit is contained in:
parent
c5b9196df7
commit
fb1ee7842f
|
@ -28,7 +28,7 @@ use {
|
|||
},
|
||||
itertools::Itertools,
|
||||
serde::{Deserialize, Serialize},
|
||||
std::collections::HashMap,
|
||||
std::{collections::HashMap, num::NonZeroUsize},
|
||||
thiserror::Error,
|
||||
};
|
||||
|
||||
|
@ -57,14 +57,14 @@ pub struct DiscreteLog {
|
|||
/// Target point for discrete log
|
||||
pub target: RistrettoPoint,
|
||||
/// Number of threads used for discrete log computation
|
||||
num_threads: usize,
|
||||
num_threads: Option<NonZeroUsize>,
|
||||
/// Range bound for discrete log search derived from the max value to search for and
|
||||
/// `num_threads`
|
||||
range_bound: usize,
|
||||
range_bound: NonZeroUsize,
|
||||
/// Ristretto point representing each step of the discrete log search
|
||||
step_point: RistrettoPoint,
|
||||
/// Ristretto point compression batch size
|
||||
compression_batch_size: usize,
|
||||
compression_batch_size: NonZeroUsize,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default)]
|
||||
|
@ -107,24 +107,27 @@ impl DiscreteLog {
|
|||
Self {
|
||||
generator,
|
||||
target,
|
||||
num_threads: 1,
|
||||
range_bound: TWO16 as usize,
|
||||
num_threads: None,
|
||||
range_bound: (TWO16 as usize).try_into().unwrap(),
|
||||
step_point: G,
|
||||
compression_batch_size: 32,
|
||||
compression_batch_size: 32.try_into().unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Adjusts number of threads in a discrete log instance.
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub fn num_threads(&mut self, num_threads: usize) -> Result<(), DiscreteLogError> {
|
||||
pub fn num_threads(&mut self, num_threads: NonZeroUsize) -> Result<(), DiscreteLogError> {
|
||||
// number of threads must be a positive power-of-two integer
|
||||
if num_threads == 0 || (num_threads & (num_threads - 1)) != 0 || num_threads > MAX_THREAD {
|
||||
if !num_threads.is_power_of_two() || num_threads.get() > MAX_THREAD {
|
||||
return Err(DiscreteLogError::DiscreteLogThreads);
|
||||
}
|
||||
|
||||
self.num_threads = num_threads;
|
||||
self.range_bound = (TWO16 as usize).checked_div(num_threads).unwrap();
|
||||
self.step_point = Scalar::from(num_threads as u64) * G;
|
||||
self.num_threads = Some(num_threads);
|
||||
self.range_bound = (TWO16 as usize)
|
||||
.checked_div(num_threads.get())
|
||||
.and_then(|range_bound| range_bound.try_into().ok())
|
||||
.unwrap(); // `num_threads` cannot exceed `TWO16`, so `range_bound` always non-zero
|
||||
self.step_point = Scalar::from(num_threads.get() as u64) * G;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -132,9 +135,9 @@ impl DiscreteLog {
|
|||
/// Adjusts inversion batch size in a discrete log instance.
|
||||
pub fn set_compression_batch_size(
|
||||
&mut self,
|
||||
compression_batch_size: usize,
|
||||
compression_batch_size: NonZeroUsize,
|
||||
) -> Result<(), DiscreteLogError> {
|
||||
if compression_batch_size >= TWO16 as usize || compression_batch_size == 0 {
|
||||
if compression_batch_size.get() >= TWO16 as usize {
|
||||
return Err(DiscreteLogError::DiscreteLogBatchSize);
|
||||
}
|
||||
self.compression_batch_size = compression_batch_size;
|
||||
|
@ -145,14 +148,15 @@ impl DiscreteLog {
|
|||
/// Solves the discrete log problem under the assumption that the solution
|
||||
/// is a positive 32-bit number.
|
||||
pub fn decode_u32(self) -> Option<u64> {
|
||||
if let Some(num_threads) = self.num_threads {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
{
|
||||
let mut starting_point = self.target;
|
||||
let handles = (0..self.num_threads)
|
||||
let handles = (0..num_threads.get())
|
||||
.map(|i| {
|
||||
let ristretto_iterator = RistrettoIterator::new(
|
||||
(starting_point, i as u64),
|
||||
(-(&self.step_point), self.num_threads as u64),
|
||||
(-(&self.step_point), num_threads.get() as u64),
|
||||
);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
|
@ -175,11 +179,10 @@ impl DiscreteLog {
|
|||
.flatten()
|
||||
}
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
{
|
||||
let ristretto_iterator = RistrettoIterator::new(
|
||||
(self.target, 0_u64),
|
||||
(-(&self.step_point), self.num_threads as u64),
|
||||
);
|
||||
unreachable!() // `self.num_threads` always `None` on wasm target
|
||||
} else {
|
||||
let ristretto_iterator =
|
||||
RistrettoIterator::new((self.target, 0_u64), (-(&self.step_point), 1u64));
|
||||
|
||||
Self::decode_range(
|
||||
ristretto_iterator,
|
||||
|
@ -191,15 +194,15 @@ impl DiscreteLog {
|
|||
|
||||
fn decode_range(
|
||||
ristretto_iterator: RistrettoIterator,
|
||||
range_bound: usize,
|
||||
compression_batch_size: usize,
|
||||
range_bound: NonZeroUsize,
|
||||
compression_batch_size: NonZeroUsize,
|
||||
) -> Option<u64> {
|
||||
let hashmap = &DECODE_PRECOMPUTATION_FOR_G;
|
||||
let mut decoded = None;
|
||||
|
||||
for batch in &ristretto_iterator
|
||||
.take(range_bound)
|
||||
.chunks(compression_batch_size)
|
||||
.take(range_bound.get())
|
||||
.chunks(compression_batch_size.get())
|
||||
{
|
||||
// batch compression currently errors if any point in the batch is the identity point
|
||||
let (batch_points, batch_indices): (Vec<_>, Vec<_>) = batch
|
||||
|
@ -298,7 +301,7 @@ mod tests {
|
|||
let amount: u64 = 55;
|
||||
|
||||
let mut instance = DiscreteLog::new(G, Scalar::from(amount) * G);
|
||||
instance.num_threads(4).unwrap();
|
||||
instance.num_threads(4.try_into().unwrap()).unwrap();
|
||||
|
||||
// Very informal measurements for now
|
||||
let start_computation = Instant::now();
|
||||
|
|
|
@ -799,7 +799,7 @@ mod tests {
|
|||
let ciphertext = ElGamal::encrypt(&public, amount);
|
||||
|
||||
let mut instance = ElGamal::decrypt(&secret, &ciphertext);
|
||||
instance.num_threads(4).unwrap();
|
||||
instance.num_threads(4.try_into().unwrap()).unwrap();
|
||||
assert_eq!(57_u64, instance.decode_u32().unwrap());
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue