Zk token sdk/batch discrete log (#27412)
* zk-token-sdk: optimize discrete log search with batch compression * zk-token-sdk: include batch size as part of discrete log struct * zk-token-sdk: add a note on discrete log timings * zk-token-sdk: add upper bound on the number of threads * zk-token-sdk: minor * zk-token-sdk: cargo.lock
This commit is contained in:
parent
d0983c3cf7
commit
bd88e2a11c
|
@ -6773,6 +6773,7 @@ dependencies = [
|
||||||
"cipher 0.4.3",
|
"cipher 0.4.3",
|
||||||
"curve25519-dalek",
|
"curve25519-dalek",
|
||||||
"getrandom 0.1.16",
|
"getrandom 0.1.16",
|
||||||
|
"itertools",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
"merlin",
|
"merlin",
|
||||||
"num-derive",
|
"num-derive",
|
||||||
|
|
|
@ -5982,6 +5982,7 @@ dependencies = [
|
||||||
"cipher 0.4.3",
|
"cipher 0.4.3",
|
||||||
"curve25519-dalek",
|
"curve25519-dalek",
|
||||||
"getrandom 0.1.14",
|
"getrandom 0.1.14",
|
||||||
|
"itertools",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
"merlin",
|
"merlin",
|
||||||
"num-derive",
|
"num-derive",
|
||||||
|
|
|
@ -22,6 +22,7 @@ byteorder = "1"
|
||||||
cipher = "0.4"
|
cipher = "0.4"
|
||||||
curve25519-dalek = { version = "3.2.1", features = ["serde"] }
|
curve25519-dalek = { version = "3.2.1", features = ["serde"] }
|
||||||
getrandom = { version = "0.1", features = ["dummy"] }
|
getrandom = { version = "0.1", features = ["dummy"] }
|
||||||
|
itertools = "0.10.3"
|
||||||
lazy_static = "1.4.0"
|
lazy_static = "1.4.0"
|
||||||
merlin = "3"
|
merlin = "3"
|
||||||
rand = "0.7"
|
rand = "0.7"
|
||||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 2.1 MiB After Width: | Height: | Size: 2.1 MiB |
|
@ -1,16 +1,36 @@
|
||||||
|
//! The discrete log implementation for the twisted ElGamal decryption.
|
||||||
|
//!
|
||||||
|
//! The implementation uses the baby-step giant-step method, which consists of a precomputation
|
||||||
|
//! step and an online step. The precomputation step involves computing a hash table of a number
|
||||||
|
//! of Ristretto points that is independent of a discrete log instance. The online phase computes
|
||||||
|
//! the final discrete log solution using the discrete log instance and the pre-computed hash
|
||||||
|
//! table. More details on the baby-step giant-step algorithm and the implementation can be found
|
||||||
|
//! in the [spl documentation](https://spl.solana.com).
|
||||||
|
//!
|
||||||
|
//! The implementation is NOT intended to run in constant-time. There are some measures to prevent
|
||||||
|
//! straightforward timing attacks. For instance, it does not short-circuit the search when a
|
||||||
|
//! solution is found. However, the use of hashtables, batching, and threads make the
|
||||||
|
//! implementation inherently not constant-time. This may theoretically allow an adversary to gain
|
||||||
|
//! information on a discrete log solution depending on the execution time of the implementation.
|
||||||
|
//!
|
||||||
|
|
||||||
#![cfg(not(target_os = "solana"))]
|
#![cfg(not(target_os = "solana"))]
|
||||||
|
|
||||||
use {
|
use {
|
||||||
crate::errors::ProofError,
|
crate::errors::ProofError,
|
||||||
curve25519_dalek::{
|
curve25519_dalek::{
|
||||||
constants::RISTRETTO_BASEPOINT_POINT as G, ristretto::RistrettoPoint, scalar::Scalar,
|
constants::RISTRETTO_BASEPOINT_POINT as G,
|
||||||
traits::Identity,
|
ristretto::RistrettoPoint,
|
||||||
|
scalar::Scalar,
|
||||||
|
traits::{Identity, IsIdentity},
|
||||||
},
|
},
|
||||||
|
itertools::Itertools,
|
||||||
serde::{Deserialize, Serialize},
|
serde::{Deserialize, Serialize},
|
||||||
std::{collections::HashMap, thread},
|
std::{collections::HashMap, thread},
|
||||||
};
|
};
|
||||||
|
|
||||||
const TWO16: u64 = 65536; // 2^16
|
const TWO16: u64 = 65536; // 2^16
|
||||||
|
const TWO17: u64 = 131072; // 2^17
|
||||||
|
|
||||||
/// Type that captures a discrete log challenge.
|
/// Type that captures a discrete log challenge.
|
||||||
///
|
///
|
||||||
|
@ -28,6 +48,8 @@ pub struct DiscreteLog {
|
||||||
range_bound: usize,
|
range_bound: usize,
|
||||||
/// Ristretto point representing each step of the discrete log search
|
/// Ristretto point representing each step of the discrete log search
|
||||||
step_point: RistrettoPoint,
|
step_point: RistrettoPoint,
|
||||||
|
/// Ristretto point compression batch size
|
||||||
|
compression_batch_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Default)]
|
#[derive(Serialize, Deserialize, Default)]
|
||||||
|
@ -38,11 +60,11 @@ pub struct DecodePrecomputation(HashMap<[u8; 32], u16>);
|
||||||
fn decode_u32_precomputation(generator: RistrettoPoint) -> DecodePrecomputation {
|
fn decode_u32_precomputation(generator: RistrettoPoint) -> DecodePrecomputation {
|
||||||
let mut hashmap = HashMap::new();
|
let mut hashmap = HashMap::new();
|
||||||
|
|
||||||
let two16_scalar = Scalar::from(TWO16);
|
let two17_scalar = Scalar::from(TWO17);
|
||||||
let identity = RistrettoPoint::identity(); // 0 * G
|
let identity = RistrettoPoint::identity(); // 0 * G
|
||||||
let generator = two16_scalar * generator; // 2^16 * G
|
let generator = two17_scalar * generator; // 2^17 * G
|
||||||
|
|
||||||
// iterator for 2^12*0G , 2^12*1G, 2^12*2G, ...
|
// iterator for 2^17*0G , 2^17*1G, 2^17*2G, ...
|
||||||
let ristretto_iter = RistrettoIterator::new((identity, 0), (generator, 1));
|
let ristretto_iter = RistrettoIterator::new((identity, 0), (generator, 1));
|
||||||
for (point, x_hi) in ristretto_iter.take(TWO16 as usize) {
|
for (point, x_hi) in ristretto_iter.take(TWO16 as usize) {
|
||||||
let key = point.compress().to_bytes();
|
let key = point.compress().to_bytes();
|
||||||
|
@ -73,13 +95,14 @@ impl DiscreteLog {
|
||||||
num_threads: 1,
|
num_threads: 1,
|
||||||
range_bound: TWO16 as usize,
|
range_bound: TWO16 as usize,
|
||||||
step_point: G,
|
step_point: G,
|
||||||
|
compression_batch_size: 32,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Adjusts number of threads in a discrete log instance.
|
/// Adjusts number of threads in a discrete log instance.
|
||||||
pub fn num_threads(&mut self, num_threads: usize) -> Result<(), ProofError> {
|
pub fn num_threads(&mut self, num_threads: usize) -> Result<(), ProofError> {
|
||||||
// number of threads must be a positive power-of-two integer
|
// number of threads must be a positive power-of-two integer
|
||||||
if num_threads == 0 || (num_threads & (num_threads - 1)) != 0 {
|
if num_threads == 0 || (num_threads & (num_threads - 1)) != 0 || num_threads > 65536 {
|
||||||
return Err(ProofError::DiscreteLogThreads);
|
return Err(ProofError::DiscreteLogThreads);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -90,6 +113,19 @@ impl DiscreteLog {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Adjusts inversion batch size in a discrete log instance.
|
||||||
|
pub fn set_compression_batch_size(
|
||||||
|
&mut self,
|
||||||
|
compression_batch_size: usize,
|
||||||
|
) -> Result<(), ProofError> {
|
||||||
|
if compression_batch_size >= TWO16 as usize {
|
||||||
|
return Err(ProofError::DiscreteLogBatchSize);
|
||||||
|
}
|
||||||
|
self.compression_batch_size = compression_batch_size;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Solves the discrete log problem under the assumption that the solution
|
/// Solves the discrete log problem under the assumption that the solution
|
||||||
/// is a 32-bit number.
|
/// is a 32-bit number.
|
||||||
pub fn decode_u32(self) -> Option<u64> {
|
pub fn decode_u32(self) -> Option<u64> {
|
||||||
|
@ -102,8 +138,14 @@ impl DiscreteLog {
|
||||||
(-(&self.step_point), self.num_threads as u64),
|
(-(&self.step_point), self.num_threads as u64),
|
||||||
);
|
);
|
||||||
|
|
||||||
let handle =
|
let handle = thread::spawn(move || {
|
||||||
thread::spawn(move || Self::decode_range(ristretto_iterator, self.range_bound));
|
Self::decode_range(
|
||||||
|
ristretto_iterator,
|
||||||
|
self.range_bound,
|
||||||
|
self.compression_batch_size,
|
||||||
|
)
|
||||||
|
// Self::decode_range(ristretto_iterator, self.range_bound)
|
||||||
|
});
|
||||||
|
|
||||||
starting_point -= G;
|
starting_point -= G;
|
||||||
handle
|
handle
|
||||||
|
@ -120,16 +162,39 @@ impl DiscreteLog {
|
||||||
solution
|
solution
|
||||||
}
|
}
|
||||||
|
|
||||||
fn decode_range(ristretto_iterator: RistrettoIterator, range_bound: usize) -> Option<u64> {
|
fn decode_range(
|
||||||
|
ristretto_iterator: RistrettoIterator,
|
||||||
|
range_bound: usize,
|
||||||
|
compression_batch_size: usize,
|
||||||
|
) -> Option<u64> {
|
||||||
let hashmap = &DECODE_PRECOMPUTATION_FOR_G;
|
let hashmap = &DECODE_PRECOMPUTATION_FOR_G;
|
||||||
let mut decoded = None;
|
let mut decoded = None;
|
||||||
for (point, x_lo) in ristretto_iterator.take(range_bound) {
|
|
||||||
let key = point.compress().to_bytes();
|
for batch in &ristretto_iterator
|
||||||
|
.take(range_bound)
|
||||||
|
.chunks(compression_batch_size)
|
||||||
|
{
|
||||||
|
let (batch_points, batch_indices): (Vec<_>, Vec<_>) = batch
|
||||||
|
.filter(|(point, index)| {
|
||||||
|
if point.is_identity() {
|
||||||
|
decoded = Some(*index);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
true
|
||||||
|
})
|
||||||
|
.unzip();
|
||||||
|
|
||||||
|
let batch_compressed = RistrettoPoint::double_and_compress_batch(&batch_points);
|
||||||
|
|
||||||
|
for (point, x_lo) in batch_compressed.iter().zip(batch_indices.iter()) {
|
||||||
|
let key = point.to_bytes();
|
||||||
if hashmap.0.contains_key(&key) {
|
if hashmap.0.contains_key(&key) {
|
||||||
let x_hi = hashmap.0[&key];
|
let x_hi = hashmap.0[&key];
|
||||||
decoded = Some(x_lo + TWO16 * x_hi as u64);
|
decoded = Some(x_lo + TWO16 * x_hi as u64);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
decoded
|
decoded
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -167,6 +232,7 @@ mod tests {
|
||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
fn test_serialize_decode_u32_precomputation_for_G() {
|
fn test_serialize_decode_u32_precomputation_for_G() {
|
||||||
let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
|
let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
|
||||||
|
// let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
|
||||||
|
|
||||||
if decode_u32_precomputation_for_G.0 != DECODE_PRECOMPUTATION_FOR_G.0 {
|
if decode_u32_precomputation_for_G.0 != DECODE_PRECOMPUTATION_FOR_G.0 {
|
||||||
use std::{fs::File, io::Write, path::PathBuf};
|
use std::{fs::File, io::Write, path::PathBuf};
|
||||||
|
@ -183,7 +249,7 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn test_decode_correctness() {
|
fn test_decode_correctness() {
|
||||||
// general case
|
// general case
|
||||||
let amount: u64 = 55;
|
let amount: u64 = 4294967295;
|
||||||
|
|
||||||
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
|
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,8 @@ pub enum ProofError {
|
||||||
Decryption,
|
Decryption,
|
||||||
#[error("discrete log number of threads not power-of-two")]
|
#[error("discrete log number of threads not power-of-two")]
|
||||||
DiscreteLogThreads,
|
DiscreteLogThreads,
|
||||||
|
#[error("discrete log batch size too large")]
|
||||||
|
DiscreteLogBatchSize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Error, Clone, Debug, Eq, PartialEq)]
|
#[derive(Error, Clone, Debug, Eq, PartialEq)]
|
||||||
|
|
Loading…
Reference in New Issue