clean up ElGamal decryption

This commit is contained in:
Sam Kim 2021-09-30 11:54:14 -04:00 committed by Michael Vines
parent 409b55ad81
commit d20d03cd7f
3 changed files with 74 additions and 187 deletions

View File

@ -1,17 +1,12 @@
use core::ops::{Add, Neg, Sub};
use curve25519_dalek::constants::RISTRETTO_BASEPOINT_POINT as G;
use curve25519_dalek::ristretto::RistrettoPoint;
use curve25519_dalek::scalar::Scalar;
use curve25519_dalek::traits::Identity;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use serde::{Deserialize, Serialize};
use {
curve25519_dalek::{ristretto::RistrettoPoint, scalar::Scalar, traits::Identity},
serde::{Deserialize, Serialize},
std::collections::HashMap,
};
const TWO15: u32 = 32768;
const TWO14: u32 = 16384; // 2^14
const TWO16: u32 = 65536; // 2^16
// const TWO16: u32 = 65536; // 2^16
const TWO18: u32 = 262144; // 2^18
/// Type that captures a discrete log challenge.
@ -25,93 +20,52 @@ pub struct DiscreteLogInstance {
pub target: RistrettoPoint,
}
/// Solves the discrete log instance using a 16/16 bit offline/online split
impl DiscreteLogInstance {
/// Solves the discrete log problem under the assumption that the solution
/// is a 32-bit number.
pub fn decode_u32(self) -> Option<u32> {
let hashmap = DiscreteLogInstance::decode_u32_precomputation(self.generator);
self.decode_u32_online(&hashmap)
}
/// Builds a HashMap of 2^18 elements
pub fn decode_u32_precomputation(generator: RistrettoPoint) -> HashMap<[u8; 32], u32> {
let mut hashmap = HashMap::new();
/// Builds a HashMap of 2^16 elements
pub fn decode_u32_precomputation(generator: RistrettoPoint) -> HashMap<HashableRistretto, u32> {
let mut hashmap = HashMap::new();
let two12_scalar = Scalar::from(TWO14);
let identity = RistrettoPoint::identity(); // 0 * G
let generator = two12_scalar * generator; // 2^12 * G
let two16_scalar = Scalar::from(TWO16);
let identity = HashableRistretto(RistrettoPoint::identity()); // 0 * G
let generator = HashableRistretto(two16_scalar * generator); // 2^16 * G
// iterator for 2^12*0G , 2^12*1G, 2^12*2G, ...
let ristretto_iter = RistrettoIterator::new(identity, generator);
let mut steps_for_breakpoint = 0;
ristretto_iter.zip(0..TWO18).for_each(|(elem, x_hi)| {
let key = elem.compress().to_bytes();
hashmap.insert(key, x_hi);
// iterator for 2^16*0G , 2^16*1G, 2^16*2G, ...
let ristretto_iter = RistrettoIterator::new(identity, generator);
ristretto_iter.zip(0..TWO16).for_each(|(elem, x_hi)| {
hashmap.insert(elem, x_hi);
});
// unclean way to print status update; will clean up later
if x_hi % TWO15 == 0 {
println!(" [{:?}/8] completed", steps_for_breakpoint);
steps_for_breakpoint += 1;
}
});
println!(" [8/8] completed");
hashmap
}
/// Solves the discrete log instance using the pre-computed HashMap by enumerating through 2^16
/// possible solutions
pub fn decode_u32_online(self, hashmap: &HashMap<HashableRistretto, u32>) -> Option<u32> {
// iterator for 0G, -1G, -2G, ...
let ristretto_iter = RistrettoIterator::new(
HashableRistretto(self.target),
HashableRistretto(-self.generator),
);
let mut decoded = None;
ristretto_iter.zip(0..TWO16).for_each(|(elem, x_lo)| {
if hashmap.contains_key(&elem) {
let x_hi = hashmap[&elem];
decoded = Some(x_lo + TWO16 * x_hi);
}
});
decoded
}
hashmap
}
/// Solves the discrete log instance using a 18/14 bit offline/online split
impl DiscreteLogInstance {
/// Solves the discrete log problem under the assumption that the solution
/// is a 32-bit number.
pub fn decode_u32_alt(self) -> Option<u32> {
let hashmap = DiscreteLogInstance::decode_u32_precomputation_alt(self.generator);
self.decode_u32_online_alt(&hashmap)
}
/// Builds a HashMap of 2^18 elements
pub fn decode_u32_precomputation_alt(
generator: RistrettoPoint,
) -> HashMap<HashableRistretto, u32> {
let mut hashmap = HashMap::new();
let two12_scalar = Scalar::from(TWO14);
let identity = HashableRistretto(RistrettoPoint::identity()); // 0 * G
let generator = HashableRistretto(two12_scalar * generator); // 2^12 * G
// iterator for 2^12*0G , 2^12*1G, 2^12*2G, ...
let ristretto_iter = RistrettoIterator::new(identity, generator);
ristretto_iter.zip(0..TWO18).for_each(|(elem, x_hi)| {
hashmap.insert(elem, x_hi);
});
hashmap
pub fn decode_u32(self) -> Option<u32> {
let hashmap = decode_u32_precomputation(self.generator);
self.decode_u32_online(&hashmap)
}
/// Solves the discrete log instance using the pre-computed HashMap by enumerating through 2^14
/// possible solutions
pub fn decode_u32_online_alt(self, hashmap: &HashMap<HashableRistretto, u32>) -> Option<u32> {
pub fn decode_u32_online(self, hashmap: &HashMap<[u8; 32], u32>) -> Option<u32> {
// iterator for 0G, -1G, -2G, ...
let ristretto_iter = RistrettoIterator::new(
HashableRistretto(self.target),
HashableRistretto(-self.generator),
);
let ristretto_iter = RistrettoIterator::new(self.target, -self.generator);
let mut decoded = None;
ristretto_iter.zip(0..TWO14).for_each(|(elem, x_lo)| {
if hashmap.contains_key(&elem) {
let x_hi = hashmap[&elem];
let key = elem.compress().to_bytes();
if hashmap.contains_key(&key) {
let x_hi = hashmap[&key];
decoded = Some(x_lo + TWO14 * x_hi);
}
});
@ -119,92 +73,34 @@ impl DiscreteLogInstance {
}
}
/// Type wrapper for RistrettoPoint that implements the Hash trait
#[derive(Serialize, Deserialize, Copy, Clone, Debug, Eq)]
pub struct HashableRistretto(pub RistrettoPoint);
impl HashableRistretto {
pub fn encode<T: Into<Scalar>>(amount: T) -> Self {
HashableRistretto(amount.into() * G)
}
}
impl Hash for HashableRistretto {
fn hash<H: Hasher>(&self, state: &mut H) {
bincode::serialize(self).unwrap().hash(state);
}
}
impl PartialEq for HashableRistretto {
fn eq(&self, other: &Self) -> bool {
self == other
}
}
/// HashableRistretto iterator.
///
/// Given an initial point X and a stepping point P, the iterator iterates through
/// X + 0*P, X + 1*P, X + 2*P, X + 3*P, ...
struct RistrettoIterator {
pub curr: HashableRistretto,
pub step: HashableRistretto,
pub curr: RistrettoPoint,
pub step: RistrettoPoint,
}
impl RistrettoIterator {
fn new(curr: HashableRistretto, step: HashableRistretto) -> Self {
fn new(curr: RistrettoPoint, step: RistrettoPoint) -> Self {
RistrettoIterator { curr, step }
}
}
impl Iterator for RistrettoIterator {
type Item = HashableRistretto;
type Item = RistrettoPoint;
fn next(&mut self) -> Option<Self::Item> {
let r = self.curr;
self.curr = self.curr + self.step;
self.curr += self.step;
Some(r)
}
}
impl<'a, 'b> Add<&'b HashableRistretto> for &'a HashableRistretto {
type Output = HashableRistretto;
fn add(self, other: &HashableRistretto) -> HashableRistretto {
HashableRistretto(self.0 + other.0)
}
}
define_add_variants!(
LHS = HashableRistretto,
RHS = HashableRistretto,
Output = HashableRistretto
);
impl<'a, 'b> Sub<&'b HashableRistretto> for &'a HashableRistretto {
type Output = HashableRistretto;
fn sub(self, other: &HashableRistretto) -> HashableRistretto {
HashableRistretto(self.0 - other.0)
}
}
define_sub_variants!(
LHS = HashableRistretto,
RHS = HashableRistretto,
Output = HashableRistretto
);
impl Neg for HashableRistretto {
type Output = HashableRistretto;
fn neg(self) -> HashableRistretto {
HashableRistretto(-self.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
use {super::*, curve25519_dalek::constants::RISTRETTO_BASEPOINT_POINT as G};
/// Discrete log test for 16/16 split
///
@ -212,7 +108,6 @@ mod tests {
/// - 8 sec for precomputation
/// - 3 sec for online computation
#[test]
#[ignore]
fn test_decode_correctness() {
let amount: u32 = 65545;
@ -223,7 +118,7 @@ mod tests {
// Very informal measurements for now
let start_precomputation = time::precise_time_s();
let precomputed_hashmap = DiscreteLogInstance::decode_u32_precomputation(G);
let precomputed_hashmap = decode_u32_precomputation(G);
let end_precomputation = time::precise_time_s();
let start_online = time::precise_time_s();
@ -241,42 +136,4 @@ mod tests {
end_online - start_online
);
}
/// Discrete log test for 18/14 split
///
/// Very informal measurements on my machine:
/// - 33 sec for precomputation
/// - 0.8 sec for online computation
#[test]
#[ignore]
fn test_decode_alt_correctness() {
let amount: u32 = 65545;
let instance = DiscreteLogInstance {
generator: G,
target: Scalar::from(amount) * G,
};
// Very informal measurements for now
let start_precomputation = time::precise_time_s();
let precomputed_hashmap = DiscreteLogInstance::decode_u32_precomputation_alt(G);
let end_precomputation = time::precise_time_s();
let start_online = time::precise_time_s();
let computed_amount = instance
.decode_u32_online_alt(&precomputed_hashmap)
.unwrap();
let end_online = time::precise_time_s();
assert_eq!(amount, computed_amount);
println!(
"18/14 Split precomputation: {:?} sec",
end_precomputation - start_precomputation
);
println!(
"18/14 Split online computation: {:?} sec",
end_online - start_online
);
}
}

View File

@ -2,7 +2,7 @@
use rand::{rngs::OsRng, CryptoRng, RngCore};
use {
crate::encryption::{
encode::DiscreteLogInstance,
dlog::DiscreteLogInstance,
pedersen::{Pedersen, PedersenBase, PedersenComm, PedersenDecHandle, PedersenOpen},
},
arrayref::{array_ref, array_refs},
@ -12,6 +12,7 @@ use {
scalar::Scalar,
},
serde::{Deserialize, Serialize},
std::collections::HashMap,
std::convert::TryInto,
subtle::{Choice, ConstantTimeEq},
zeroize::Zeroize,
@ -100,6 +101,17 @@ impl ElGamal {
let discrete_log_instance = ElGamal::decrypt(sk, ct);
discrete_log_instance.decode_u32()
}
/// On input a secret key, ciphertext, and hashmap, the function decrypts the
/// ciphertext for a u32 value.
pub fn decrypt_u32_online(
sk: &ElGamalSK,
ct: &ElGamalCiphertext,
hashmap: &HashMap<[u8; 32], u32>,
) -> Option<u32> {
let discrete_log_instance = ElGamal::decrypt(sk, ct);
discrete_log_instance.decode_u32_online(hashmap)
}
}
/// Public key for the ElGamal encryption scheme.
@ -164,6 +176,15 @@ impl ElGamalSK {
ElGamal::decrypt_u32(self, ct)
}
/// Utility method for code ergonomics.
pub fn decrypt_u32_online(
&self,
ct: &ElGamalCiphertext,
hashmap: &HashMap<[u8; 32], u32>,
) -> Option<u32> {
ElGamal::decrypt_u32_online(self, ct, hashmap)
}
pub fn to_bytes(&self) -> [u8; 32] {
self.0.to_bytes()
}
@ -249,6 +270,15 @@ impl ElGamalCiphertext {
pub fn decrypt_u32(&self, sk: &ElGamalSK) -> Option<u32> {
ElGamal::decrypt_u32(sk, self)
}
/// Utility method for code ergonomics.
pub fn decrypt_u32_online(
&self,
sk: &ElGamalSK,
hashmap: &HashMap<[u8; 32], u32>,
) -> Option<u32> {
ElGamal::decrypt_u32_online(sk, self, hashmap)
}
}
impl<'a, 'b> Add<&'b ElGamalCiphertext> for &'a ElGamalCiphertext {

View File

@ -1,3 +1,3 @@
pub mod elgamal;
pub mod dlog;
pub mod elgamal;
pub mod pedersen;