Zk token sdk/batch discrete log (#27412)

* zk-token-sdk: optimize discrete log search with batch compression

* zk-token-sdk: include batch size as part of discrete log struct

* zk-token-sdk: add a note on discrete log timings

* zk-token-sdk: add upper bound on the number of threads

* zk-token-sdk: minor

* zk-token-sdk: cargo.lock
This commit is contained in:
samkim-crypto 2022-08-27 06:54:59 +09:00 committed by GitHub
parent d0983c3cf7
commit bd88e2a11c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 86 additions and 15 deletions

1
Cargo.lock generated
View File

@ -6773,6 +6773,7 @@ dependencies = [
"cipher 0.4.3",
"curve25519-dalek",
"getrandom 0.1.16",
"itertools",
"lazy_static",
"merlin",
"num-derive",

View File

@ -5982,6 +5982,7 @@ dependencies = [
"cipher 0.4.3",
"curve25519-dalek",
"getrandom 0.1.14",
"itertools",
"lazy_static",
"merlin",
"num-derive",

View File

@ -22,6 +22,7 @@ byteorder = "1"
cipher = "0.4"
curve25519-dalek = { version = "3.2.1", features = ["serde"] }
getrandom = { version = "0.1", features = ["dummy"] }
itertools = "0.10.3"
lazy_static = "1.4.0"
merlin = "3"
rand = "0.7"

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.1 MiB

After

Width:  |  Height:  |  Size: 2.1 MiB

View File

@ -1,16 +1,36 @@
//! The discrete log implementation for the twisted ElGamal decryption.
//!
//! The implementation uses the baby-step giant-step method, which consists of a precomputation
//! step and an online step. The precomputation step involves computing a hash table of a number
//! of Ristretto points that is independent of a discrete log instance. The online phase computes
//! the final discrete log solution using the discrete log instance and the pre-computed hash
//! table. More details on the baby-step giant-step algorithm and the implementation can be found
//! in the [spl documentation](https://spl.solana.com).
//!
//! The implementation is NOT intended to run in constant-time. There are some measures to prevent
//! straightforward timing attacks. For instance, it does not short-circuit the search when a
//! solution is found. However, the use of hashtables, batching, and threads make the
//! implementation inherently not constant-time. This may theoretically allow an adversary to gain
//! information on a discrete log solution depending on the execution time of the implementation.
//!
#![cfg(not(target_os = "solana"))]
use {
crate::errors::ProofError,
curve25519_dalek::{
constants::RISTRETTO_BASEPOINT_POINT as G, ristretto::RistrettoPoint, scalar::Scalar,
traits::Identity,
constants::RISTRETTO_BASEPOINT_POINT as G,
ristretto::RistrettoPoint,
scalar::Scalar,
traits::{Identity, IsIdentity},
},
itertools::Itertools,
serde::{Deserialize, Serialize},
std::{collections::HashMap, thread},
};
const TWO16: u64 = 65536; // 2^16
const TWO17: u64 = 131072; // 2^17
/// Type that captures a discrete log challenge.
///
@ -28,6 +48,8 @@ pub struct DiscreteLog {
range_bound: usize,
/// Ristretto point representing each step of the discrete log search
step_point: RistrettoPoint,
/// Ristretto point compression batch size
compression_batch_size: usize,
}
#[derive(Serialize, Deserialize, Default)]
@ -38,11 +60,11 @@ pub struct DecodePrecomputation(HashMap<[u8; 32], u16>);
fn decode_u32_precomputation(generator: RistrettoPoint) -> DecodePrecomputation {
let mut hashmap = HashMap::new();
let two16_scalar = Scalar::from(TWO16);
let two17_scalar = Scalar::from(TWO17);
let identity = RistrettoPoint::identity(); // 0 * G
let generator = two16_scalar * generator; // 2^16 * G
let generator = two17_scalar * generator; // 2^17 * G
// iterator for 2^12*0G , 2^12*1G, 2^12*2G, ...
// iterator for 2^17*0G , 2^17*1G, 2^17*2G, ...
let ristretto_iter = RistrettoIterator::new((identity, 0), (generator, 1));
for (point, x_hi) in ristretto_iter.take(TWO16 as usize) {
let key = point.compress().to_bytes();
@ -73,13 +95,14 @@ impl DiscreteLog {
num_threads: 1,
range_bound: TWO16 as usize,
step_point: G,
compression_batch_size: 32,
}
}
/// Adjusts number of threads in a discrete log instance.
pub fn num_threads(&mut self, num_threads: usize) -> Result<(), ProofError> {
// number of threads must be a positive power-of-two integer
if num_threads == 0 || (num_threads & (num_threads - 1)) != 0 {
if num_threads == 0 || (num_threads & (num_threads - 1)) != 0 || num_threads > 65536 {
return Err(ProofError::DiscreteLogThreads);
}
@ -90,6 +113,19 @@ impl DiscreteLog {
Ok(())
}
/// Adjusts inversion batch size in a discrete log instance.
pub fn set_compression_batch_size(
&mut self,
compression_batch_size: usize,
) -> Result<(), ProofError> {
if compression_batch_size >= TWO16 as usize {
return Err(ProofError::DiscreteLogBatchSize);
}
self.compression_batch_size = compression_batch_size;
Ok(())
}
/// Solves the discrete log problem under the assumption that the solution
/// is a 32-bit number.
pub fn decode_u32(self) -> Option<u64> {
@ -102,8 +138,14 @@ impl DiscreteLog {
(-(&self.step_point), self.num_threads as u64),
);
let handle =
thread::spawn(move || Self::decode_range(ristretto_iterator, self.range_bound));
let handle = thread::spawn(move || {
Self::decode_range(
ristretto_iterator,
self.range_bound,
self.compression_batch_size,
)
// Self::decode_range(ristretto_iterator, self.range_bound)
});
starting_point -= G;
handle
@ -120,16 +162,39 @@ impl DiscreteLog {
solution
}
fn decode_range(ristretto_iterator: RistrettoIterator, range_bound: usize) -> Option<u64> {
fn decode_range(
ristretto_iterator: RistrettoIterator,
range_bound: usize,
compression_batch_size: usize,
) -> Option<u64> {
let hashmap = &DECODE_PRECOMPUTATION_FOR_G;
let mut decoded = None;
for (point, x_lo) in ristretto_iterator.take(range_bound) {
let key = point.compress().to_bytes();
if hashmap.0.contains_key(&key) {
let x_hi = hashmap.0[&key];
decoded = Some(x_lo + TWO16 * x_hi as u64);
for batch in &ristretto_iterator
.take(range_bound)
.chunks(compression_batch_size)
{
let (batch_points, batch_indices): (Vec<_>, Vec<_>) = batch
.filter(|(point, index)| {
if point.is_identity() {
decoded = Some(*index);
return false;
}
true
})
.unzip();
let batch_compressed = RistrettoPoint::double_and_compress_batch(&batch_points);
for (point, x_lo) in batch_compressed.iter().zip(batch_indices.iter()) {
let key = point.to_bytes();
if hashmap.0.contains_key(&key) {
let x_hi = hashmap.0[&key];
decoded = Some(x_lo + TWO16 * x_hi as u64);
}
}
}
decoded
}
}
@ -167,6 +232,7 @@ mod tests {
#[allow(non_snake_case)]
fn test_serialize_decode_u32_precomputation_for_G() {
let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
// let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
if decode_u32_precomputation_for_G.0 != DECODE_PRECOMPUTATION_FOR_G.0 {
use std::{fs::File, io::Write, path::PathBuf};
@ -183,7 +249,7 @@ mod tests {
#[test]
fn test_decode_correctness() {
// general case
let amount: u64 = 55;
let amount: u64 = 4294967295;
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);

View File

@ -30,6 +30,8 @@ pub enum ProofError {
Decryption,
#[error("discrete log number of threads not power-of-two")]
DiscreteLogThreads,
#[error("discrete log batch size too large")]
DiscreteLogBatchSize,
}
#[derive(Error, Clone, Debug, Eq, PartialEq)]