From f1f8f5458d8e5d12d3f5d2840fbeef8c3df3bb50 Mon Sep 17 00:00:00 2001 From: samkim-crypto Date: Fri, 1 Apr 2022 21:01:24 -0300 Subject: [PATCH] Threads for discrete log (#23867) * zk-token-sdk: add multi-thread for discrete log * zk-token-sdk: some clean-up * zk-token-sdk: change default discrete log thread to 1 * zk-token-sdk: allow discrete log thread nums to be chosen as param * zk-token-sdk: join discrete log threads * zk-token-sdk: join thread handles before returning * zk-token-sdk: Apply suggestions from code review Co-authored-by: Michael Vines * zk-token-sdk: update tests to use num_threads * zk-token-sdk: simplify discrete log by removing mpsc and just using join * zk-token-sdk: minor Co-authored-by: Michael Vines --- zk-token-sdk/src/encryption/discrete_log.rs | 211 +++++++++++++++----- zk-token-sdk/src/encryption/elgamal.rs | 29 +-- zk-token-sdk/src/errors.rs | 2 + 3 files changed, 178 insertions(+), 64 deletions(-) diff --git a/zk-token-sdk/src/encryption/discrete_log.rs b/zk-token-sdk/src/encryption/discrete_log.rs index 2af579de29..95729c51a6 100644 --- a/zk-token-sdk/src/encryption/discrete_log.rs +++ b/zk-token-sdk/src/encryption/discrete_log.rs @@ -1,18 +1,16 @@ #![cfg(not(target_arch = "bpf"))] use { - curve25519_dalek::{ristretto::RistrettoPoint, scalar::Scalar, traits::Identity}, + crate::errors::ProofError, + curve25519_dalek::{ + constants::RISTRETTO_BASEPOINT_POINT as G, ristretto::RistrettoPoint, scalar::Scalar, + traits::Identity, + }, serde::{Deserialize, Serialize}, - std::collections::HashMap, + std::{collections::HashMap, thread}, }; -#[allow(dead_code)] -const TWO15: u64 = 32768; -#[allow(dead_code)] -const TWO14: u64 = 16384; // 2^14 const TWO16: u64 = 65536; // 2^16 -#[allow(dead_code)] -const TWO18: u64 = 262144; // 2^18 /// Type that captures a discrete log challenge. /// @@ -23,6 +21,13 @@ pub struct DiscreteLog { pub generator: RistrettoPoint, /// Target point for discrete log pub target: RistrettoPoint, + /// Number of threads used for discrete log computation + num_threads: usize, + /// Range bound for discrete log search derived from the max value to search for and + /// `num_threads` + range_bound: usize, + /// Ristretto point representing each step of the discrete log search + step_point: RistrettoPoint, } #[derive(Serialize, Deserialize, Default)] @@ -38,11 +43,11 @@ fn decode_u32_precomputation(generator: RistrettoPoint) -> DecodePrecomputation let generator = two16_scalar * generator; // 2^16 * G // iterator for 2^12*0G , 2^12*1G, 2^12*2G, ... - let ristretto_iter = RistrettoIterator::new(identity, generator); - ristretto_iter.zip(0..TWO16).for_each(|(elem, x_hi)| { - let key = elem.compress().to_bytes(); + let ristretto_iter = RistrettoIterator::new((identity, 0), (generator, 1)); + for (point, x_hi) in ristretto_iter.take(TWO16 as usize) { + let key = point.compress().to_bytes(); hashmap.insert(key, x_hi as u16); - }); + } DecodePrecomputation(hashmap) } @@ -58,26 +63,73 @@ lazy_static::lazy_static! { /// Solves the discrete log instance using a 16/16 bit offline/online split impl DiscreteLog { - /// Solves the discrete log problem under the assumption that the solution - /// is a 32-bit number. - pub(crate) fn decode_u32(self) -> Option { - self.decode_online(&DECODE_PRECOMPUTATION_FOR_G, TWO16) + /// Discrete log instance constructor. + /// + /// Default number of threads set to 1. + pub fn new(generator: RistrettoPoint, target: RistrettoPoint) -> Self { + Self { + generator, + target, + num_threads: 1, + range_bound: TWO16 as usize, + step_point: G, + } } - pub fn decode_online(self, hashmap: &DecodePrecomputation, solution_bound: u64) -> Option { - // iterator for 0G, -1G, -2G, ... - let ristretto_iter = RistrettoIterator::new(self.target, -self.generator); + /// Adjusts number of threads in a discrete log instance. + pub fn num_threads(&mut self, num_threads: usize) -> Result<(), ProofError> { + // number of threads must be a positive power-of-two integer + if num_threads == 0 || (num_threads & (num_threads - 1)) != 0 { + return Err(ProofError::DiscreteLogThreads); + } + self.num_threads = num_threads; + self.range_bound = (TWO16 as usize).checked_div(num_threads).unwrap(); + self.step_point = Scalar::from(num_threads as u64) * G; + + Ok(()) + } + + /// Solves the discrete log problem under the assumption that the solution + /// is a 32-bit number. + pub fn decode_u32(self) -> Option { + let mut starting_point = self.target; + let handles = (0..self.num_threads) + .into_iter() + .map(|i| { + let ristretto_iterator = RistrettoIterator::new( + (starting_point, i as u64), + (-(&self.step_point), self.num_threads as u64), + ); + + let handle = + thread::spawn(move || Self::decode_range(ristretto_iterator, self.range_bound)); + + starting_point -= G; + handle + }) + .collect::>(); + + let mut solution = None; + for handle in handles { + let discrete_log = handle.join().unwrap(); + if discrete_log.is_some() { + solution = discrete_log; + } + } + solution + } + + fn decode_range(ristretto_iterator: RistrettoIterator, range_bound: usize) -> Option { + let hashmap = &DECODE_PRECOMPUTATION_FOR_G; let mut decoded = None; - ristretto_iter - .zip(0..solution_bound) - .for_each(|(elem, x_lo)| { - let key = elem.compress().to_bytes(); - if hashmap.0.contains_key(&key) { - let x_hi = hashmap.0[&key]; - decoded = Some(x_lo + solution_bound * x_hi as u64); - } - }); + for (point, x_lo) in ristretto_iterator.take(range_bound) { + let key = point.compress().to_bytes(); + if hashmap.0.contains_key(&key) { + let x_hi = hashmap.0[&key]; + decoded = Some(x_lo + TWO16 * x_hi as u64); + } + } decoded } } @@ -87,31 +139,29 @@ impl DiscreteLog { /// Given an initial point X and a stepping point P, the iterator iterates through /// X + 0*P, X + 1*P, X + 2*P, X + 3*P, ... struct RistrettoIterator { - pub curr: RistrettoPoint, - pub step: RistrettoPoint, + pub current: (RistrettoPoint, u64), + pub step: (RistrettoPoint, u64), } impl RistrettoIterator { - fn new(curr: RistrettoPoint, step: RistrettoPoint) -> Self { - RistrettoIterator { curr, step } + fn new(current: (RistrettoPoint, u64), step: (RistrettoPoint, u64)) -> Self { + RistrettoIterator { current, step } } } impl Iterator for RistrettoIterator { - type Item = RistrettoPoint; + type Item = (RistrettoPoint, u64); fn next(&mut self) -> Option { - let r = self.curr; - self.curr += self.step; + let r = self.current; + self.current = (self.current.0 + self.step.0, self.current.1 + self.step.1); Some(r) } } #[cfg(test)] mod tests { - use { - super::*, curve25519_dalek::constants::RISTRETTO_BASEPOINT_POINT as G, std::time::Instant, - }; + use {super::*, std::time::Instant}; #[test] #[allow(non_snake_case)] @@ -132,25 +182,82 @@ mod tests { #[test] fn test_decode_correctness() { - let amount: u64 = 65545; + // general case + let amount: u64 = 55; - let instance = DiscreteLog { - generator: G, - target: Scalar::from(amount) * G, - }; + let instance = DiscreteLog::new(G, Scalar::from(amount) * G); // Very informal measurements for now - let start_precomputation = Instant::now(); - let precomputed_hashmap = decode_u32_precomputation(G); - let precomputation_secs = start_precomputation.elapsed().as_secs_f64(); + let start_computation = Instant::now(); + let decoded = instance.decode_u32(); + let computation_secs = start_computation.elapsed().as_secs_f64(); - let start_online = Instant::now(); - let computed_amount = instance.decode_online(&precomputed_hashmap, TWO16).unwrap(); - let online_secs = start_online.elapsed().as_secs_f64(); + assert_eq!(amount, decoded.unwrap()); - assert_eq!(amount, computed_amount); + println!( + "single thread discrete log computation secs: {:?} sec", + computation_secs + ); + } - println!("16/16 Split precomputation: {:?} sec", precomputation_secs); - println!("16/16 Split online computation: {:?} sec", online_secs); + #[test] + fn test_decode_correctness_threaded() { + // general case + let amount: u64 = 55; + + let mut instance = DiscreteLog::new(G, Scalar::from(amount) * G); + instance.num_threads(4).unwrap(); + + // Very informal measurements for now + let start_computation = Instant::now(); + let decoded = instance.decode_u32(); + let computation_secs = start_computation.elapsed().as_secs_f64(); + + assert_eq!(amount, decoded.unwrap()); + + println!( + "4 thread discrete log computation: {:?} sec", + computation_secs + ); + + // amount 0 + let amount: u64 = 0; + + let instance = DiscreteLog::new(G, Scalar::from(amount) * G); + + let decoded = instance.decode_u32(); + assert_eq!(amount, decoded.unwrap()); + + // amount 1 + let amount: u64 = 1; + + let instance = DiscreteLog::new(G, Scalar::from(amount) * G); + + let decoded = instance.decode_u32(); + assert_eq!(amount, decoded.unwrap()); + + // amount 2 + let amount: u64 = 2; + + let instance = DiscreteLog::new(G, Scalar::from(amount) * G); + + let decoded = instance.decode_u32(); + assert_eq!(amount, decoded.unwrap()); + + // amount 3 + let amount: u64 = 3; + + let instance = DiscreteLog::new(G, Scalar::from(amount) * G); + + let decoded = instance.decode_u32(); + assert_eq!(amount, decoded.unwrap()); + + // max amount + let amount: u64 = ((1_u64 << 32) - 1) as u64; + + let instance = DiscreteLog::new(G, Scalar::from(amount) * G); + + let decoded = instance.decode_u32(); + assert_eq!(amount, decoded.unwrap()); } } diff --git a/zk-token-sdk/src/encryption/elgamal.rs b/zk-token-sdk/src/encryption/elgamal.rs index 67667c3357..14b72af699 100644 --- a/zk-token-sdk/src/encryption/elgamal.rs +++ b/zk-token-sdk/src/encryption/elgamal.rs @@ -123,10 +123,10 @@ impl ElGamal { /// message, use `DiscreteLog::decode`. #[cfg(not(target_arch = "bpf"))] fn decrypt(secret: &ElGamalSecretKey, ciphertext: &ElGamalCiphertext) -> DiscreteLog { - DiscreteLog { - generator: *G, - target: &ciphertext.commitment.0 - &(&secret.0 * &ciphertext.handle.0), - } + DiscreteLog::new( + *G, + &ciphertext.commitment.0 - &(&secret.0 * &ciphertext.handle.0), + ) } /// On input a secret key and a ciphertext, the function returns the decrypted message @@ -584,15 +584,23 @@ mod tests { let amount: u32 = 57; let ciphertext = ElGamal::encrypt(&public, amount); - let expected_instance = DiscreteLog { - generator: *G, - target: Scalar::from(amount) * &(*G), - }; + let expected_instance = DiscreteLog::new(*G, Scalar::from(amount) * &(*G)); assert_eq!(expected_instance, ElGamal::decrypt(&secret, &ciphertext)); assert_eq!(57_u64, secret.decrypt_u32(&ciphertext).unwrap()); } + #[test] + fn test_encrypt_decrypt_correctness_multithreaded() { + let ElGamalKeypair { public, secret } = ElGamalKeypair::new_rand(); + let amount: u32 = 57; + let ciphertext = ElGamal::encrypt(&public, amount); + + let mut instance = ElGamal::decrypt(&secret, &ciphertext); + instance.num_threads(4).unwrap(); + assert_eq!(57_u64, instance.decode_u32().unwrap()); + } + #[test] fn test_decrypt_handle() { let ElGamalKeypair { @@ -619,10 +627,7 @@ mod tests { handle: handle_1, }; - let expected_instance = DiscreteLog { - generator: *G, - target: Scalar::from(amount) * (*G), - }; + let expected_instance = DiscreteLog::new(*G, Scalar::from(amount) * &(*G)); assert_eq!(expected_instance, secret_0.decrypt(&ciphertext_0)); assert_eq!(expected_instance, secret_1.decrypt(&ciphertext_1)); diff --git a/zk-token-sdk/src/errors.rs b/zk-token-sdk/src/errors.rs index c5a59c4b7c..61f958a2d6 100644 --- a/zk-token-sdk/src/errors.rs +++ b/zk-token-sdk/src/errors.rs @@ -28,6 +28,8 @@ pub enum ProofError { InconsistentCTData, #[error("failed to decrypt ciphertext from transfer data")] Decryption, + #[error("discrete log number of threads not power-of-two")] + DiscreteLogThreads, } #[derive(Error, Clone, Debug, Eq, PartialEq)]