Zk token sdk/batch discrete log (#27412)
* zk-token-sdk: optimize discrete log search with batch compression * zk-token-sdk: include batch size as part of discrete log struct * zk-token-sdk: add a note on discrete log timings * zk-token-sdk: add upper bound on the number of threads * zk-token-sdk: minor * zk-token-sdk: cargo.lock
This commit is contained in:
parent
d0983c3cf7
commit
bd88e2a11c
|
@ -6773,6 +6773,7 @@ dependencies = [
|
|||
"cipher 0.4.3",
|
||||
"curve25519-dalek",
|
||||
"getrandom 0.1.16",
|
||||
"itertools",
|
||||
"lazy_static",
|
||||
"merlin",
|
||||
"num-derive",
|
||||
|
|
|
@ -5982,6 +5982,7 @@ dependencies = [
|
|||
"cipher 0.4.3",
|
||||
"curve25519-dalek",
|
||||
"getrandom 0.1.14",
|
||||
"itertools",
|
||||
"lazy_static",
|
||||
"merlin",
|
||||
"num-derive",
|
||||
|
|
|
@ -22,6 +22,7 @@ byteorder = "1"
|
|||
cipher = "0.4"
|
||||
curve25519-dalek = { version = "3.2.1", features = ["serde"] }
|
||||
getrandom = { version = "0.1", features = ["dummy"] }
|
||||
itertools = "0.10.3"
|
||||
lazy_static = "1.4.0"
|
||||
merlin = "3"
|
||||
rand = "0.7"
|
||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 2.1 MiB After Width: | Height: | Size: 2.1 MiB |
|
@ -1,16 +1,36 @@
|
|||
//! The discrete log implementation for the twisted ElGamal decryption.
|
||||
//!
|
||||
//! The implementation uses the baby-step giant-step method, which consists of a precomputation
|
||||
//! step and an online step. The precomputation step involves computing a hash table of a number
|
||||
//! of Ristretto points that is independent of a discrete log instance. The online phase computes
|
||||
//! the final discrete log solution using the discrete log instance and the pre-computed hash
|
||||
//! table. More details on the baby-step giant-step algorithm and the implementation can be found
|
||||
//! in the [spl documentation](https://spl.solana.com).
|
||||
//!
|
||||
//! The implementation is NOT intended to run in constant-time. There are some measures to prevent
|
||||
//! straightforward timing attacks. For instance, it does not short-circuit the search when a
|
||||
//! solution is found. However, the use of hashtables, batching, and threads make the
|
||||
//! implementation inherently not constant-time. This may theoretically allow an adversary to gain
|
||||
//! information on a discrete log solution depending on the execution time of the implementation.
|
||||
//!
|
||||
|
||||
#![cfg(not(target_os = "solana"))]
|
||||
|
||||
use {
|
||||
crate::errors::ProofError,
|
||||
curve25519_dalek::{
|
||||
constants::RISTRETTO_BASEPOINT_POINT as G, ristretto::RistrettoPoint, scalar::Scalar,
|
||||
traits::Identity,
|
||||
constants::RISTRETTO_BASEPOINT_POINT as G,
|
||||
ristretto::RistrettoPoint,
|
||||
scalar::Scalar,
|
||||
traits::{Identity, IsIdentity},
|
||||
},
|
||||
itertools::Itertools,
|
||||
serde::{Deserialize, Serialize},
|
||||
std::{collections::HashMap, thread},
|
||||
};
|
||||
|
||||
const TWO16: u64 = 65536; // 2^16
|
||||
const TWO17: u64 = 131072; // 2^17
|
||||
|
||||
/// Type that captures a discrete log challenge.
|
||||
///
|
||||
|
@ -28,6 +48,8 @@ pub struct DiscreteLog {
|
|||
range_bound: usize,
|
||||
/// Ristretto point representing each step of the discrete log search
|
||||
step_point: RistrettoPoint,
|
||||
/// Ristretto point compression batch size
|
||||
compression_batch_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default)]
|
||||
|
@ -38,11 +60,11 @@ pub struct DecodePrecomputation(HashMap<[u8; 32], u16>);
|
|||
fn decode_u32_precomputation(generator: RistrettoPoint) -> DecodePrecomputation {
|
||||
let mut hashmap = HashMap::new();
|
||||
|
||||
let two16_scalar = Scalar::from(TWO16);
|
||||
let two17_scalar = Scalar::from(TWO17);
|
||||
let identity = RistrettoPoint::identity(); // 0 * G
|
||||
let generator = two16_scalar * generator; // 2^16 * G
|
||||
let generator = two17_scalar * generator; // 2^17 * G
|
||||
|
||||
// iterator for 2^12*0G , 2^12*1G, 2^12*2G, ...
|
||||
// iterator for 2^17*0G , 2^17*1G, 2^17*2G, ...
|
||||
let ristretto_iter = RistrettoIterator::new((identity, 0), (generator, 1));
|
||||
for (point, x_hi) in ristretto_iter.take(TWO16 as usize) {
|
||||
let key = point.compress().to_bytes();
|
||||
|
@ -73,13 +95,14 @@ impl DiscreteLog {
|
|||
num_threads: 1,
|
||||
range_bound: TWO16 as usize,
|
||||
step_point: G,
|
||||
compression_batch_size: 32,
|
||||
}
|
||||
}
|
||||
|
||||
/// Adjusts number of threads in a discrete log instance.
|
||||
pub fn num_threads(&mut self, num_threads: usize) -> Result<(), ProofError> {
|
||||
// number of threads must be a positive power-of-two integer
|
||||
if num_threads == 0 || (num_threads & (num_threads - 1)) != 0 {
|
||||
if num_threads == 0 || (num_threads & (num_threads - 1)) != 0 || num_threads > 65536 {
|
||||
return Err(ProofError::DiscreteLogThreads);
|
||||
}
|
||||
|
||||
|
@ -90,6 +113,19 @@ impl DiscreteLog {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Adjusts inversion batch size in a discrete log instance.
|
||||
pub fn set_compression_batch_size(
|
||||
&mut self,
|
||||
compression_batch_size: usize,
|
||||
) -> Result<(), ProofError> {
|
||||
if compression_batch_size >= TWO16 as usize {
|
||||
return Err(ProofError::DiscreteLogBatchSize);
|
||||
}
|
||||
self.compression_batch_size = compression_batch_size;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Solves the discrete log problem under the assumption that the solution
|
||||
/// is a 32-bit number.
|
||||
pub fn decode_u32(self) -> Option<u64> {
|
||||
|
@ -102,8 +138,14 @@ impl DiscreteLog {
|
|||
(-(&self.step_point), self.num_threads as u64),
|
||||
);
|
||||
|
||||
let handle =
|
||||
thread::spawn(move || Self::decode_range(ristretto_iterator, self.range_bound));
|
||||
let handle = thread::spawn(move || {
|
||||
Self::decode_range(
|
||||
ristretto_iterator,
|
||||
self.range_bound,
|
||||
self.compression_batch_size,
|
||||
)
|
||||
// Self::decode_range(ristretto_iterator, self.range_bound)
|
||||
});
|
||||
|
||||
starting_point -= G;
|
||||
handle
|
||||
|
@ -120,16 +162,39 @@ impl DiscreteLog {
|
|||
solution
|
||||
}
|
||||
|
||||
fn decode_range(ristretto_iterator: RistrettoIterator, range_bound: usize) -> Option<u64> {
|
||||
fn decode_range(
|
||||
ristretto_iterator: RistrettoIterator,
|
||||
range_bound: usize,
|
||||
compression_batch_size: usize,
|
||||
) -> Option<u64> {
|
||||
let hashmap = &DECODE_PRECOMPUTATION_FOR_G;
|
||||
let mut decoded = None;
|
||||
for (point, x_lo) in ristretto_iterator.take(range_bound) {
|
||||
let key = point.compress().to_bytes();
|
||||
if hashmap.0.contains_key(&key) {
|
||||
let x_hi = hashmap.0[&key];
|
||||
decoded = Some(x_lo + TWO16 * x_hi as u64);
|
||||
|
||||
for batch in &ristretto_iterator
|
||||
.take(range_bound)
|
||||
.chunks(compression_batch_size)
|
||||
{
|
||||
let (batch_points, batch_indices): (Vec<_>, Vec<_>) = batch
|
||||
.filter(|(point, index)| {
|
||||
if point.is_identity() {
|
||||
decoded = Some(*index);
|
||||
return false;
|
||||
}
|
||||
true
|
||||
})
|
||||
.unzip();
|
||||
|
||||
let batch_compressed = RistrettoPoint::double_and_compress_batch(&batch_points);
|
||||
|
||||
for (point, x_lo) in batch_compressed.iter().zip(batch_indices.iter()) {
|
||||
let key = point.to_bytes();
|
||||
if hashmap.0.contains_key(&key) {
|
||||
let x_hi = hashmap.0[&key];
|
||||
decoded = Some(x_lo + TWO16 * x_hi as u64);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
decoded
|
||||
}
|
||||
}
|
||||
|
@ -167,6 +232,7 @@ mod tests {
|
|||
#[allow(non_snake_case)]
|
||||
fn test_serialize_decode_u32_precomputation_for_G() {
|
||||
let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
|
||||
// let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
|
||||
|
||||
if decode_u32_precomputation_for_G.0 != DECODE_PRECOMPUTATION_FOR_G.0 {
|
||||
use std::{fs::File, io::Write, path::PathBuf};
|
||||
|
@ -183,7 +249,7 @@ mod tests {
|
|||
#[test]
|
||||
fn test_decode_correctness() {
|
||||
// general case
|
||||
let amount: u64 = 55;
|
||||
let amount: u64 = 4294967295;
|
||||
|
||||
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
|
||||
|
||||
|
|
|
@ -30,6 +30,8 @@ pub enum ProofError {
|
|||
Decryption,
|
||||
#[error("discrete log number of threads not power-of-two")]
|
||||
DiscreteLogThreads,
|
||||
#[error("discrete log batch size too large")]
|
||||
DiscreteLogBatchSize,
|
||||
}
|
||||
|
||||
#[derive(Error, Clone, Debug, Eq, PartialEq)]
|
||||
|
|
Loading…
Reference in New Issue