Threads for discrete log (#23867)
* zk-token-sdk: add multi-thread for discrete log * zk-token-sdk: some clean-up * zk-token-sdk: change default discrete log thread to 1 * zk-token-sdk: allow discrete log thread nums to be chosen as param * zk-token-sdk: join discrete log threads * zk-token-sdk: join thread handles before returning * zk-token-sdk: Apply suggestions from code review Co-authored-by: Michael Vines <mvines@gmail.com> * zk-token-sdk: update tests to use num_threads * zk-token-sdk: simplify discrete log by removing mpsc and just using join * zk-token-sdk: minor Co-authored-by: Michael Vines <mvines@gmail.com>
This commit is contained in:
parent
8a18c48e47
commit
f1f8f5458d
|
@ -1,18 +1,16 @@
|
||||||
#![cfg(not(target_arch = "bpf"))]
|
#![cfg(not(target_arch = "bpf"))]
|
||||||
|
|
||||||
use {
|
use {
|
||||||
curve25519_dalek::{ristretto::RistrettoPoint, scalar::Scalar, traits::Identity},
|
crate::errors::ProofError,
|
||||||
|
curve25519_dalek::{
|
||||||
|
constants::RISTRETTO_BASEPOINT_POINT as G, ristretto::RistrettoPoint, scalar::Scalar,
|
||||||
|
traits::Identity,
|
||||||
|
},
|
||||||
serde::{Deserialize, Serialize},
|
serde::{Deserialize, Serialize},
|
||||||
std::collections::HashMap,
|
std::{collections::HashMap, thread},
|
||||||
};
|
};
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
const TWO15: u64 = 32768;
|
|
||||||
#[allow(dead_code)]
|
|
||||||
const TWO14: u64 = 16384; // 2^14
|
|
||||||
const TWO16: u64 = 65536; // 2^16
|
const TWO16: u64 = 65536; // 2^16
|
||||||
#[allow(dead_code)]
|
|
||||||
const TWO18: u64 = 262144; // 2^18
|
|
||||||
|
|
||||||
/// Type that captures a discrete log challenge.
|
/// Type that captures a discrete log challenge.
|
||||||
///
|
///
|
||||||
|
@ -23,6 +21,13 @@ pub struct DiscreteLog {
|
||||||
pub generator: RistrettoPoint,
|
pub generator: RistrettoPoint,
|
||||||
/// Target point for discrete log
|
/// Target point for discrete log
|
||||||
pub target: RistrettoPoint,
|
pub target: RistrettoPoint,
|
||||||
|
/// Number of threads used for discrete log computation
|
||||||
|
num_threads: usize,
|
||||||
|
/// Range bound for discrete log search derived from the max value to search for and
|
||||||
|
/// `num_threads`
|
||||||
|
range_bound: usize,
|
||||||
|
/// Ristretto point representing each step of the discrete log search
|
||||||
|
step_point: RistrettoPoint,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Default)]
|
#[derive(Serialize, Deserialize, Default)]
|
||||||
|
@ -38,11 +43,11 @@ fn decode_u32_precomputation(generator: RistrettoPoint) -> DecodePrecomputation
|
||||||
let generator = two16_scalar * generator; // 2^16 * G
|
let generator = two16_scalar * generator; // 2^16 * G
|
||||||
|
|
||||||
// iterator for 2^12*0G , 2^12*1G, 2^12*2G, ...
|
// iterator for 2^12*0G , 2^12*1G, 2^12*2G, ...
|
||||||
let ristretto_iter = RistrettoIterator::new(identity, generator);
|
let ristretto_iter = RistrettoIterator::new((identity, 0), (generator, 1));
|
||||||
ristretto_iter.zip(0..TWO16).for_each(|(elem, x_hi)| {
|
for (point, x_hi) in ristretto_iter.take(TWO16 as usize) {
|
||||||
let key = elem.compress().to_bytes();
|
let key = point.compress().to_bytes();
|
||||||
hashmap.insert(key, x_hi as u16);
|
hashmap.insert(key, x_hi as u16);
|
||||||
});
|
}
|
||||||
|
|
||||||
DecodePrecomputation(hashmap)
|
DecodePrecomputation(hashmap)
|
||||||
}
|
}
|
||||||
|
@ -58,26 +63,73 @@ lazy_static::lazy_static! {
|
||||||
|
|
||||||
/// Solves the discrete log instance using a 16/16 bit offline/online split
|
/// Solves the discrete log instance using a 16/16 bit offline/online split
|
||||||
impl DiscreteLog {
|
impl DiscreteLog {
|
||||||
|
/// Discrete log instance constructor.
|
||||||
|
///
|
||||||
|
/// Default number of threads set to 1.
|
||||||
|
pub fn new(generator: RistrettoPoint, target: RistrettoPoint) -> Self {
|
||||||
|
Self {
|
||||||
|
generator,
|
||||||
|
target,
|
||||||
|
num_threads: 1,
|
||||||
|
range_bound: TWO16 as usize,
|
||||||
|
step_point: G,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Adjusts number of threads in a discrete log instance.
|
||||||
|
pub fn num_threads(&mut self, num_threads: usize) -> Result<(), ProofError> {
|
||||||
|
// number of threads must be a positive power-of-two integer
|
||||||
|
if num_threads == 0 || (num_threads & (num_threads - 1)) != 0 {
|
||||||
|
return Err(ProofError::DiscreteLogThreads);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.num_threads = num_threads;
|
||||||
|
self.range_bound = (TWO16 as usize).checked_div(num_threads).unwrap();
|
||||||
|
self.step_point = Scalar::from(num_threads as u64) * G;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Solves the discrete log problem under the assumption that the solution
|
/// Solves the discrete log problem under the assumption that the solution
|
||||||
/// is a 32-bit number.
|
/// is a 32-bit number.
|
||||||
pub(crate) fn decode_u32(self) -> Option<u64> {
|
pub fn decode_u32(self) -> Option<u64> {
|
||||||
self.decode_online(&DECODE_PRECOMPUTATION_FOR_G, TWO16)
|
let mut starting_point = self.target;
|
||||||
|
let handles = (0..self.num_threads)
|
||||||
|
.into_iter()
|
||||||
|
.map(|i| {
|
||||||
|
let ristretto_iterator = RistrettoIterator::new(
|
||||||
|
(starting_point, i as u64),
|
||||||
|
(-(&self.step_point), self.num_threads as u64),
|
||||||
|
);
|
||||||
|
|
||||||
|
let handle =
|
||||||
|
thread::spawn(move || Self::decode_range(ristretto_iterator, self.range_bound));
|
||||||
|
|
||||||
|
starting_point -= G;
|
||||||
|
handle
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let mut solution = None;
|
||||||
|
for handle in handles {
|
||||||
|
let discrete_log = handle.join().unwrap();
|
||||||
|
if discrete_log.is_some() {
|
||||||
|
solution = discrete_log;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
solution
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn decode_online(self, hashmap: &DecodePrecomputation, solution_bound: u64) -> Option<u64> {
|
fn decode_range(ristretto_iterator: RistrettoIterator, range_bound: usize) -> Option<u64> {
|
||||||
// iterator for 0G, -1G, -2G, ...
|
let hashmap = &DECODE_PRECOMPUTATION_FOR_G;
|
||||||
let ristretto_iter = RistrettoIterator::new(self.target, -self.generator);
|
|
||||||
|
|
||||||
let mut decoded = None;
|
let mut decoded = None;
|
||||||
ristretto_iter
|
for (point, x_lo) in ristretto_iterator.take(range_bound) {
|
||||||
.zip(0..solution_bound)
|
let key = point.compress().to_bytes();
|
||||||
.for_each(|(elem, x_lo)| {
|
|
||||||
let key = elem.compress().to_bytes();
|
|
||||||
if hashmap.0.contains_key(&key) {
|
if hashmap.0.contains_key(&key) {
|
||||||
let x_hi = hashmap.0[&key];
|
let x_hi = hashmap.0[&key];
|
||||||
decoded = Some(x_lo + solution_bound * x_hi as u64);
|
decoded = Some(x_lo + TWO16 * x_hi as u64);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
|
||||||
decoded
|
decoded
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -87,31 +139,29 @@ impl DiscreteLog {
|
||||||
/// Given an initial point X and a stepping point P, the iterator iterates through
|
/// Given an initial point X and a stepping point P, the iterator iterates through
|
||||||
/// X + 0*P, X + 1*P, X + 2*P, X + 3*P, ...
|
/// X + 0*P, X + 1*P, X + 2*P, X + 3*P, ...
|
||||||
struct RistrettoIterator {
|
struct RistrettoIterator {
|
||||||
pub curr: RistrettoPoint,
|
pub current: (RistrettoPoint, u64),
|
||||||
pub step: RistrettoPoint,
|
pub step: (RistrettoPoint, u64),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RistrettoIterator {
|
impl RistrettoIterator {
|
||||||
fn new(curr: RistrettoPoint, step: RistrettoPoint) -> Self {
|
fn new(current: (RistrettoPoint, u64), step: (RistrettoPoint, u64)) -> Self {
|
||||||
RistrettoIterator { curr, step }
|
RistrettoIterator { current, step }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Iterator for RistrettoIterator {
|
impl Iterator for RistrettoIterator {
|
||||||
type Item = RistrettoPoint;
|
type Item = (RistrettoPoint, u64);
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
let r = self.curr;
|
let r = self.current;
|
||||||
self.curr += self.step;
|
self.current = (self.current.0 + self.step.0, self.current.1 + self.step.1);
|
||||||
Some(r)
|
Some(r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use {
|
use {super::*, std::time::Instant};
|
||||||
super::*, curve25519_dalek::constants::RISTRETTO_BASEPOINT_POINT as G, std::time::Instant,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
|
@ -132,25 +182,82 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_decode_correctness() {
|
fn test_decode_correctness() {
|
||||||
let amount: u64 = 65545;
|
// general case
|
||||||
|
let amount: u64 = 55;
|
||||||
|
|
||||||
let instance = DiscreteLog {
|
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
|
||||||
generator: G,
|
|
||||||
target: Scalar::from(amount) * G,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Very informal measurements for now
|
// Very informal measurements for now
|
||||||
let start_precomputation = Instant::now();
|
let start_computation = Instant::now();
|
||||||
let precomputed_hashmap = decode_u32_precomputation(G);
|
let decoded = instance.decode_u32();
|
||||||
let precomputation_secs = start_precomputation.elapsed().as_secs_f64();
|
let computation_secs = start_computation.elapsed().as_secs_f64();
|
||||||
|
|
||||||
let start_online = Instant::now();
|
assert_eq!(amount, decoded.unwrap());
|
||||||
let computed_amount = instance.decode_online(&precomputed_hashmap, TWO16).unwrap();
|
|
||||||
let online_secs = start_online.elapsed().as_secs_f64();
|
|
||||||
|
|
||||||
assert_eq!(amount, computed_amount);
|
println!(
|
||||||
|
"single thread discrete log computation secs: {:?} sec",
|
||||||
|
computation_secs
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
println!("16/16 Split precomputation: {:?} sec", precomputation_secs);
|
#[test]
|
||||||
println!("16/16 Split online computation: {:?} sec", online_secs);
|
fn test_decode_correctness_threaded() {
|
||||||
|
// general case
|
||||||
|
let amount: u64 = 55;
|
||||||
|
|
||||||
|
let mut instance = DiscreteLog::new(G, Scalar::from(amount) * G);
|
||||||
|
instance.num_threads(4).unwrap();
|
||||||
|
|
||||||
|
// Very informal measurements for now
|
||||||
|
let start_computation = Instant::now();
|
||||||
|
let decoded = instance.decode_u32();
|
||||||
|
let computation_secs = start_computation.elapsed().as_secs_f64();
|
||||||
|
|
||||||
|
assert_eq!(amount, decoded.unwrap());
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"4 thread discrete log computation: {:?} sec",
|
||||||
|
computation_secs
|
||||||
|
);
|
||||||
|
|
||||||
|
// amount 0
|
||||||
|
let amount: u64 = 0;
|
||||||
|
|
||||||
|
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
|
||||||
|
|
||||||
|
let decoded = instance.decode_u32();
|
||||||
|
assert_eq!(amount, decoded.unwrap());
|
||||||
|
|
||||||
|
// amount 1
|
||||||
|
let amount: u64 = 1;
|
||||||
|
|
||||||
|
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
|
||||||
|
|
||||||
|
let decoded = instance.decode_u32();
|
||||||
|
assert_eq!(amount, decoded.unwrap());
|
||||||
|
|
||||||
|
// amount 2
|
||||||
|
let amount: u64 = 2;
|
||||||
|
|
||||||
|
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
|
||||||
|
|
||||||
|
let decoded = instance.decode_u32();
|
||||||
|
assert_eq!(amount, decoded.unwrap());
|
||||||
|
|
||||||
|
// amount 3
|
||||||
|
let amount: u64 = 3;
|
||||||
|
|
||||||
|
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
|
||||||
|
|
||||||
|
let decoded = instance.decode_u32();
|
||||||
|
assert_eq!(amount, decoded.unwrap());
|
||||||
|
|
||||||
|
// max amount
|
||||||
|
let amount: u64 = ((1_u64 << 32) - 1) as u64;
|
||||||
|
|
||||||
|
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
|
||||||
|
|
||||||
|
let decoded = instance.decode_u32();
|
||||||
|
assert_eq!(amount, decoded.unwrap());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -123,10 +123,10 @@ impl ElGamal {
|
||||||
/// message, use `DiscreteLog::decode`.
|
/// message, use `DiscreteLog::decode`.
|
||||||
#[cfg(not(target_arch = "bpf"))]
|
#[cfg(not(target_arch = "bpf"))]
|
||||||
fn decrypt(secret: &ElGamalSecretKey, ciphertext: &ElGamalCiphertext) -> DiscreteLog {
|
fn decrypt(secret: &ElGamalSecretKey, ciphertext: &ElGamalCiphertext) -> DiscreteLog {
|
||||||
DiscreteLog {
|
DiscreteLog::new(
|
||||||
generator: *G,
|
*G,
|
||||||
target: &ciphertext.commitment.0 - &(&secret.0 * &ciphertext.handle.0),
|
&ciphertext.commitment.0 - &(&secret.0 * &ciphertext.handle.0),
|
||||||
}
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// On input a secret key and a ciphertext, the function returns the decrypted message
|
/// On input a secret key and a ciphertext, the function returns the decrypted message
|
||||||
|
@ -584,15 +584,23 @@ mod tests {
|
||||||
let amount: u32 = 57;
|
let amount: u32 = 57;
|
||||||
let ciphertext = ElGamal::encrypt(&public, amount);
|
let ciphertext = ElGamal::encrypt(&public, amount);
|
||||||
|
|
||||||
let expected_instance = DiscreteLog {
|
let expected_instance = DiscreteLog::new(*G, Scalar::from(amount) * &(*G));
|
||||||
generator: *G,
|
|
||||||
target: Scalar::from(amount) * &(*G),
|
|
||||||
};
|
|
||||||
|
|
||||||
assert_eq!(expected_instance, ElGamal::decrypt(&secret, &ciphertext));
|
assert_eq!(expected_instance, ElGamal::decrypt(&secret, &ciphertext));
|
||||||
assert_eq!(57_u64, secret.decrypt_u32(&ciphertext).unwrap());
|
assert_eq!(57_u64, secret.decrypt_u32(&ciphertext).unwrap());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_encrypt_decrypt_correctness_multithreaded() {
|
||||||
|
let ElGamalKeypair { public, secret } = ElGamalKeypair::new_rand();
|
||||||
|
let amount: u32 = 57;
|
||||||
|
let ciphertext = ElGamal::encrypt(&public, amount);
|
||||||
|
|
||||||
|
let mut instance = ElGamal::decrypt(&secret, &ciphertext);
|
||||||
|
instance.num_threads(4).unwrap();
|
||||||
|
assert_eq!(57_u64, instance.decode_u32().unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_decrypt_handle() {
|
fn test_decrypt_handle() {
|
||||||
let ElGamalKeypair {
|
let ElGamalKeypair {
|
||||||
|
@ -619,10 +627,7 @@ mod tests {
|
||||||
handle: handle_1,
|
handle: handle_1,
|
||||||
};
|
};
|
||||||
|
|
||||||
let expected_instance = DiscreteLog {
|
let expected_instance = DiscreteLog::new(*G, Scalar::from(amount) * &(*G));
|
||||||
generator: *G,
|
|
||||||
target: Scalar::from(amount) * (*G),
|
|
||||||
};
|
|
||||||
|
|
||||||
assert_eq!(expected_instance, secret_0.decrypt(&ciphertext_0));
|
assert_eq!(expected_instance, secret_0.decrypt(&ciphertext_0));
|
||||||
assert_eq!(expected_instance, secret_1.decrypt(&ciphertext_1));
|
assert_eq!(expected_instance, secret_1.decrypt(&ciphertext_1));
|
||||||
|
|
|
@ -28,6 +28,8 @@ pub enum ProofError {
|
||||||
InconsistentCTData,
|
InconsistentCTData,
|
||||||
#[error("failed to decrypt ciphertext from transfer data")]
|
#[error("failed to decrypt ciphertext from transfer data")]
|
||||||
Decryption,
|
Decryption,
|
||||||
|
#[error("discrete log number of threads not power-of-two")]
|
||||||
|
DiscreteLogThreads,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Error, Clone, Debug, Eq, PartialEq)]
|
#[derive(Error, Clone, Debug, Eq, PartialEq)]
|
||||||
|
|
Loading…
Reference in New Issue