Threads for discrete log (#23867)

* zk-token-sdk: add multi-thread for discrete log

* zk-token-sdk: some clean-up

* zk-token-sdk: change default discrete log thread to 1

* zk-token-sdk: allow discrete log thread nums to be chosen as param

* zk-token-sdk: join discrete log threads

* zk-token-sdk: join thread handles before returning

* zk-token-sdk: Apply suggestions from code review

Co-authored-by: Michael Vines <mvines@gmail.com>

* zk-token-sdk: update tests to use num_threads

* zk-token-sdk: simplify discrete log by removing mpsc and just using join

* zk-token-sdk: minor

Co-authored-by: Michael Vines <mvines@gmail.com>
This commit is contained in:
samkim-crypto 2022-04-01 21:01:24 -03:00 committed by GitHub
parent 8a18c48e47
commit f1f8f5458d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 178 additions and 64 deletions

View File

@ -1,18 +1,16 @@
#![cfg(not(target_arch = "bpf"))]
use {
curve25519_dalek::{ristretto::RistrettoPoint, scalar::Scalar, traits::Identity},
crate::errors::ProofError,
curve25519_dalek::{
constants::RISTRETTO_BASEPOINT_POINT as G, ristretto::RistrettoPoint, scalar::Scalar,
traits::Identity,
},
serde::{Deserialize, Serialize},
std::collections::HashMap,
std::{collections::HashMap, thread},
};
#[allow(dead_code)]
const TWO15: u64 = 32768;
#[allow(dead_code)]
const TWO14: u64 = 16384; // 2^14
const TWO16: u64 = 65536; // 2^16
#[allow(dead_code)]
const TWO18: u64 = 262144; // 2^18
/// Type that captures a discrete log challenge.
///
@ -23,6 +21,13 @@ pub struct DiscreteLog {
pub generator: RistrettoPoint,
/// Target point for discrete log
pub target: RistrettoPoint,
/// Number of threads used for discrete log computation
num_threads: usize,
/// Range bound for discrete log search derived from the max value to search for and
/// `num_threads`
range_bound: usize,
/// Ristretto point representing each step of the discrete log search
step_point: RistrettoPoint,
}
#[derive(Serialize, Deserialize, Default)]
@ -38,11 +43,11 @@ fn decode_u32_precomputation(generator: RistrettoPoint) -> DecodePrecomputation
let generator = two16_scalar * generator; // 2^16 * G
// iterator for 2^12*0G , 2^12*1G, 2^12*2G, ...
let ristretto_iter = RistrettoIterator::new(identity, generator);
ristretto_iter.zip(0..TWO16).for_each(|(elem, x_hi)| {
let key = elem.compress().to_bytes();
let ristretto_iter = RistrettoIterator::new((identity, 0), (generator, 1));
for (point, x_hi) in ristretto_iter.take(TWO16 as usize) {
let key = point.compress().to_bytes();
hashmap.insert(key, x_hi as u16);
});
}
DecodePrecomputation(hashmap)
}
@ -58,26 +63,73 @@ lazy_static::lazy_static! {
/// Solves the discrete log instance using a 16/16 bit offline/online split
impl DiscreteLog {
/// Solves the discrete log problem under the assumption that the solution
/// is a 32-bit number.
pub(crate) fn decode_u32(self) -> Option<u64> {
self.decode_online(&DECODE_PRECOMPUTATION_FOR_G, TWO16)
/// Discrete log instance constructor.
///
/// Default number of threads set to 1.
pub fn new(generator: RistrettoPoint, target: RistrettoPoint) -> Self {
Self {
generator,
target,
num_threads: 1,
range_bound: TWO16 as usize,
step_point: G,
}
}
pub fn decode_online(self, hashmap: &DecodePrecomputation, solution_bound: u64) -> Option<u64> {
// iterator for 0G, -1G, -2G, ...
let ristretto_iter = RistrettoIterator::new(self.target, -self.generator);
/// Adjusts number of threads in a discrete log instance.
pub fn num_threads(&mut self, num_threads: usize) -> Result<(), ProofError> {
// number of threads must be a positive power-of-two integer
if num_threads == 0 || (num_threads & (num_threads - 1)) != 0 {
return Err(ProofError::DiscreteLogThreads);
}
self.num_threads = num_threads;
self.range_bound = (TWO16 as usize).checked_div(num_threads).unwrap();
self.step_point = Scalar::from(num_threads as u64) * G;
Ok(())
}
/// Solves the discrete log problem under the assumption that the solution
/// is a 32-bit number.
pub fn decode_u32(self) -> Option<u64> {
let mut starting_point = self.target;
let handles = (0..self.num_threads)
.into_iter()
.map(|i| {
let ristretto_iterator = RistrettoIterator::new(
(starting_point, i as u64),
(-(&self.step_point), self.num_threads as u64),
);
let handle =
thread::spawn(move || Self::decode_range(ristretto_iterator, self.range_bound));
starting_point -= G;
handle
})
.collect::<Vec<_>>();
let mut solution = None;
for handle in handles {
let discrete_log = handle.join().unwrap();
if discrete_log.is_some() {
solution = discrete_log;
}
}
solution
}
fn decode_range(ristretto_iterator: RistrettoIterator, range_bound: usize) -> Option<u64> {
let hashmap = &DECODE_PRECOMPUTATION_FOR_G;
let mut decoded = None;
ristretto_iter
.zip(0..solution_bound)
.for_each(|(elem, x_lo)| {
let key = elem.compress().to_bytes();
if hashmap.0.contains_key(&key) {
let x_hi = hashmap.0[&key];
decoded = Some(x_lo + solution_bound * x_hi as u64);
}
});
for (point, x_lo) in ristretto_iterator.take(range_bound) {
let key = point.compress().to_bytes();
if hashmap.0.contains_key(&key) {
let x_hi = hashmap.0[&key];
decoded = Some(x_lo + TWO16 * x_hi as u64);
}
}
decoded
}
}
@ -87,31 +139,29 @@ impl DiscreteLog {
/// Given an initial point X and a stepping point P, the iterator iterates through
/// X + 0*P, X + 1*P, X + 2*P, X + 3*P, ...
struct RistrettoIterator {
pub curr: RistrettoPoint,
pub step: RistrettoPoint,
pub current: (RistrettoPoint, u64),
pub step: (RistrettoPoint, u64),
}
impl RistrettoIterator {
fn new(curr: RistrettoPoint, step: RistrettoPoint) -> Self {
RistrettoIterator { curr, step }
fn new(current: (RistrettoPoint, u64), step: (RistrettoPoint, u64)) -> Self {
RistrettoIterator { current, step }
}
}
impl Iterator for RistrettoIterator {
type Item = RistrettoPoint;
type Item = (RistrettoPoint, u64);
fn next(&mut self) -> Option<Self::Item> {
let r = self.curr;
self.curr += self.step;
let r = self.current;
self.current = (self.current.0 + self.step.0, self.current.1 + self.step.1);
Some(r)
}
}
#[cfg(test)]
mod tests {
use {
super::*, curve25519_dalek::constants::RISTRETTO_BASEPOINT_POINT as G, std::time::Instant,
};
use {super::*, std::time::Instant};
#[test]
#[allow(non_snake_case)]
@ -132,25 +182,82 @@ mod tests {
#[test]
fn test_decode_correctness() {
let amount: u64 = 65545;
// general case
let amount: u64 = 55;
let instance = DiscreteLog {
generator: G,
target: Scalar::from(amount) * G,
};
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
// Very informal measurements for now
let start_precomputation = Instant::now();
let precomputed_hashmap = decode_u32_precomputation(G);
let precomputation_secs = start_precomputation.elapsed().as_secs_f64();
let start_computation = Instant::now();
let decoded = instance.decode_u32();
let computation_secs = start_computation.elapsed().as_secs_f64();
let start_online = Instant::now();
let computed_amount = instance.decode_online(&precomputed_hashmap, TWO16).unwrap();
let online_secs = start_online.elapsed().as_secs_f64();
assert_eq!(amount, decoded.unwrap());
assert_eq!(amount, computed_amount);
println!(
"single thread discrete log computation secs: {:?} sec",
computation_secs
);
}
println!("16/16 Split precomputation: {:?} sec", precomputation_secs);
println!("16/16 Split online computation: {:?} sec", online_secs);
#[test]
fn test_decode_correctness_threaded() {
// general case
let amount: u64 = 55;
let mut instance = DiscreteLog::new(G, Scalar::from(amount) * G);
instance.num_threads(4).unwrap();
// Very informal measurements for now
let start_computation = Instant::now();
let decoded = instance.decode_u32();
let computation_secs = start_computation.elapsed().as_secs_f64();
assert_eq!(amount, decoded.unwrap());
println!(
"4 thread discrete log computation: {:?} sec",
computation_secs
);
// amount 0
let amount: u64 = 0;
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
let decoded = instance.decode_u32();
assert_eq!(amount, decoded.unwrap());
// amount 1
let amount: u64 = 1;
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
let decoded = instance.decode_u32();
assert_eq!(amount, decoded.unwrap());
// amount 2
let amount: u64 = 2;
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
let decoded = instance.decode_u32();
assert_eq!(amount, decoded.unwrap());
// amount 3
let amount: u64 = 3;
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
let decoded = instance.decode_u32();
assert_eq!(amount, decoded.unwrap());
// max amount
let amount: u64 = ((1_u64 << 32) - 1) as u64;
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
let decoded = instance.decode_u32();
assert_eq!(amount, decoded.unwrap());
}
}

View File

@ -123,10 +123,10 @@ impl ElGamal {
/// message, use `DiscreteLog::decode`.
#[cfg(not(target_arch = "bpf"))]
fn decrypt(secret: &ElGamalSecretKey, ciphertext: &ElGamalCiphertext) -> DiscreteLog {
DiscreteLog {
generator: *G,
target: &ciphertext.commitment.0 - &(&secret.0 * &ciphertext.handle.0),
}
DiscreteLog::new(
*G,
&ciphertext.commitment.0 - &(&secret.0 * &ciphertext.handle.0),
)
}
/// On input a secret key and a ciphertext, the function returns the decrypted message
@ -584,15 +584,23 @@ mod tests {
let amount: u32 = 57;
let ciphertext = ElGamal::encrypt(&public, amount);
let expected_instance = DiscreteLog {
generator: *G,
target: Scalar::from(amount) * &(*G),
};
let expected_instance = DiscreteLog::new(*G, Scalar::from(amount) * &(*G));
assert_eq!(expected_instance, ElGamal::decrypt(&secret, &ciphertext));
assert_eq!(57_u64, secret.decrypt_u32(&ciphertext).unwrap());
}
#[test]
fn test_encrypt_decrypt_correctness_multithreaded() {
let ElGamalKeypair { public, secret } = ElGamalKeypair::new_rand();
let amount: u32 = 57;
let ciphertext = ElGamal::encrypt(&public, amount);
let mut instance = ElGamal::decrypt(&secret, &ciphertext);
instance.num_threads(4).unwrap();
assert_eq!(57_u64, instance.decode_u32().unwrap());
}
#[test]
fn test_decrypt_handle() {
let ElGamalKeypair {
@ -619,10 +627,7 @@ mod tests {
handle: handle_1,
};
let expected_instance = DiscreteLog {
generator: *G,
target: Scalar::from(amount) * (*G),
};
let expected_instance = DiscreteLog::new(*G, Scalar::from(amount) * &(*G));
assert_eq!(expected_instance, secret_0.decrypt(&ciphertext_0));
assert_eq!(expected_instance, secret_1.decrypt(&ciphertext_1));

View File

@ -28,6 +28,8 @@ pub enum ProofError {
InconsistentCTData,
#[error("failed to decrypt ciphertext from transfer data")]
Decryption,
#[error("discrete log number of threads not power-of-two")]
DiscreteLogThreads,
}
#[derive(Error, Clone, Debug, Eq, PartialEq)]