diff --git a/Cargo.lock b/Cargo.lock index 62293b6485..a60bb91ca1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6773,6 +6773,7 @@ dependencies = [ "cipher 0.4.3", "curve25519-dalek", "getrandom 0.1.16", + "itertools", "lazy_static", "merlin", "num-derive", diff --git a/programs/bpf/Cargo.lock b/programs/bpf/Cargo.lock index 0a110ea936..35756468f3 100644 --- a/programs/bpf/Cargo.lock +++ b/programs/bpf/Cargo.lock @@ -5982,6 +5982,7 @@ dependencies = [ "cipher 0.4.3", "curve25519-dalek", "getrandom 0.1.14", + "itertools", "lazy_static", "merlin", "num-derive", diff --git a/zk-token-sdk/Cargo.toml b/zk-token-sdk/Cargo.toml index 37809fc7f0..dedda34eda 100644 --- a/zk-token-sdk/Cargo.toml +++ b/zk-token-sdk/Cargo.toml @@ -22,6 +22,7 @@ byteorder = "1" cipher = "0.4" curve25519-dalek = { version = "3.2.1", features = ["serde"] } getrandom = { version = "0.1", features = ["dummy"] } +itertools = "0.10.3" lazy_static = "1.4.0" merlin = "3" rand = "0.7" diff --git a/zk-token-sdk/src/encryption/decode_u32_precomputation_for_G.bincode b/zk-token-sdk/src/encryption/decode_u32_precomputation_for_G.bincode index 7be1353b3e..f8aaaa937e 100644 Binary files a/zk-token-sdk/src/encryption/decode_u32_precomputation_for_G.bincode and b/zk-token-sdk/src/encryption/decode_u32_precomputation_for_G.bincode differ diff --git a/zk-token-sdk/src/encryption/discrete_log.rs b/zk-token-sdk/src/encryption/discrete_log.rs index 03beda9efd..bc50576998 100644 --- a/zk-token-sdk/src/encryption/discrete_log.rs +++ b/zk-token-sdk/src/encryption/discrete_log.rs @@ -1,16 +1,36 @@ +//! The discrete log implementation for the twisted ElGamal decryption. +//! +//! The implementation uses the baby-step giant-step method, which consists of a precomputation +//! step and an online step. The precomputation step involves computing a hash table of a number +//! of Ristretto points that is independent of a discrete log instance. The online phase computes +//! the final discrete log solution using the discrete log instance and the pre-computed hash +//! table. More details on the baby-step giant-step algorithm and the implementation can be found +//! in the [spl documentation](https://spl.solana.com). +//! +//! The implementation is NOT intended to run in constant-time. There are some measures to prevent +//! straightforward timing attacks. For instance, it does not short-circuit the search when a +//! solution is found. However, the use of hashtables, batching, and threads make the +//! implementation inherently not constant-time. This may theoretically allow an adversary to gain +//! information on a discrete log solution depending on the execution time of the implementation. +//! + #![cfg(not(target_os = "solana"))] use { crate::errors::ProofError, curve25519_dalek::{ - constants::RISTRETTO_BASEPOINT_POINT as G, ristretto::RistrettoPoint, scalar::Scalar, - traits::Identity, + constants::RISTRETTO_BASEPOINT_POINT as G, + ristretto::RistrettoPoint, + scalar::Scalar, + traits::{Identity, IsIdentity}, }, + itertools::Itertools, serde::{Deserialize, Serialize}, std::{collections::HashMap, thread}, }; const TWO16: u64 = 65536; // 2^16 +const TWO17: u64 = 131072; // 2^17 /// Type that captures a discrete log challenge. /// @@ -28,6 +48,8 @@ pub struct DiscreteLog { range_bound: usize, /// Ristretto point representing each step of the discrete log search step_point: RistrettoPoint, + /// Ristretto point compression batch size + compression_batch_size: usize, } #[derive(Serialize, Deserialize, Default)] @@ -38,11 +60,11 @@ pub struct DecodePrecomputation(HashMap<[u8; 32], u16>); fn decode_u32_precomputation(generator: RistrettoPoint) -> DecodePrecomputation { let mut hashmap = HashMap::new(); - let two16_scalar = Scalar::from(TWO16); + let two17_scalar = Scalar::from(TWO17); let identity = RistrettoPoint::identity(); // 0 * G - let generator = two16_scalar * generator; // 2^16 * G + let generator = two17_scalar * generator; // 2^17 * G - // iterator for 2^12*0G , 2^12*1G, 2^12*2G, ... + // iterator for 2^17*0G , 2^17*1G, 2^17*2G, ... let ristretto_iter = RistrettoIterator::new((identity, 0), (generator, 1)); for (point, x_hi) in ristretto_iter.take(TWO16 as usize) { let key = point.compress().to_bytes(); @@ -73,13 +95,14 @@ impl DiscreteLog { num_threads: 1, range_bound: TWO16 as usize, step_point: G, + compression_batch_size: 32, } } /// Adjusts number of threads in a discrete log instance. pub fn num_threads(&mut self, num_threads: usize) -> Result<(), ProofError> { // number of threads must be a positive power-of-two integer - if num_threads == 0 || (num_threads & (num_threads - 1)) != 0 { + if num_threads == 0 || (num_threads & (num_threads - 1)) != 0 || num_threads > 65536 { return Err(ProofError::DiscreteLogThreads); } @@ -90,6 +113,19 @@ impl DiscreteLog { Ok(()) } + /// Adjusts inversion batch size in a discrete log instance. + pub fn set_compression_batch_size( + &mut self, + compression_batch_size: usize, + ) -> Result<(), ProofError> { + if compression_batch_size >= TWO16 as usize { + return Err(ProofError::DiscreteLogBatchSize); + } + self.compression_batch_size = compression_batch_size; + + Ok(()) + } + /// Solves the discrete log problem under the assumption that the solution /// is a 32-bit number. pub fn decode_u32(self) -> Option { @@ -102,8 +138,14 @@ impl DiscreteLog { (-(&self.step_point), self.num_threads as u64), ); - let handle = - thread::spawn(move || Self::decode_range(ristretto_iterator, self.range_bound)); + let handle = thread::spawn(move || { + Self::decode_range( + ristretto_iterator, + self.range_bound, + self.compression_batch_size, + ) + // Self::decode_range(ristretto_iterator, self.range_bound) + }); starting_point -= G; handle @@ -120,16 +162,39 @@ impl DiscreteLog { solution } - fn decode_range(ristretto_iterator: RistrettoIterator, range_bound: usize) -> Option { + fn decode_range( + ristretto_iterator: RistrettoIterator, + range_bound: usize, + compression_batch_size: usize, + ) -> Option { let hashmap = &DECODE_PRECOMPUTATION_FOR_G; let mut decoded = None; - for (point, x_lo) in ristretto_iterator.take(range_bound) { - let key = point.compress().to_bytes(); - if hashmap.0.contains_key(&key) { - let x_hi = hashmap.0[&key]; - decoded = Some(x_lo + TWO16 * x_hi as u64); + + for batch in &ristretto_iterator + .take(range_bound) + .chunks(compression_batch_size) + { + let (batch_points, batch_indices): (Vec<_>, Vec<_>) = batch + .filter(|(point, index)| { + if point.is_identity() { + decoded = Some(*index); + return false; + } + true + }) + .unzip(); + + let batch_compressed = RistrettoPoint::double_and_compress_batch(&batch_points); + + for (point, x_lo) in batch_compressed.iter().zip(batch_indices.iter()) { + let key = point.to_bytes(); + if hashmap.0.contains_key(&key) { + let x_hi = hashmap.0[&key]; + decoded = Some(x_lo + TWO16 * x_hi as u64); + } } } + decoded } } @@ -167,6 +232,7 @@ mod tests { #[allow(non_snake_case)] fn test_serialize_decode_u32_precomputation_for_G() { let decode_u32_precomputation_for_G = decode_u32_precomputation(G); + // let decode_u32_precomputation_for_G = decode_u32_precomputation(G); if decode_u32_precomputation_for_G.0 != DECODE_PRECOMPUTATION_FOR_G.0 { use std::{fs::File, io::Write, path::PathBuf}; @@ -183,7 +249,7 @@ mod tests { #[test] fn test_decode_correctness() { // general case - let amount: u64 = 55; + let amount: u64 = 4294967295; let instance = DiscreteLog::new(G, Scalar::from(amount) * G); diff --git a/zk-token-sdk/src/errors.rs b/zk-token-sdk/src/errors.rs index 61f958a2d6..bb25ee9115 100644 --- a/zk-token-sdk/src/errors.rs +++ b/zk-token-sdk/src/errors.rs @@ -30,6 +30,8 @@ pub enum ProofError { Decryption, #[error("discrete log number of threads not power-of-two")] DiscreteLogThreads, + #[error("discrete log batch size too large")] + DiscreteLogBatchSize, } #[derive(Error, Clone, Debug, Eq, PartialEq)]