323 lines
11 KiB
Rust
323 lines
11 KiB
Rust
//! The discrete log implementation for the twisted ElGamal decryption.
|
|
//!
|
|
//! The implementation uses the baby-step giant-step method, which consists of a precomputation
|
|
//! step and an online step. The precomputation step involves computing a hash table of a number
|
|
//! of Ristretto points that is independent of a discrete log instance. The online phase computes
|
|
//! the final discrete log solution using the discrete log instance and the pre-computed hash
|
|
//! table. More details on the baby-step giant-step algorithm and the implementation can be found
|
|
//! in the [spl documentation](https://spl.solana.com).
|
|
//!
|
|
//! The implementation is NOT intended to run in constant-time. There are some measures to prevent
|
|
//! straightforward timing attacks. For instance, it does not short-circuit the search when a
|
|
//! solution is found. However, the use of hashtables, batching, and threads make the
|
|
//! implementation inherently not constant-time. This may theoretically allow an adversary to gain
|
|
//! information on a discrete log solution depending on the execution time of the implementation.
|
|
//!
|
|
|
|
#![cfg(not(target_os = "solana"))]
|
|
|
|
use {
|
|
crate::encryption::errors::DiscreteLogError,
|
|
curve25519_dalek::{
|
|
constants::RISTRETTO_BASEPOINT_POINT as G,
|
|
ristretto::RistrettoPoint,
|
|
scalar::Scalar,
|
|
traits::{Identity, IsIdentity},
|
|
},
|
|
itertools::Itertools,
|
|
serde::{Deserialize, Serialize},
|
|
std::{collections::HashMap, thread},
|
|
};
|
|
|
|
const TWO16: u64 = 65536; // 2^16
|
|
const TWO17: u64 = 131072; // 2^17
|
|
|
|
/// Type that captures a discrete log challenge.
|
|
///
|
|
/// The goal of discrete log is to find x such that x * generator = target.
|
|
#[derive(Serialize, Deserialize, Copy, Clone, Debug, Eq, PartialEq)]
|
|
pub struct DiscreteLog {
|
|
/// Generator point for discrete log
|
|
pub generator: RistrettoPoint,
|
|
/// Target point for discrete log
|
|
pub target: RistrettoPoint,
|
|
/// Number of threads used for discrete log computation
|
|
num_threads: usize,
|
|
/// Range bound for discrete log search derived from the max value to search for and
|
|
/// `num_threads`
|
|
range_bound: usize,
|
|
/// Ristretto point representing each step of the discrete log search
|
|
step_point: RistrettoPoint,
|
|
/// Ristretto point compression batch size
|
|
compression_batch_size: usize,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Default)]
|
|
pub struct DecodePrecomputation(HashMap<[u8; 32], u16>);
|
|
|
|
/// Builds a HashMap of 2^16 elements
|
|
#[allow(dead_code)]
|
|
fn decode_u32_precomputation(generator: RistrettoPoint) -> DecodePrecomputation {
|
|
let mut hashmap = HashMap::new();
|
|
|
|
let two17_scalar = Scalar::from(TWO17);
|
|
let identity = RistrettoPoint::identity(); // 0 * G
|
|
let generator = two17_scalar * generator; // 2^17 * G
|
|
|
|
// iterator for 2^17*0G , 2^17*1G, 2^17*2G, ...
|
|
let ristretto_iter = RistrettoIterator::new((identity, 0), (generator, 1));
|
|
for (point, x_hi) in ristretto_iter.take(TWO16 as usize) {
|
|
let key = point.compress().to_bytes();
|
|
hashmap.insert(key, x_hi as u16);
|
|
}
|
|
|
|
DecodePrecomputation(hashmap)
|
|
}
|
|
|
|
lazy_static::lazy_static! {
|
|
/// Pre-computed HashMap needed for decryption. The HashMap is independent of (works for) any key.
|
|
pub static ref DECODE_PRECOMPUTATION_FOR_G: DecodePrecomputation = {
|
|
static DECODE_PRECOMPUTATION_FOR_G_BINCODE: &[u8] =
|
|
include_bytes!("decode_u32_precomputation_for_G.bincode");
|
|
bincode::deserialize(DECODE_PRECOMPUTATION_FOR_G_BINCODE).unwrap_or_default()
|
|
};
|
|
}
|
|
|
|
/// Solves the discrete log instance using a 16/16 bit offline/online split
|
|
impl DiscreteLog {
|
|
/// Discrete log instance constructor.
|
|
///
|
|
/// Default number of threads set to 1.
|
|
pub fn new(generator: RistrettoPoint, target: RistrettoPoint) -> Self {
|
|
Self {
|
|
generator,
|
|
target,
|
|
num_threads: 1,
|
|
range_bound: TWO16 as usize,
|
|
step_point: G,
|
|
compression_batch_size: 32,
|
|
}
|
|
}
|
|
|
|
/// Adjusts number of threads in a discrete log instance.
|
|
pub fn num_threads(&mut self, num_threads: usize) -> Result<(), DiscreteLogError> {
|
|
// number of threads must be a positive power-of-two integer
|
|
if num_threads == 0 || (num_threads & (num_threads - 1)) != 0 || num_threads > 65536 {
|
|
return Err(DiscreteLogError::DiscreteLogThreads);
|
|
}
|
|
|
|
self.num_threads = num_threads;
|
|
self.range_bound = (TWO16 as usize).checked_div(num_threads).unwrap();
|
|
self.step_point = Scalar::from(num_threads as u64) * G;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Adjusts inversion batch size in a discrete log instance.
|
|
pub fn set_compression_batch_size(
|
|
&mut self,
|
|
compression_batch_size: usize,
|
|
) -> Result<(), DiscreteLogError> {
|
|
if compression_batch_size >= TWO16 as usize {
|
|
return Err(DiscreteLogError::DiscreteLogBatchSize);
|
|
}
|
|
self.compression_batch_size = compression_batch_size;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Solves the discrete log problem under the assumption that the solution
|
|
/// is a positive 32-bit number.
|
|
pub fn decode_u32(self) -> Option<u64> {
|
|
let mut starting_point = self.target;
|
|
let handles = (0..self.num_threads)
|
|
.map(|i| {
|
|
let ristretto_iterator = RistrettoIterator::new(
|
|
(starting_point, i as u64),
|
|
(-(&self.step_point), self.num_threads as u64),
|
|
);
|
|
|
|
let handle = thread::spawn(move || {
|
|
Self::decode_range(
|
|
ristretto_iterator,
|
|
self.range_bound,
|
|
self.compression_batch_size,
|
|
)
|
|
});
|
|
|
|
starting_point -= G;
|
|
handle
|
|
})
|
|
.collect::<Vec<_>>();
|
|
|
|
let mut solution = None;
|
|
for handle in handles {
|
|
let discrete_log = handle.join().unwrap();
|
|
if discrete_log.is_some() {
|
|
solution = discrete_log;
|
|
}
|
|
}
|
|
solution
|
|
}
|
|
|
|
fn decode_range(
|
|
ristretto_iterator: RistrettoIterator,
|
|
range_bound: usize,
|
|
compression_batch_size: usize,
|
|
) -> Option<u64> {
|
|
let hashmap = &DECODE_PRECOMPUTATION_FOR_G;
|
|
let mut decoded = None;
|
|
|
|
for batch in &ristretto_iterator
|
|
.take(range_bound)
|
|
.chunks(compression_batch_size)
|
|
{
|
|
// batch compression currently errors if any point in the batch is the identity point
|
|
let (batch_points, batch_indices): (Vec<_>, Vec<_>) = batch
|
|
.filter(|(point, index)| {
|
|
if point.is_identity() {
|
|
decoded = Some(*index);
|
|
return false;
|
|
}
|
|
true
|
|
})
|
|
.unzip();
|
|
|
|
let batch_compressed = RistrettoPoint::double_and_compress_batch(&batch_points);
|
|
|
|
for (point, x_lo) in batch_compressed.iter().zip(batch_indices.iter()) {
|
|
let key = point.to_bytes();
|
|
if hashmap.0.contains_key(&key) {
|
|
let x_hi = hashmap.0[&key];
|
|
decoded = Some(x_lo + TWO16 * x_hi as u64);
|
|
}
|
|
}
|
|
}
|
|
|
|
decoded
|
|
}
|
|
}
|
|
|
|
/// Hashable Ristretto iterator.
|
|
///
|
|
/// Given an initial point X and a stepping point P, the iterator iterates through
|
|
/// X + 0*P, X + 1*P, X + 2*P, X + 3*P, ...
|
|
struct RistrettoIterator {
|
|
pub current: (RistrettoPoint, u64),
|
|
pub step: (RistrettoPoint, u64),
|
|
}
|
|
|
|
impl RistrettoIterator {
|
|
fn new(current: (RistrettoPoint, u64), step: (RistrettoPoint, u64)) -> Self {
|
|
RistrettoIterator { current, step }
|
|
}
|
|
}
|
|
|
|
impl Iterator for RistrettoIterator {
|
|
type Item = (RistrettoPoint, u64);
|
|
|
|
fn next(&mut self) -> Option<Self::Item> {
|
|
let r = self.current;
|
|
self.current = (self.current.0 + self.step.0, self.current.1 + self.step.1);
|
|
Some(r)
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use {super::*, std::time::Instant};
|
|
|
|
#[test]
|
|
#[allow(non_snake_case)]
|
|
fn test_serialize_decode_u32_precomputation_for_G() {
|
|
let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
|
|
// let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
|
|
|
|
if decode_u32_precomputation_for_G.0 != DECODE_PRECOMPUTATION_FOR_G.0 {
|
|
use std::{fs::File, io::Write, path::PathBuf};
|
|
let mut f = File::create(PathBuf::from(
|
|
"src/encryption/decode_u32_precomputation_for_G.bincode",
|
|
))
|
|
.unwrap();
|
|
f.write_all(&bincode::serialize(&decode_u32_precomputation_for_G).unwrap())
|
|
.unwrap();
|
|
panic!("Rebuild and run this test again");
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_decode_correctness() {
|
|
// general case
|
|
let amount: u64 = 4294967295;
|
|
|
|
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
|
|
|
|
// Very informal measurements for now
|
|
let start_computation = Instant::now();
|
|
let decoded = instance.decode_u32();
|
|
let computation_secs = start_computation.elapsed().as_secs_f64();
|
|
|
|
assert_eq!(amount, decoded.unwrap());
|
|
|
|
println!("single thread discrete log computation secs: {computation_secs:?} sec");
|
|
}
|
|
|
|
#[test]
|
|
fn test_decode_correctness_threaded() {
|
|
// general case
|
|
let amount: u64 = 55;
|
|
|
|
let mut instance = DiscreteLog::new(G, Scalar::from(amount) * G);
|
|
instance.num_threads(4).unwrap();
|
|
|
|
// Very informal measurements for now
|
|
let start_computation = Instant::now();
|
|
let decoded = instance.decode_u32();
|
|
let computation_secs = start_computation.elapsed().as_secs_f64();
|
|
|
|
assert_eq!(amount, decoded.unwrap());
|
|
|
|
println!("4 thread discrete log computation: {computation_secs:?} sec");
|
|
|
|
// amount 0
|
|
let amount: u64 = 0;
|
|
|
|
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
|
|
|
|
let decoded = instance.decode_u32();
|
|
assert_eq!(amount, decoded.unwrap());
|
|
|
|
// amount 1
|
|
let amount: u64 = 1;
|
|
|
|
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
|
|
|
|
let decoded = instance.decode_u32();
|
|
assert_eq!(amount, decoded.unwrap());
|
|
|
|
// amount 2
|
|
let amount: u64 = 2;
|
|
|
|
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
|
|
|
|
let decoded = instance.decode_u32();
|
|
assert_eq!(amount, decoded.unwrap());
|
|
|
|
// amount 3
|
|
let amount: u64 = 3;
|
|
|
|
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
|
|
|
|
let decoded = instance.decode_u32();
|
|
assert_eq!(amount, decoded.unwrap());
|
|
|
|
// max amount
|
|
let amount: u64 = (1_u64 << 32) - 1;
|
|
|
|
let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
|
|
|
|
let decoded = instance.decode_u32();
|
|
assert_eq!(amount, decoded.unwrap());
|
|
}
|
|
}
|