diff --git a/zk-token-sdk/src/encryption/decode_u32_precomputation_for_G.bincode b/zk-token-sdk/src/encryption/decode_u32_precomputation_for_G.bincode index eb8b2a15a..7be1353b3 100644 Binary files a/zk-token-sdk/src/encryption/decode_u32_precomputation_for_G.bincode and b/zk-token-sdk/src/encryption/decode_u32_precomputation_for_G.bincode differ diff --git a/zk-token-sdk/src/encryption/discrete_log.rs b/zk-token-sdk/src/encryption/discrete_log.rs index a704ef806..2af579de2 100644 --- a/zk-token-sdk/src/encryption/discrete_log.rs +++ b/zk-token-sdk/src/encryption/discrete_log.rs @@ -6,10 +6,13 @@ use { std::collections::HashMap, }; -const TWO15: u32 = 32768; -const TWO14: u32 = 16384; // 2^14 - // const TWO16: u32 = 65536; // 2^16 -const TWO18: u32 = 262144; // 2^18 +#[allow(dead_code)] +const TWO15: u64 = 32768; +#[allow(dead_code)] +const TWO14: u64 = 16384; // 2^14 +const TWO16: u64 = 65536; // 2^16 +#[allow(dead_code)] +const TWO18: u64 = 262144; // 2^18 /// Type that captures a discrete log challenge. /// @@ -23,65 +26,58 @@ pub struct DiscreteLog { } #[derive(Serialize, Deserialize, Default)] -pub struct DecodeU32Precomputation(HashMap<[u8; 32], u32>); +pub struct DecodePrecomputation(HashMap<[u8; 32], u16>); -/// Builds a HashMap of 2^18 elements -fn decode_u32_precomputation(generator: RistrettoPoint) -> DecodeU32Precomputation { +/// Builds a HashMap of 2^16 elements +#[allow(dead_code)] +fn decode_u32_precomputation(generator: RistrettoPoint) -> DecodePrecomputation { let mut hashmap = HashMap::new(); - let two12_scalar = Scalar::from(TWO14); + let two16_scalar = Scalar::from(TWO16); let identity = RistrettoPoint::identity(); // 0 * G - let generator = two12_scalar * generator; // 2^12 * G + let generator = two16_scalar * generator; // 2^16 * G // iterator for 2^12*0G , 2^12*1G, 2^12*2G, ... let ristretto_iter = RistrettoIterator::new(identity, generator); - let mut steps_for_breakpoint = 0; - ristretto_iter.zip(0..TWO18).for_each(|(elem, x_hi)| { + ristretto_iter.zip(0..TWO16).for_each(|(elem, x_hi)| { let key = elem.compress().to_bytes(); - hashmap.insert(key, x_hi); - - // unclean way to print status update; will clean up later - if x_hi % TWO15 == 0 { - println!(" [{:?}/8] completed", steps_for_breakpoint); - steps_for_breakpoint += 1; - } + hashmap.insert(key, x_hi as u16); }); - println!(" [8/8] completed"); - DecodeU32Precomputation(hashmap) + DecodePrecomputation(hashmap) } lazy_static::lazy_static! { /// Pre-computed HashMap needed for decryption. The HashMap is independent of (works for) any key. - pub static ref DECODE_U32_PRECOMPUTATION_FOR_G: DecodeU32Precomputation = { - static DECODE_U32_PRECOMPUTATION_FOR_G_BINCODE: &[u8] = + pub static ref DECODE_PRECOMPUTATION_FOR_G: DecodePrecomputation = { + static DECODE_PRECOMPUTATION_FOR_G_BINCODE: &[u8] = include_bytes!("decode_u32_precomputation_for_G.bincode"); - bincode::deserialize(DECODE_U32_PRECOMPUTATION_FOR_G_BINCODE).unwrap_or_default() + bincode::deserialize(DECODE_PRECOMPUTATION_FOR_G_BINCODE).unwrap_or_default() }; } -/// Solves the discrete log instance using a 18/14 bit offline/online split +/// Solves the discrete log instance using a 16/16 bit offline/online split impl DiscreteLog { /// Solves the discrete log problem under the assumption that the solution /// is a 32-bit number. - pub(crate) fn decode_u32(self) -> Option { - self.decode_u32_online(&decode_u32_precomputation(self.generator)) + pub(crate) fn decode_u32(self) -> Option { + self.decode_online(&DECODE_PRECOMPUTATION_FOR_G, TWO16) } - /// Solves the discrete log instance using the pre-computed HashMap by enumerating through 2^14 - /// possible solutions - pub fn decode_u32_online(self, hashmap: &DecodeU32Precomputation) -> Option { + pub fn decode_online(self, hashmap: &DecodePrecomputation, solution_bound: u64) -> Option { // iterator for 0G, -1G, -2G, ... let ristretto_iter = RistrettoIterator::new(self.target, -self.generator); let mut decoded = None; - ristretto_iter.zip(0..TWO14).for_each(|(elem, x_lo)| { - let key = elem.compress().to_bytes(); - if hashmap.0.contains_key(&key) { - let x_hi = hashmap.0[&key]; - decoded = Some(x_lo + TWO14 * x_hi); - } - }); + ristretto_iter + .zip(0..solution_bound) + .for_each(|(elem, x_lo)| { + let key = elem.compress().to_bytes(); + if hashmap.0.contains_key(&key) { + let x_hi = hashmap.0[&key]; + decoded = Some(x_lo + solution_bound * x_hi as u64); + } + }); decoded } } @@ -122,7 +118,7 @@ mod tests { fn test_serialize_decode_u32_precomputation_for_G() { let decode_u32_precomputation_for_G = decode_u32_precomputation(G); - if decode_u32_precomputation_for_G.0 != DECODE_U32_PRECOMPUTATION_FOR_G.0 { + if decode_u32_precomputation_for_G.0 != DECODE_PRECOMPUTATION_FOR_G.0 { use std::{fs::File, io::Write, path::PathBuf}; let mut f = File::create(PathBuf::from( "src/encryption/decode_u32_precomputation_for_G.bincode", @@ -134,14 +130,9 @@ mod tests { } } - /// Discrete log test for 16/16 split - /// - /// Very informal measurements on my machine: - /// - 8 sec for precomputation - /// - 3 sec for online computation #[test] fn test_decode_correctness() { - let amount: u32 = 65545; + let amount: u64 = 65545; let instance = DiscreteLog { generator: G, @@ -154,7 +145,7 @@ mod tests { let precomputation_secs = start_precomputation.elapsed().as_secs_f64(); let start_online = Instant::now(); - let computed_amount = instance.decode_u32_online(&precomputed_hashmap).unwrap(); + let computed_amount = instance.decode_online(&precomputed_hashmap, TWO16).unwrap(); let online_secs = start_online.elapsed().as_secs_f64(); assert_eq!(amount, computed_amount); diff --git a/zk-token-sdk/src/encryption/elgamal.rs b/zk-token-sdk/src/encryption/elgamal.rs index 1cf02a6d2..67667c335 100644 --- a/zk-token-sdk/src/encryption/elgamal.rs +++ b/zk-token-sdk/src/encryption/elgamal.rs @@ -15,7 +15,7 @@ use { crate::encryption::{ - discrete_log::{DecodeU32Precomputation, DiscreteLog}, + discrete_log::DiscreteLog, pedersen::{Pedersen, PedersenCommitment, PedersenOpening, G, H}, }, arrayref::{array_ref, array_refs}, @@ -132,22 +132,10 @@ impl ElGamal { /// On input a secret key and a ciphertext, the function returns the decrypted message /// interpretted as type `u32`. #[cfg(not(target_arch = "bpf"))] - fn decrypt_u32(secret: &ElGamalSecretKey, ciphertext: &ElGamalCiphertext) -> Option { + fn decrypt_u32(secret: &ElGamalSecretKey, ciphertext: &ElGamalCiphertext) -> Option { let discrete_log_instance = Self::decrypt(secret, ciphertext); discrete_log_instance.decode_u32() } - - /// On input a secret key, a ciphertext, and a pre-computed hashmap, the function returns the - /// decrypted message interpretted as type `u32`. - #[cfg(not(target_arch = "bpf"))] - fn decrypt_u32_online( - secret: &ElGamalSecretKey, - ciphertext: &ElGamalCiphertext, - hashmap: &DecodeU32Precomputation, - ) -> Option { - let discrete_log_instance = Self::decrypt(secret, ciphertext); - discrete_log_instance.decode_u32_online(hashmap) - } } /// A (twisted) ElGamal encryption keypair. @@ -375,20 +363,10 @@ impl ElGamalSecretKey { } /// Decrypts a ciphertext using the ElGamal secret key interpretting the message as type `u32`. - pub fn decrypt_u32(&self, ciphertext: &ElGamalCiphertext) -> Option { + pub fn decrypt_u32(&self, ciphertext: &ElGamalCiphertext) -> Option { ElGamal::decrypt_u32(self, ciphertext) } - /// Decrypts a ciphertext using the ElGamal secret key and a pre-computed hashmap. It - /// interprets the decrypted message as type `u32`. - pub fn decrypt_u32_online( - &self, - ciphertext: &ElGamalCiphertext, - hashmap: &DecodeU32Precomputation, - ) -> Option { - ElGamal::decrypt_u32_online(self, ciphertext, hashmap) - } - pub fn as_bytes(&self) -> &[u8; 32] { self.0.as_bytes() } @@ -474,19 +452,9 @@ impl ElGamalCiphertext { } /// Decrypts the ciphertext using an ElGamal secret key interpretting the message as type `u32`. - pub fn decrypt_u32(&self, secret: &ElGamalSecretKey) -> Option { + pub fn decrypt_u32(&self, secret: &ElGamalSecretKey) -> Option { ElGamal::decrypt_u32(secret, self) } - - /// Decrypts the ciphertext using an ElGamal secret key and a pre-computed hashmap. It - /// interprets the decrypted message as type `u32`. - pub fn decrypt_u32_online( - &self, - secret: &ElGamalSecretKey, - hashmap: &DecodeU32Precomputation, - ) -> Option { - ElGamal::decrypt_u32_online(secret, self, hashmap) - } } impl<'a, 'b> Add<&'b ElGamalCiphertext> for &'a ElGamalCiphertext { @@ -606,7 +574,7 @@ define_mul_variants!(LHS = DecryptHandle, RHS = Scalar, Output = DecryptHandle); mod tests { use { super::*, - crate::encryption::{discrete_log::DECODE_U32_PRECOMPUTATION_FOR_G, pedersen::Pedersen}, + crate::encryption::pedersen::Pedersen, solana_sdk::{signature::Keypair, signer::null_signer::NullSigner}, }; @@ -622,12 +590,7 @@ mod tests { }; assert_eq!(expected_instance, ElGamal::decrypt(&secret, &ciphertext)); - assert_eq!( - 57_u32, - secret - .decrypt_u32_online(&ciphertext, &(*DECODE_U32_PRECOMPUTATION_FOR_G)) - .unwrap() - ); + assert_eq!(57_u64, secret.decrypt_u32(&ciphertext).unwrap()); } #[test] diff --git a/zk-token-sdk/src/errors.rs b/zk-token-sdk/src/errors.rs index 72b3ab77f..78e79ac1e 100644 --- a/zk-token-sdk/src/errors.rs +++ b/zk-token-sdk/src/errors.rs @@ -4,7 +4,6 @@ use { thiserror::Error, }; -// TODO: clean up errors for encryption #[derive(Error, Clone, Debug, Eq, PartialEq)] pub enum ProofError { #[error("proof generation failed")] diff --git a/zk-token-sdk/src/instruction/transfer.rs b/zk-token-sdk/src/instruction/transfer.rs index 272cbd39b..23cb4d42b 100644 --- a/zk-token-sdk/src/instruction/transfer.rs +++ b/zk-token-sdk/src/instruction/transfer.rs @@ -6,7 +6,6 @@ use { use { crate::{ encryption::{ - discrete_log::*, elgamal::{ DecryptHandle, ElGamalCiphertext, ElGamalKeypair, ElGamalPubkey, ElGamalSecretKey, }, @@ -211,11 +210,11 @@ impl TransferData { let ciphertext_lo = self.ciphertext_lo(role)?; let ciphertext_hi = self.ciphertext_hi(role)?; - let amount_lo = ciphertext_lo.decrypt_u32_online(sk, &DECODE_U32_PRECOMPUTATION_FOR_G); - let amount_hi = ciphertext_hi.decrypt_u32_online(sk, &DECODE_U32_PRECOMPUTATION_FOR_G); + let amount_lo = ciphertext_lo.decrypt_u32(sk); + let amount_hi = ciphertext_hi.decrypt_u32(sk); if let (Some(amount_lo), Some(amount_hi)) = (amount_lo, amount_hi) { - Ok((amount_lo as u64) + (TWO_32 * amount_hi as u64)) + Ok(amount_lo + TWO_32 * amount_hi) } else { Err(ProofError::Verification) } diff --git a/zk-token-sdk/src/instruction/transfer_with_fee.rs b/zk-token-sdk/src/instruction/transfer_with_fee.rs index f76e67bcb..feccc3f17 100644 --- a/zk-token-sdk/src/instruction/transfer_with_fee.rs +++ b/zk-token-sdk/src/instruction/transfer_with_fee.rs @@ -6,7 +6,6 @@ use { use { crate::{ encryption::{ - discrete_log::*, elgamal::{ DecryptHandle, ElGamalCiphertext, ElGamalKeypair, ElGamalPubkey, ElGamalSecretKey, }, @@ -33,7 +32,9 @@ use { }; #[cfg(not(target_arch = "bpf"))] -const MAX_FEE_BASIS_POINTS: u64 = 10000; +const MAX_FEE_BASIS_POINTS: u64 = 10_000; +#[cfg(not(target_arch = "bpf"))] +const ONE_IN_BASIS_POINTS: u128 = MAX_FEE_BASIS_POINTS as u128; #[cfg(not(target_arch = "bpf"))] const TRANSFER_WITH_FEE_SOURCE_AMOUNT_BIT_LENGTH: usize = 64; @@ -121,7 +122,8 @@ impl TransferWithFeeData { // calculate and encrypt fee let (fee_amount, delta_fee) = - calculate_fee(transfer_amount, fee_parameters.fee_rate_basis_points); + calculate_fee(transfer_amount, fee_parameters.fee_rate_basis_points) + .ok_or(ProofError::Generation)?; let below_max = u64::ct_gt(&fee_parameters.maximum_fee, &fee_amount); let fee_to_encrypt = @@ -219,11 +221,11 @@ impl TransferWithFeeData { let ciphertext_lo = self.ciphertext_lo(role)?; let ciphertext_hi = self.ciphertext_hi(role)?; - let amount_lo = ciphertext_lo.decrypt_u32_online(sk, &DECODE_U32_PRECOMPUTATION_FOR_G); - let amount_hi = ciphertext_hi.decrypt_u32_online(sk, &DECODE_U32_PRECOMPUTATION_FOR_G); + let amount_lo = ciphertext_lo.decrypt_u32(sk); + let amount_hi = ciphertext_hi.decrypt_u32(sk); if let (Some(amount_lo), Some(amount_hi)) = (amount_lo, amount_hi) { - Ok((amount_lo as u64) + (TWO_32 * amount_hi as u64)) + Ok(amount_lo + TWO_32 * amount_hi) } else { Err(ProofError::Verification) } @@ -632,17 +634,16 @@ impl FeeParameters { } #[cfg(not(target_arch = "bpf"))] -fn calculate_fee(transfer_amount: u64, fee_rate_basis_points: u16) -> (u64, u64) { - let fee_scaled = (transfer_amount as u128) * (fee_rate_basis_points as u128); - - let fee = (fee_scaled / MAX_FEE_BASIS_POINTS as u128) as u64; - let rem = (fee_scaled % MAX_FEE_BASIS_POINTS as u128) as u64; - - if rem == 0 { - (fee, rem) - } else { - (fee + 1, rem) +fn calculate_fee(transfer_amount: u64, fee_rate_basis_points: u16) -> Option<(u64, u64)> { + let numerator = (transfer_amount as u128).checked_mul(fee_rate_basis_points as u128)?; + let mut fee = numerator.checked_div(ONE_IN_BASIS_POINTS)?; + let remainder = numerator.checked_rem(ONE_IN_BASIS_POINTS)?; + if remainder > 0 { + fee = fee.checked_add(1)?; } + + let fee = u64::try_from(fee).ok()?; + Some((fee as u64, remainder as u64)) } #[cfg(not(target_arch = "bpf"))] diff --git a/zk-token-sdk/src/range_proof/inner_product.rs b/zk-token-sdk/src/range_proof/inner_product.rs index 60383a86d..45df5d39b 100644 --- a/zk-token-sdk/src/range_proof/inner_product.rs +++ b/zk-token-sdk/src/range_proof/inner_product.rs @@ -400,7 +400,6 @@ mod tests { }; #[test] - #[ignore] #[allow(non_snake_case)] fn test_basic_correctness() { let n = 32; diff --git a/zk-token-sdk/src/sigma_proofs/equality_proof.rs b/zk-token-sdk/src/sigma_proofs/equality_proof.rs index 7a5d4d370..079c93a9e 100644 --- a/zk-token-sdk/src/sigma_proofs/equality_proof.rs +++ b/zk-token-sdk/src/sigma_proofs/equality_proof.rs @@ -48,7 +48,7 @@ pub struct CtxtCommEqualityProof { #[allow(non_snake_case)] #[cfg(not(target_arch = "bpf"))] impl CtxtCommEqualityProof { - /// Equality proof constructor. + /// Equality proof constructor. The proof is with respect to a ciphertext and commitment. /// /// The function does *not* hash the public key, ciphertext, or commitment into the transcript. /// For security, the caller (the main protocol) should hash these public components prior to @@ -119,7 +119,7 @@ impl CtxtCommEqualityProof { } } - /// Equality proof verifier. TODO: wrt commitment + /// Equality proof verifier. The proof is with respect to a single ciphertext and commitment. /// /// * `source_pubkey` - The ElGamal pubkey associated with the ciphertext to be proved /// * `source_ciphertext` - The main ElGamal ciphertext to be proved @@ -245,7 +245,7 @@ pub struct CtxtCtxtEqualityProof { #[allow(non_snake_case)] #[cfg(not(target_arch = "bpf"))] impl CtxtCtxtEqualityProof { - /// Equality proof constructor. + /// Equality proof constructor. The proof is with respect to two ciphertexts. /// /// The function does *not* hash the public key, ciphertext, or commitment into the transcript. /// For security, the caller (the main protocol) should hash these public components prior to @@ -322,7 +322,7 @@ impl CtxtCtxtEqualityProof { } } - /// Equality proof verifier. + /// Equality proof verifier. The proof is with respect to two ciphertexts. /// /// * `source_pubkey` - The ElGamal pubkey associated with the first ciphertext to be proved /// * `destination_pubkey` - The ElGamal pubkey associated with the second ciphertext to be proved diff --git a/zk-token-sdk/src/zk_token_elgamal/decryption.rs b/zk-token-sdk/src/zk_token_elgamal/decryption.rs new file mode 100644 index 000000000..07867c379 --- /dev/null +++ b/zk-token-sdk/src/zk_token_elgamal/decryption.rs @@ -0,0 +1,35 @@ +#[cfg(not(target_arch = "bpf"))] +use crate::{ + encryption::elgamal::{ElGamalCiphertext, ElGamalSecretKey}, + zk_token_elgamal::pod, +}; + +#[cfg(not(target_arch = "bpf"))] +impl pod::ElGamalCiphertext { + pub fn decrypt(self, secret_key: &ElGamalSecretKey) -> Option { + let deserialized_ciphertext: Option = self.try_into().ok(); + if let Some(ciphertext) = deserialized_ciphertext { + ciphertext.decrypt_u32(secret_key) + } else { + None + } + } +} + +#[cfg(test)] +mod tests { + use {super::*, crate::encryption::elgamal::ElGamalKeypair}; + + #[test] + fn test_pod_decryption() { + let keypair = ElGamalKeypair::new_rand(); + + let pod_ciphertext = pod::ElGamalCiphertext([0u8; 64]); + assert_eq!(pod_ciphertext.decrypt(&keypair.secret).unwrap(), 0); + + let amount = 55_u64; + let ciphertext = keypair.public.encrypt(amount); + let pod_ciphertext: pod::ElGamalCiphertext = ciphertext.into(); + assert_eq!(pod_ciphertext.decrypt(&keypair.secret).unwrap(), 55); + } +} diff --git a/zk-token-sdk/src/zk_token_elgamal/mod.rs b/zk-token-sdk/src/zk_token_elgamal/mod.rs index 2663aef6e..73a4ce08e 100644 --- a/zk-token-sdk/src/zk_token_elgamal/mod.rs +++ b/zk-token-sdk/src/zk_token_elgamal/mod.rs @@ -1,3 +1,4 @@ pub mod convert; +pub mod decryption; pub mod ops; pub mod pod;