Zk token sdk/batch discrete log (#27412)

* zk-token-sdk: optimize discrete log search with batch compression

* zk-token-sdk: include batch size as part of discrete log struct

* zk-token-sdk: add a note on discrete log timings

* zk-token-sdk: add upper bound on the number of threads

* zk-token-sdk: minor

* zk-token-sdk: cargo.lock
This commit is contained in:
samkim-crypto 2022-08-27 06:54:59 +09:00 committed by GitHub
parent d0983c3cf7
commit bd88e2a11c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 86 additions and 15 deletions

1
Cargo.lock generated
View File

@ -6773,6 +6773,7 @@ dependencies = [
"cipher 0.4.3", "cipher 0.4.3",
"curve25519-dalek", "curve25519-dalek",
"getrandom 0.1.16", "getrandom 0.1.16",
"itertools",
"lazy_static", "lazy_static",
"merlin", "merlin",
"num-derive", "num-derive",

View File

@ -5982,6 +5982,7 @@ dependencies = [
"cipher 0.4.3", "cipher 0.4.3",
"curve25519-dalek", "curve25519-dalek",
"getrandom 0.1.14", "getrandom 0.1.14",
"itertools",
"lazy_static", "lazy_static",
"merlin", "merlin",
"num-derive", "num-derive",

View File

@ -22,6 +22,7 @@ byteorder = "1"
cipher = "0.4" cipher = "0.4"
curve25519-dalek = { version = "3.2.1", features = ["serde"] } curve25519-dalek = { version = "3.2.1", features = ["serde"] }
getrandom = { version = "0.1", features = ["dummy"] } getrandom = { version = "0.1", features = ["dummy"] }
itertools = "0.10.3"
lazy_static = "1.4.0" lazy_static = "1.4.0"
merlin = "3" merlin = "3"
rand = "0.7" rand = "0.7"

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.1 MiB

After

Width:  |  Height:  |  Size: 2.1 MiB

View File

@ -1,16 +1,36 @@
//! The discrete log implementation for the twisted ElGamal decryption.
//!
//! The implementation uses the baby-step giant-step method, which consists of a precomputation
//! step and an online step. The precomputation step involves computing a hash table of a number
//! of Ristretto points that is independent of a discrete log instance. The online phase computes
//! the final discrete log solution using the discrete log instance and the pre-computed hash
//! table. More details on the baby-step giant-step algorithm and the implementation can be found
//! in the [spl documentation](https://spl.solana.com).
//!
//! The implementation is NOT intended to run in constant-time. There are some measures to prevent
//! straightforward timing attacks. For instance, it does not short-circuit the search when a
//! solution is found. However, the use of hashtables, batching, and threads make the
//! implementation inherently not constant-time. This may theoretically allow an adversary to gain
//! information on a discrete log solution depending on the execution time of the implementation.
//!
#![cfg(not(target_os = "solana"))] #![cfg(not(target_os = "solana"))]
use { use {
crate::errors::ProofError, crate::errors::ProofError,
curve25519_dalek::{ curve25519_dalek::{
constants::RISTRETTO_BASEPOINT_POINT as G, ristretto::RistrettoPoint, scalar::Scalar, constants::RISTRETTO_BASEPOINT_POINT as G,
traits::Identity, ristretto::RistrettoPoint,
scalar::Scalar,
traits::{Identity, IsIdentity},
}, },
itertools::Itertools,
serde::{Deserialize, Serialize}, serde::{Deserialize, Serialize},
std::{collections::HashMap, thread}, std::{collections::HashMap, thread},
}; };
const TWO16: u64 = 65536; // 2^16 const TWO16: u64 = 65536; // 2^16
const TWO17: u64 = 131072; // 2^17
/// Type that captures a discrete log challenge. /// Type that captures a discrete log challenge.
/// ///
@ -28,6 +48,8 @@ pub struct DiscreteLog {
range_bound: usize, range_bound: usize,
/// Ristretto point representing each step of the discrete log search /// Ristretto point representing each step of the discrete log search
step_point: RistrettoPoint, step_point: RistrettoPoint,
/// Ristretto point compression batch size
compression_batch_size: usize,
} }
#[derive(Serialize, Deserialize, Default)] #[derive(Serialize, Deserialize, Default)]
@ -38,11 +60,11 @@ pub struct DecodePrecomputation(HashMap<[u8; 32], u16>);
fn decode_u32_precomputation(generator: RistrettoPoint) -> DecodePrecomputation { fn decode_u32_precomputation(generator: RistrettoPoint) -> DecodePrecomputation {
let mut hashmap = HashMap::new(); let mut hashmap = HashMap::new();
let two16_scalar = Scalar::from(TWO16); let two17_scalar = Scalar::from(TWO17);
let identity = RistrettoPoint::identity(); // 0 * G let identity = RistrettoPoint::identity(); // 0 * G
let generator = two16_scalar * generator; // 2^16 * G let generator = two17_scalar * generator; // 2^17 * G
// iterator for 2^12*0G , 2^12*1G, 2^12*2G, ... // iterator for 2^17*0G , 2^17*1G, 2^17*2G, ...
let ristretto_iter = RistrettoIterator::new((identity, 0), (generator, 1)); let ristretto_iter = RistrettoIterator::new((identity, 0), (generator, 1));
for (point, x_hi) in ristretto_iter.take(TWO16 as usize) { for (point, x_hi) in ristretto_iter.take(TWO16 as usize) {
let key = point.compress().to_bytes(); let key = point.compress().to_bytes();
@ -73,13 +95,14 @@ impl DiscreteLog {
num_threads: 1, num_threads: 1,
range_bound: TWO16 as usize, range_bound: TWO16 as usize,
step_point: G, step_point: G,
compression_batch_size: 32,
} }
} }
/// Adjusts number of threads in a discrete log instance. /// Adjusts number of threads in a discrete log instance.
pub fn num_threads(&mut self, num_threads: usize) -> Result<(), ProofError> { pub fn num_threads(&mut self, num_threads: usize) -> Result<(), ProofError> {
// number of threads must be a positive power-of-two integer // number of threads must be a positive power-of-two integer
if num_threads == 0 || (num_threads & (num_threads - 1)) != 0 { if num_threads == 0 || (num_threads & (num_threads - 1)) != 0 || num_threads > 65536 {
return Err(ProofError::DiscreteLogThreads); return Err(ProofError::DiscreteLogThreads);
} }
@ -90,6 +113,19 @@ impl DiscreteLog {
Ok(()) Ok(())
} }
/// Adjusts inversion batch size in a discrete log instance.
pub fn set_compression_batch_size(
&mut self,
compression_batch_size: usize,
) -> Result<(), ProofError> {
if compression_batch_size >= TWO16 as usize {
return Err(ProofError::DiscreteLogBatchSize);
}
self.compression_batch_size = compression_batch_size;
Ok(())
}
/// Solves the discrete log problem under the assumption that the solution /// Solves the discrete log problem under the assumption that the solution
/// is a 32-bit number. /// is a 32-bit number.
pub fn decode_u32(self) -> Option<u64> { pub fn decode_u32(self) -> Option<u64> {
@ -102,8 +138,14 @@ impl DiscreteLog {
(-(&self.step_point), self.num_threads as u64), (-(&self.step_point), self.num_threads as u64),
); );
let handle = let handle = thread::spawn(move || {
thread::spawn(move || Self::decode_range(ristretto_iterator, self.range_bound)); Self::decode_range(
ristretto_iterator,
self.range_bound,
self.compression_batch_size,
)
// Self::decode_range(ristretto_iterator, self.range_bound)
});
starting_point -= G; starting_point -= G;
handle handle
@ -120,16 +162,39 @@ impl DiscreteLog {
solution solution
} }
fn decode_range(ristretto_iterator: RistrettoIterator, range_bound: usize) -> Option<u64> { fn decode_range(
ristretto_iterator: RistrettoIterator,
range_bound: usize,
compression_batch_size: usize,
) -> Option<u64> {
let hashmap = &DECODE_PRECOMPUTATION_FOR_G; let hashmap = &DECODE_PRECOMPUTATION_FOR_G;
let mut decoded = None; let mut decoded = None;
for (point, x_lo) in ristretto_iterator.take(range_bound) {
let key = point.compress().to_bytes(); for batch in &ristretto_iterator
if hashmap.0.contains_key(&key) { .take(range_bound)
let x_hi = hashmap.0[&key]; .chunks(compression_batch_size)
decoded = Some(x_lo + TWO16 * x_hi as u64); {
let (batch_points, batch_indices): (Vec<_>, Vec<_>) = batch
.filter(|(point, index)| {
if point.is_identity() {
decoded = Some(*index);
return false;
}
true
})
.unzip();
let batch_compressed = RistrettoPoint::double_and_compress_batch(&batch_points);
for (point, x_lo) in batch_compressed.iter().zip(batch_indices.iter()) {
let key = point.to_bytes();
if hashmap.0.contains_key(&key) {
let x_hi = hashmap.0[&key];
decoded = Some(x_lo + TWO16 * x_hi as u64);
}
} }
} }
decoded decoded
} }
} }
@ -167,6 +232,7 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_serialize_decode_u32_precomputation_for_G() { fn test_serialize_decode_u32_precomputation_for_G() {
let decode_u32_precomputation_for_G = decode_u32_precomputation(G); let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
// let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
if decode_u32_precomputation_for_G.0 != DECODE_PRECOMPUTATION_FOR_G.0 { if decode_u32_precomputation_for_G.0 != DECODE_PRECOMPUTATION_FOR_G.0 {
use std::{fs::File, io::Write, path::PathBuf}; use std::{fs::File, io::Write, path::PathBuf};
@ -183,7 +249,7 @@ mod tests {
#[test] #[test]
fn test_decode_correctness() { fn test_decode_correctness() {
// general case // general case
let amount: u64 = 55; let amount: u64 = 4294967295;
let instance = DiscreteLog::new(G, Scalar::from(amount) * G); let instance = DiscreteLog::new(G, Scalar::from(amount) * G);

View File

@ -30,6 +30,8 @@ pub enum ProofError {
Decryption, Decryption,
#[error("discrete log number of threads not power-of-two")] #[error("discrete log number of threads not power-of-two")]
DiscreteLogThreads, DiscreteLogThreads,
#[error("discrete log batch size too large")]
DiscreteLogBatchSize,
} }
#[derive(Error, Clone, Debug, Eq, PartialEq)] #[derive(Error, Clone, Debug, Eq, PartialEq)]