From d01d425e4bb5889a695eb246d70d08ed72808f95 Mon Sep 17 00:00:00 2001 From: Michael Vines Date: Wed, 29 Sep 2021 21:45:35 -0700 Subject: [PATCH] Rename crypto crate to sdk --- zk-token-sdk/Cargo.toml | 33 + zk-token-sdk/src/encryption/elgamal.rs | 468 ++++++++++++++ zk-token-sdk/src/encryption/encode.rs | 276 ++++++++ zk-token-sdk/src/encryption/mod.rs | 3 + zk-token-sdk/src/encryption/pedersen.rs | 484 ++++++++++++++ zk-token-sdk/src/errors.rs | 19 + zk-token-sdk/src/instruction/close_account.rs | 163 +++++ zk-token-sdk/src/instruction/mod.rs | 21 + zk-token-sdk/src/instruction/transfer.rs | 546 ++++++++++++++++ .../src/instruction/update_account_pk.rs | 271 ++++++++ zk-token-sdk/src/instruction/withdraw.rs | 207 ++++++ zk-token-sdk/src/lib.rs | 19 + zk-token-sdk/src/macros.rs | 101 +++ zk-token-sdk/src/pod.rs | 596 ++++++++++++++++++ zk-token-sdk/src/range_proof/generators.rs | 156 +++++ zk-token-sdk/src/range_proof/inner_product.rs | 505 +++++++++++++++ zk-token-sdk/src/range_proof/mod.rs | 456 ++++++++++++++ zk-token-sdk/src/range_proof/util.rs | 138 ++++ zk-token-sdk/src/transcript.rs | 117 ++++ .../src/zk_token_proof_instruction.rs | 91 +++ zk-token-sdk/src/zk_token_proof_program.rs | 2 + 21 files changed, 4672 insertions(+) create mode 100644 zk-token-sdk/Cargo.toml create mode 100644 zk-token-sdk/src/encryption/elgamal.rs create mode 100644 zk-token-sdk/src/encryption/encode.rs create mode 100644 zk-token-sdk/src/encryption/mod.rs create mode 100644 zk-token-sdk/src/encryption/pedersen.rs create mode 100644 zk-token-sdk/src/errors.rs create mode 100644 zk-token-sdk/src/instruction/close_account.rs create mode 100644 zk-token-sdk/src/instruction/mod.rs create mode 100644 zk-token-sdk/src/instruction/transfer.rs create mode 100644 zk-token-sdk/src/instruction/update_account_pk.rs create mode 100644 zk-token-sdk/src/instruction/withdraw.rs create mode 100644 zk-token-sdk/src/lib.rs create mode 100644 zk-token-sdk/src/macros.rs create mode 100644 zk-token-sdk/src/pod.rs create mode 100644 zk-token-sdk/src/range_proof/generators.rs create mode 100644 zk-token-sdk/src/range_proof/inner_product.rs create mode 100644 zk-token-sdk/src/range_proof/mod.rs create mode 100644 zk-token-sdk/src/range_proof/util.rs create mode 100644 zk-token-sdk/src/transcript.rs create mode 100644 zk-token-sdk/src/zk_token_proof_instruction.rs create mode 100644 zk-token-sdk/src/zk_token_proof_program.rs diff --git a/zk-token-sdk/Cargo.toml b/zk-token-sdk/Cargo.toml new file mode 100644 index 000000000..f80bdf94c --- /dev/null +++ b/zk-token-sdk/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "spl-zk-token-sdk" +description = "Solana Program Library ZkToken SDK" +authors = ["Solana Maintainers "] +repository = "https://github.com/solana-labs/solana-program-library" +version = "0.1.0" +license = "Apache-2.0" +edition = "2018" +publish = false + +[dependencies] +bytemuck = { version = "1.7.2", features = ["derive"] } +num-derive = "0.3" +num-traits = "0.2" +solana-program = "=1.7.13" + +[target.'cfg(not(target_arch = "bpf"))'.dependencies] +arrayref = "0.3.6" +bincode = "1" +byteorder = "1" +clear_on_drop = "0.2" +curve25519-dalek = { version = "3.2.0", features = ["serde"]} +getrandom = { version = "0.1", features = ["dummy"] } +merlin = "2" +rand = "0.7" +serde = { version = "1.0", features = ["derive"] } +sha3 = "0.9" +subtle = "2" +thiserror = "1" +zeroize = { version = "1.2.0", default-features = false, features = ["zeroize_derive"] } + +[dev-dependencies] +time = "0.1.40" diff --git a/zk-token-sdk/src/encryption/elgamal.rs b/zk-token-sdk/src/encryption/elgamal.rs new file mode 100644 index 000000000..e9e5af411 --- /dev/null +++ b/zk-token-sdk/src/encryption/elgamal.rs @@ -0,0 +1,468 @@ +#[cfg(not(target_arch = "bpf"))] +use rand::{rngs::OsRng, CryptoRng, RngCore}; +use { + crate::encryption::{ + encode::DiscreteLogInstance, + pedersen::{Pedersen, PedersenBase, PedersenComm, PedersenDecHandle, PedersenOpen}, + }, + arrayref::{array_ref, array_refs}, + core::ops::{Add, Div, Mul, Sub}, + curve25519_dalek::{ + ristretto::{CompressedRistretto, RistrettoPoint}, + scalar::Scalar, + }, + serde::{Deserialize, Serialize}, + std::convert::TryInto, + subtle::{Choice, ConstantTimeEq}, + zeroize::Zeroize, +}; + +/// Handle for the (twisted) ElGamal encryption scheme +pub struct ElGamal; +impl ElGamal { + /// Generates the public and secret keys for ElGamal encryption. + #[cfg(not(target_arch = "bpf"))] + pub fn keygen() -> (ElGamalPK, ElGamalSK) { + ElGamal::keygen_with(&mut OsRng) // using OsRng for now + } + + /// On input a randomness generator, the function generates the public and + /// secret keys for ElGamal encryption. + #[cfg(not(target_arch = "bpf"))] + #[allow(non_snake_case)] + pub fn keygen_with(rng: &mut T) -> (ElGamalPK, ElGamalSK) { + // sample a non-zero scalar + let mut s: Scalar; + loop { + s = Scalar::random(rng); + + if s != Scalar::zero() { + break; + } + } + + let H = PedersenBase::default().H; + let P = s.invert() * H; + + (ElGamalPK(P), ElGamalSK(s)) + } + + /// On input a public key and a message to be encrypted, the function + /// returns an ElGamal ciphertext of the message under the public key. + #[cfg(not(target_arch = "bpf"))] + pub fn encrypt>(pk: &ElGamalPK, amount: T) -> ElGamalCT { + let (message_comm, open) = Pedersen::commit(amount); + let decrypt_handle = pk.gen_decrypt_handle(&open); + + ElGamalCT { + message_comm, + decrypt_handle, + } + } + + /// On input a public key, message, and Pedersen opening, the function + /// returns an ElGamal ciphertext of the message under the public key using + /// the opening. + pub fn encrypt_with>( + pk: &ElGamalPK, + amount: T, + open: &PedersenOpen, + ) -> ElGamalCT { + let message_comm = Pedersen::commit_with(amount, open); + let decrypt_handle = pk.gen_decrypt_handle(open); + + ElGamalCT { + message_comm, + decrypt_handle, + } + } + + /// On input a secret key and a ciphertext, the function decrypts the ciphertext. + /// + /// The output of the function is of type `DiscreteLogInstance`. The exact message + /// can be recovered via the DiscreteLogInstance's decode method. + pub fn decrypt(sk: &ElGamalSK, ct: &ElGamalCT) -> DiscreteLogInstance { + let ElGamalSK(s) = sk; + let ElGamalCT { + message_comm, + decrypt_handle, + } = ct; + + DiscreteLogInstance { + generator: PedersenBase::default().G, + target: message_comm.get_point() - s * decrypt_handle.get_point(), + } + } + + /// On input a secret key and a ciphertext, the function decrypts the + /// ciphertext for a u32 value. + pub fn decrypt_u32(sk: &ElGamalSK, ct: &ElGamalCT) -> Option { + let discrete_log_instance = ElGamal::decrypt(sk, ct); + discrete_log_instance.decode_u32() + } +} + +/// Public key for the ElGamal encryption scheme. +#[derive(Serialize, Deserialize, Default, Clone, Copy, Debug, Eq, PartialEq)] +pub struct ElGamalPK(RistrettoPoint); +impl ElGamalPK { + pub fn get_point(&self) -> RistrettoPoint { + self.0 + } + + #[allow(clippy::wrong_self_convention)] + pub fn to_bytes(&self) -> [u8; 32] { + self.0.compress().to_bytes() + } + + pub fn from_bytes(bytes: &[u8]) -> Option { + Some(ElGamalPK( + CompressedRistretto::from_slice(bytes).decompress()?, + )) + } + + /// Utility method for code ergonomics. + #[cfg(not(target_arch = "bpf"))] + pub fn encrypt>(&self, msg: T) -> ElGamalCT { + ElGamal::encrypt(self, msg) + } + + /// Utility method for code ergonomics. + pub fn encrypt_with>(&self, msg: T, open: &PedersenOpen) -> ElGamalCT { + ElGamal::encrypt_with(self, msg, open) + } + + /// Generate a decryption token from an ElGamal public key and a Pedersen + /// opening. + pub fn gen_decrypt_handle(self, open: &PedersenOpen) -> PedersenDecHandle { + PedersenDecHandle::generate_handle(open, &self) + } +} + +impl From for ElGamalPK { + fn from(point: RistrettoPoint) -> ElGamalPK { + ElGamalPK(point) + } +} + +/// Secret key for the ElGamal encryption scheme. +#[derive(Serialize, Deserialize, Debug, Zeroize)] +#[zeroize(drop)] +pub struct ElGamalSK(Scalar); +impl ElGamalSK { + pub fn get_scalar(&self) -> Scalar { + self.0 + } + + /// Utility method for code ergonomics. + pub fn decrypt(&self, ct: &ElGamalCT) -> DiscreteLogInstance { + ElGamal::decrypt(self, ct) + } + + /// Utility method for code ergonomics. + pub fn decrypt_u32(&self, ct: &ElGamalCT) -> Option { + ElGamal::decrypt_u32(self, ct) + } + + pub fn to_bytes(&self) -> [u8; 32] { + self.0.to_bytes() + } + + pub fn from_bytes(bytes: &[u8]) -> Option { + match bytes.try_into() { + Ok(bytes) => Scalar::from_canonical_bytes(bytes).map(ElGamalSK), + _ => None, + } + } +} + +impl From for ElGamalSK { + fn from(scalar: Scalar) -> ElGamalSK { + ElGamalSK(scalar) + } +} + +impl Eq for ElGamalSK {} +impl PartialEq for ElGamalSK { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).unwrap_u8() == 1u8 + } +} +impl ConstantTimeEq for ElGamalSK { + fn ct_eq(&self, other: &Self) -> Choice { + self.0.ct_eq(&other.0) + } +} + +/// Ciphertext for the ElGamal encryption scheme. +#[allow(non_snake_case)] +#[derive(Serialize, Deserialize, Default, Clone, Copy, Debug, Eq, PartialEq)] +pub struct ElGamalCT { + pub message_comm: PedersenComm, + pub decrypt_handle: PedersenDecHandle, +} +impl ElGamalCT { + pub fn add_to_msg>(&self, message: T) -> Self { + let diff_comm = Pedersen::commit_with(message, &PedersenOpen::default()); + ElGamalCT { + message_comm: self.message_comm + diff_comm, + decrypt_handle: self.decrypt_handle, + } + } + + pub fn sub_to_msg>(&self, message: T) -> Self { + let diff_comm = Pedersen::commit_with(message, &PedersenOpen::default()); + ElGamalCT { + message_comm: self.message_comm - diff_comm, + decrypt_handle: self.decrypt_handle, + } + } + + #[allow(clippy::wrong_self_convention)] + pub fn to_bytes(&self) -> [u8; 64] { + let mut bytes = [0u8; 64]; + + bytes[..32].copy_from_slice(self.message_comm.get_point().compress().as_bytes()); + bytes[32..].copy_from_slice(self.decrypt_handle.get_point().compress().as_bytes()); + bytes + } + + pub fn from_bytes(bytes: &[u8]) -> Option { + let bytes = array_ref![bytes, 0, 64]; + let (message_comm, decrypt_handle) = array_refs![bytes, 32, 32]; + + let message_comm = CompressedRistretto::from_slice(message_comm).decompress()?; + let decrypt_handle = CompressedRistretto::from_slice(decrypt_handle).decompress()?; + + Some(ElGamalCT { + message_comm: PedersenComm(message_comm), + decrypt_handle: PedersenDecHandle(decrypt_handle), + }) + } + + /// Utility method for code ergonomics. + pub fn decrypt(&self, sk: &ElGamalSK) -> DiscreteLogInstance { + ElGamal::decrypt(sk, self) + } + + /// Utility method for code ergonomics. + pub fn decrypt_u32(&self, sk: &ElGamalSK) -> Option { + ElGamal::decrypt_u32(sk, self) + } +} + +impl<'a, 'b> Add<&'b ElGamalCT> for &'a ElGamalCT { + type Output = ElGamalCT; + + fn add(self, other: &'b ElGamalCT) -> ElGamalCT { + ElGamalCT { + message_comm: self.message_comm + other.message_comm, + decrypt_handle: self.decrypt_handle + other.decrypt_handle, + } + } +} + +define_add_variants!(LHS = ElGamalCT, RHS = ElGamalCT, Output = ElGamalCT); + +impl<'a, 'b> Sub<&'b ElGamalCT> for &'a ElGamalCT { + type Output = ElGamalCT; + + fn sub(self, other: &'b ElGamalCT) -> ElGamalCT { + ElGamalCT { + message_comm: self.message_comm - other.message_comm, + decrypt_handle: self.decrypt_handle - other.decrypt_handle, + } + } +} + +define_sub_variants!(LHS = ElGamalCT, RHS = ElGamalCT, Output = ElGamalCT); + +impl<'a, 'b> Mul<&'b Scalar> for &'a ElGamalCT { + type Output = ElGamalCT; + + fn mul(self, other: &'b Scalar) -> ElGamalCT { + ElGamalCT { + message_comm: self.message_comm * other, + decrypt_handle: self.decrypt_handle * other, + } + } +} + +define_mul_variants!(LHS = ElGamalCT, RHS = Scalar, Output = ElGamalCT); + +impl<'a, 'b> Div<&'b Scalar> for &'a ElGamalCT { + type Output = ElGamalCT; + + fn div(self, other: &'b Scalar) -> ElGamalCT { + ElGamalCT { + message_comm: self.message_comm * other.invert(), + decrypt_handle: self.decrypt_handle * other.invert(), + } + } +} + +define_div_variants!(LHS = ElGamalCT, RHS = Scalar, Output = ElGamalCT); + +#[cfg(test)] +mod tests { + use super::*; + use crate::encryption::pedersen::Pedersen; + + #[test] + fn test_encrypt_decrypt_correctness() { + let (pk, sk) = ElGamal::keygen(); + let msg: u32 = 57; + let ct = ElGamal::encrypt(&pk, msg); + + let expected_instance = DiscreteLogInstance { + generator: PedersenBase::default().G, + target: Scalar::from(msg) * PedersenBase::default().G, + }; + + assert_eq!(expected_instance, ElGamal::decrypt(&sk, &ct)); + + // Commenting out for faster testing + // assert_eq!(msg, ElGamal::decrypt_u32(&sk, &ct).unwrap()); + } + + #[test] + fn test_decrypt_handle() { + let (pk_1, sk_1) = ElGamal::keygen(); + let (pk_2, sk_2) = ElGamal::keygen(); + + let msg: u32 = 77; + let (comm, open) = Pedersen::commit(msg); + + let decrypt_handle_1 = pk_1.gen_decrypt_handle(&open); + let decrypt_handle_2 = pk_2.gen_decrypt_handle(&open); + + let ct_1 = decrypt_handle_1.to_elgamal_ctxt(comm); + let ct_2 = decrypt_handle_2.to_elgamal_ctxt(comm); + + let expected_instance = DiscreteLogInstance { + generator: PedersenBase::default().G, + target: Scalar::from(msg) * PedersenBase::default().G, + }; + + assert_eq!(expected_instance, sk_1.decrypt(&ct_1)); + assert_eq!(expected_instance, sk_2.decrypt(&ct_2)); + + // Commenting out for faster testing + // assert_eq!(msg, sk_1.decrypt_u32(&ct_1).unwrap()); + // assert_eq!(msg, sk_2.decrypt_u32(&ct_2).unwrap()); + } + + #[test] + fn test_homomorphic_addition() { + let (pk, _) = ElGamal::keygen(); + let msg_0: u64 = 57; + let msg_1: u64 = 77; + + // Add two ElGamal ciphertexts + let open_0 = PedersenOpen::random(&mut OsRng); + let open_1 = PedersenOpen::random(&mut OsRng); + + let ct_0 = ElGamal::encrypt_with(&pk, msg_0, &open_0); + let ct_1 = ElGamal::encrypt_with(&pk, msg_1, &open_1); + + let ct_sum = ElGamal::encrypt_with(&pk, msg_0 + msg_1, &(open_0 + open_1)); + + assert_eq!(ct_sum, ct_0 + ct_1); + + // Add to ElGamal ciphertext + let open = PedersenOpen::random(&mut OsRng); + let ct = ElGamal::encrypt_with(&pk, msg_0, &open); + let ct_sum = ElGamal::encrypt_with(&pk, msg_0 + msg_1, &open); + + assert_eq!(ct_sum, ct.add_to_msg(msg_1)); + } + + #[test] + fn test_homomorphic_subtraction() { + let (pk, _) = ElGamal::keygen(); + let msg_0: u64 = 77; + let msg_1: u64 = 55; + + // Subtract two ElGamal ciphertexts + let open_0 = PedersenOpen::random(&mut OsRng); + let open_1 = PedersenOpen::random(&mut OsRng); + + let ct_0 = ElGamal::encrypt_with(&pk, msg_0, &open_0); + let ct_1 = ElGamal::encrypt_with(&pk, msg_1, &open_1); + + let ct_sub = ElGamal::encrypt_with(&pk, msg_0 - msg_1, &(open_0 - open_1)); + + assert_eq!(ct_sub, ct_0 - ct_1); + + // Subtract to ElGamal ciphertext + let open = PedersenOpen::random(&mut OsRng); + let ct = ElGamal::encrypt_with(&pk, msg_0, &open); + let ct_sub = ElGamal::encrypt_with(&pk, msg_0 - msg_1, &open); + + assert_eq!(ct_sub, ct.sub_to_msg(msg_1)); + } + + #[test] + fn test_homomorphic_multiplication() { + let (pk, _) = ElGamal::keygen(); + let msg_0: u64 = 57; + let msg_1: u64 = 77; + + let open = PedersenOpen::random(&mut OsRng); + + let ct = ElGamal::encrypt_with(&pk, msg_0, &open); + let scalar = Scalar::from(msg_1); + + let ct_prod = ElGamal::encrypt_with(&pk, msg_0 * msg_1, &(open * scalar)); + + assert_eq!(ct_prod, ct * scalar); + } + + #[test] + fn test_homomorphic_division() { + let (pk, _) = ElGamal::keygen(); + let msg_0: u64 = 55; + let msg_1: u64 = 5; + + let open = PedersenOpen::random(&mut OsRng); + + let ct = ElGamal::encrypt_with(&pk, msg_0, &open); + let scalar = Scalar::from(msg_1); + + let ct_div = ElGamal::encrypt_with(&pk, msg_0 / msg_1, &(open / scalar)); + + assert_eq!(ct_div, ct / scalar); + } + + #[test] + fn test_serde_ciphertext() { + let (pk, _) = ElGamal::keygen(); + let msg: u64 = 77; + let ct = pk.encrypt(msg); + + let encoded = bincode::serialize(&ct).unwrap(); + let decoded: ElGamalCT = bincode::deserialize(&encoded).unwrap(); + + assert_eq!(ct, decoded); + } + + #[test] + fn test_serde_pubkey() { + let (pk, _) = ElGamal::keygen(); + + let encoded = bincode::serialize(&pk).unwrap(); + let decoded: ElGamalPK = bincode::deserialize(&encoded).unwrap(); + + assert_eq!(pk, decoded); + } + + #[test] + fn test_serde_secretkey() { + let (_, sk) = ElGamal::keygen(); + + let encoded = bincode::serialize(&sk).unwrap(); + let decoded: ElGamalSK = bincode::deserialize(&encoded).unwrap(); + + assert_eq!(sk, decoded); + } +} diff --git a/zk-token-sdk/src/encryption/encode.rs b/zk-token-sdk/src/encryption/encode.rs new file mode 100644 index 000000000..f618afc7f --- /dev/null +++ b/zk-token-sdk/src/encryption/encode.rs @@ -0,0 +1,276 @@ +use core::ops::{Add, Neg, Sub}; + +use curve25519_dalek::constants::RISTRETTO_BASEPOINT_POINT as G; +use curve25519_dalek::ristretto::RistrettoPoint; +use curve25519_dalek::scalar::Scalar; +use curve25519_dalek::traits::Identity; + +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; + +use serde::{Deserialize, Serialize}; + +const TWO14: u32 = 16384; // 2^14 +const TWO16: u32 = 65536; // 2^16 +const TWO18: u32 = 262144; // 2^18 + +/// Type that captures a discrete log challenge. +/// +/// The goal of discrete log is to find x such that x * generator = target. +#[derive(Serialize, Deserialize, Copy, Clone, Debug, Eq, PartialEq)] +pub struct DiscreteLogInstance { + /// Generator point for discrete log + pub generator: RistrettoPoint, + /// Target point for discrete log + pub target: RistrettoPoint, +} + +/// Solves the discrete log instance using a 18/14 bit offline/online split +impl DiscreteLogInstance { + /// Solves the discrete log problem under the assumption that the solution + /// is a 32-bit number. + pub fn decode_u32(self) -> Option { + let hashmap = DiscreteLogInstance::decode_u32_precomputation(self.generator); + self.decode_u32_online(&hashmap) + } + + pub fn decode_u32_precomputation(generator: RistrettoPoint) -> HashMap { + let mut hashmap = HashMap::new(); + + let two16_scalar = Scalar::from(TWO16); + let identity = HashableRistretto(RistrettoPoint::identity()); // 0 * G + let generator = HashableRistretto(two16_scalar * generator); // 2^16 * G + + // iterator for 2^16*0G , 2^16*1G, 2^16*2G, ... + let ristretto_iter = RistrettoIterator::new(identity, generator); + ristretto_iter.zip(0..TWO16).for_each(|(elem, x_hi)| { + hashmap.insert(elem, x_hi); + }); + + hashmap + } + + pub fn decode_u32_online(self, hashmap: &HashMap) -> Option { + // iterator for 0G, -1G, -2G, ... + let ristretto_iter = RistrettoIterator::new( + HashableRistretto(self.target), + HashableRistretto(-self.generator), + ); + + let mut decoded = None; + ristretto_iter.zip(0..TWO16).for_each(|(elem, x_lo)| { + if hashmap.contains_key(&elem) { + let x_hi = hashmap[&elem]; + decoded = Some(x_lo + TWO16 * x_hi); + } + }); + decoded + } +} + +/// Solves the discrete log instance using a 18/14 bit offline/online split +impl DiscreteLogInstance { + /// Solves the discrete log problem under the assumption that the solution + /// is a 32-bit number. + pub fn decode_u32_alt(self) -> Option { + let hashmap = DiscreteLogInstance::decode_u32_precomputation_alt(self.generator); + self.decode_u32_online_alt(&hashmap) + } + + pub fn decode_u32_precomputation_alt( + generator: RistrettoPoint, + ) -> HashMap { + let mut hashmap = HashMap::new(); + + let two12_scalar = Scalar::from(TWO14); + let identity = HashableRistretto(RistrettoPoint::identity()); // 0 * G + let generator = HashableRistretto(two12_scalar * generator); // 2^12 * G + + // iterator for 2^12*0G , 2^12*1G, 2^12*2G, ... + let ristretto_iter = RistrettoIterator::new(identity, generator); + ristretto_iter.zip(0..TWO18).for_each(|(elem, x_hi)| { + hashmap.insert(elem, x_hi); + }); + + hashmap + } + + pub fn decode_u32_online_alt(self, hashmap: &HashMap) -> Option { + // iterator for 0G, -1G, -2G, ... + let ristretto_iter = RistrettoIterator::new( + HashableRistretto(self.target), + HashableRistretto(-self.generator), + ); + + let mut decoded = None; + ristretto_iter.zip(0..TWO14).for_each(|(elem, x_lo)| { + if hashmap.contains_key(&elem) { + let x_hi = hashmap[&elem]; + decoded = Some(x_lo + TWO14 * x_hi); + } + }); + decoded + } +} + +/// Type wrapper for RistrettoPoint that implements the Hash trait +#[derive(Serialize, Deserialize, Copy, Clone, Debug, Eq)] +pub struct HashableRistretto(pub RistrettoPoint); + +impl HashableRistretto { + pub fn encode>(amount: T) -> Self { + HashableRistretto(amount.into() * G) + } +} + +impl Hash for HashableRistretto { + fn hash(&self, state: &mut H) { + bincode::serialize(self).unwrap().hash(state); + } +} + +impl PartialEq for HashableRistretto { + fn eq(&self, other: &Self) -> bool { + self == other + } +} + +/// HashableRistretto iterator. +/// +/// Given an initial point X and a stepping point P, the iterator iterates through +/// X + 0*P, X + 1*P, X + 2*P, X + 3*P, ... +struct RistrettoIterator { + pub curr: HashableRistretto, + pub step: HashableRistretto, +} + +impl RistrettoIterator { + fn new(curr: HashableRistretto, step: HashableRistretto) -> Self { + RistrettoIterator { curr, step } + } +} + +impl Iterator for RistrettoIterator { + type Item = HashableRistretto; + + fn next(&mut self) -> Option { + let r = self.curr; + self.curr = self.curr + self.step; + Some(r) + } +} + +impl<'a, 'b> Add<&'b HashableRistretto> for &'a HashableRistretto { + type Output = HashableRistretto; + + fn add(self, other: &HashableRistretto) -> HashableRistretto { + HashableRistretto(self.0 + other.0) + } +} + +define_add_variants!( + LHS = HashableRistretto, + RHS = HashableRistretto, + Output = HashableRistretto +); + +impl<'a, 'b> Sub<&'b HashableRistretto> for &'a HashableRistretto { + type Output = HashableRistretto; + + fn sub(self, other: &HashableRistretto) -> HashableRistretto { + HashableRistretto(self.0 - other.0) + } +} + +define_sub_variants!( + LHS = HashableRistretto, + RHS = HashableRistretto, + Output = HashableRistretto +); + +impl Neg for HashableRistretto { + type Output = HashableRistretto; + + fn neg(self) -> HashableRistretto { + HashableRistretto(-self.0) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Discrete log test for 16/16 split + /// + /// Very informal measurements on my machine: + /// - 8 sec for precomputation + /// - 3 sec for online computation + #[test] + #[ignore] + fn test_decode_correctness() { + let amount: u32 = 65545; + + let instance = DiscreteLogInstance { + generator: G, + target: Scalar::from(amount) * G, + }; + + // Very informal measurements for now + let start_precomputation = time::precise_time_s(); + let precomputed_hashmap = DiscreteLogInstance::decode_u32_precomputation(G); + let end_precomputation = time::precise_time_s(); + + let start_online = time::precise_time_s(); + let computed_amount = instance.decode_u32_online(&precomputed_hashmap).unwrap(); + let end_online = time::precise_time_s(); + + assert_eq!(amount, computed_amount); + + println!( + "16/16 Split precomputation: {:?} sec", + end_precomputation - start_precomputation + ); + println!( + "16/16 Split online computation: {:?} sec", + end_online - start_online + ); + } + + /// Discrete log test for 18/14 split + /// + /// Very informal measurements on my machine: + /// - 33 sec for precomputation + /// - 0.8 sec for online computation + #[test] + #[ignore] + fn test_decode_alt_correctness() { + let amount: u32 = 65545; + + let instance = DiscreteLogInstance { + generator: G, + target: Scalar::from(amount) * G, + }; + + // Very informal measurements for now + let start_precomputation = time::precise_time_s(); + let precomputed_hashmap = DiscreteLogInstance::decode_u32_precomputation_alt(G); + let end_precomputation = time::precise_time_s(); + + let start_online = time::precise_time_s(); + let computed_amount = instance + .decode_u32_online_alt(&precomputed_hashmap) + .unwrap(); + let end_online = time::precise_time_s(); + + assert_eq!(amount, computed_amount); + + println!( + "18/14 Split precomputation: {:?} sec", + end_precomputation - start_precomputation + ); + println!( + "18/14 Split online computation: {:?} sec", + end_online - start_online + ); + } +} diff --git a/zk-token-sdk/src/encryption/mod.rs b/zk-token-sdk/src/encryption/mod.rs new file mode 100644 index 000000000..12ed8b341 --- /dev/null +++ b/zk-token-sdk/src/encryption/mod.rs @@ -0,0 +1,3 @@ +pub mod elgamal; +pub mod encode; +pub mod pedersen; diff --git a/zk-token-sdk/src/encryption/pedersen.rs b/zk-token-sdk/src/encryption/pedersen.rs new file mode 100644 index 000000000..830bf8464 --- /dev/null +++ b/zk-token-sdk/src/encryption/pedersen.rs @@ -0,0 +1,484 @@ +#[cfg(not(target_arch = "bpf"))] +use rand::{rngs::OsRng, CryptoRng, RngCore}; +use { + crate::{ + encryption::elgamal::{ElGamalCT, ElGamalPK}, + errors::ProofError, + }, + core::ops::{Add, Div, Mul, Sub}, + curve25519_dalek::{ + constants::{RISTRETTO_BASEPOINT_COMPRESSED, RISTRETTO_BASEPOINT_POINT}, + ristretto::{CompressedRistretto, RistrettoPoint}, + scalar::Scalar, + traits::MultiscalarMul, + }, + serde::{Deserialize, Serialize}, + sha3::Sha3_512, + std::convert::TryInto, + subtle::{Choice, ConstantTimeEq}, + zeroize::Zeroize, +}; + +/// Curve basepoints for which Pedersen commitment is defined over. +/// +/// These points should be fixed for the entire system. +/// TODO: Consider setting these points as constants? +#[allow(non_snake_case)] +#[derive(Serialize, Deserialize, Clone, Copy, Debug, Eq, PartialEq)] +pub struct PedersenBase { + pub G: RistrettoPoint, + pub H: RistrettoPoint, +} +/// Default PedersenBase. This is set arbitrarily for now, but it should be fixed +/// for the entire system. +/// +/// `G` is a constant point in the curve25519_dalek library +/// `H` is the Sha3 hash of `G` interpretted as a RistrettoPoint +impl Default for PedersenBase { + #[allow(non_snake_case)] + fn default() -> PedersenBase { + let G = RISTRETTO_BASEPOINT_POINT; + let H = + RistrettoPoint::hash_from_bytes::(RISTRETTO_BASEPOINT_COMPRESSED.as_bytes()); + + PedersenBase { G, H } + } +} + +/// Handle for the Pedersen commitment scheme +pub struct Pedersen; +impl Pedersen { + /// Given a number as input, the function returns a Pedersen commitment of + /// the number and its corresponding opening. + /// + /// TODO: Interface that takes a random generator as input + #[cfg(not(target_arch = "bpf"))] + pub fn commit>(amount: T) -> (PedersenComm, PedersenOpen) { + let open = PedersenOpen(Scalar::random(&mut OsRng)); + let comm = Pedersen::commit_with(amount, &open); + + (comm, open) + } + + /// Given a number and an opening as inputs, the function returns their + /// Pedersen commitment. + #[allow(non_snake_case)] + pub fn commit_with>(amount: T, open: &PedersenOpen) -> PedersenComm { + let G = PedersenBase::default().G; + let H = PedersenBase::default().H; + + let x: Scalar = amount.into(); + let r = open.get_scalar(); + + PedersenComm(RistrettoPoint::multiscalar_mul(&[x, r], &[G, H])) + } + + /// Given a number, opening, and Pedersen commitment, the function verifies + /// the validity of the commitment with respect to the number and opening. + /// + /// This function is included for completeness and is not used for the + /// c-token program. + #[allow(non_snake_case)] + pub fn verify>( + comm: PedersenComm, + open: PedersenOpen, + amount: T, + ) -> Result<(), ProofError> { + let G = PedersenBase::default().G; + let H = PedersenBase::default().H; + + let x: Scalar = amount.into(); + + let r = open.get_scalar(); + let C = comm.get_point(); + + if C == RistrettoPoint::multiscalar_mul(&[x, r], &[G, H]) { + Ok(()) + } else { + Err(ProofError::VerificationError) + } + } +} + +#[derive(Serialize, Deserialize, Clone, Debug, Zeroize)] +#[zeroize(drop)] +pub struct PedersenOpen(pub(crate) Scalar); +impl PedersenOpen { + pub fn get_scalar(&self) -> Scalar { + self.0 + } + + #[cfg(not(target_arch = "bpf"))] + pub fn random(rng: &mut T) -> Self { + PedersenOpen(Scalar::random(rng)) + } + + #[allow(clippy::wrong_self_convention)] + pub fn to_bytes(&self) -> [u8; 32] { + self.0.to_bytes() + } + + pub fn from_bytes(bytes: &[u8]) -> Option { + match bytes.try_into() { + Ok(bytes) => Scalar::from_canonical_bytes(bytes).map(PedersenOpen), + _ => None, + } + } +} +impl Eq for PedersenOpen {} +impl PartialEq for PedersenOpen { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).unwrap_u8() == 1u8 + } +} +impl ConstantTimeEq for PedersenOpen { + fn ct_eq(&self, other: &Self) -> Choice { + self.0.ct_eq(&other.0) + } +} + +impl Default for PedersenOpen { + fn default() -> Self { + PedersenOpen(Scalar::default()) + } +} + +impl<'a, 'b> Add<&'b PedersenOpen> for &'a PedersenOpen { + type Output = PedersenOpen; + + fn add(self, other: &'b PedersenOpen) -> PedersenOpen { + PedersenOpen(self.get_scalar() + other.get_scalar()) + } +} + +define_add_variants!( + LHS = PedersenOpen, + RHS = PedersenOpen, + Output = PedersenOpen +); + +impl<'a, 'b> Sub<&'b PedersenOpen> for &'a PedersenOpen { + type Output = PedersenOpen; + + fn sub(self, other: &'b PedersenOpen) -> PedersenOpen { + PedersenOpen(self.get_scalar() - other.get_scalar()) + } +} + +define_sub_variants!( + LHS = PedersenOpen, + RHS = PedersenOpen, + Output = PedersenOpen +); + +impl<'a, 'b> Mul<&'b Scalar> for &'a PedersenOpen { + type Output = PedersenOpen; + + fn mul(self, other: &'b Scalar) -> PedersenOpen { + PedersenOpen(self.get_scalar() * other) + } +} + +define_mul_variants!(LHS = PedersenOpen, RHS = Scalar, Output = PedersenOpen); + +impl<'a, 'b> Div<&'b Scalar> for &'a PedersenOpen { + type Output = PedersenOpen; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn div(self, other: &'b Scalar) -> PedersenOpen { + PedersenOpen(self.get_scalar() * other.invert()) + } +} + +define_div_variants!(LHS = PedersenOpen, RHS = Scalar, Output = PedersenOpen); + +#[derive(Serialize, Deserialize, Default, Clone, Copy, Debug, Eq, PartialEq)] +pub struct PedersenComm(pub(crate) RistrettoPoint); +impl PedersenComm { + pub fn get_point(&self) -> RistrettoPoint { + self.0 + } + + #[allow(clippy::wrong_self_convention)] + pub fn to_bytes(&self) -> [u8; 32] { + self.0.compress().to_bytes() + } + + pub fn from_bytes(bytes: &[u8]) -> Option { + Some(PedersenComm( + CompressedRistretto::from_slice(bytes).decompress()?, + )) + } +} + +impl<'a, 'b> Add<&'b PedersenComm> for &'a PedersenComm { + type Output = PedersenComm; + + fn add(self, other: &'b PedersenComm) -> PedersenComm { + PedersenComm(self.get_point() + other.get_point()) + } +} + +define_add_variants!( + LHS = PedersenComm, + RHS = PedersenComm, + Output = PedersenComm +); + +impl<'a, 'b> Sub<&'b PedersenComm> for &'a PedersenComm { + type Output = PedersenComm; + + fn sub(self, other: &'b PedersenComm) -> PedersenComm { + PedersenComm(self.get_point() - other.get_point()) + } +} + +define_sub_variants!( + LHS = PedersenComm, + RHS = PedersenComm, + Output = PedersenComm +); + +impl<'a, 'b> Mul<&'b Scalar> for &'a PedersenComm { + type Output = PedersenComm; + + fn mul(self, other: &'b Scalar) -> PedersenComm { + PedersenComm(self.get_point() * other) + } +} + +define_mul_variants!(LHS = PedersenComm, RHS = Scalar, Output = PedersenComm); + +impl<'a, 'b> Div<&'b Scalar> for &'a PedersenComm { + type Output = PedersenComm; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn div(self, other: &'b Scalar) -> PedersenComm { + PedersenComm(self.get_point() * other.invert()) + } +} + +define_div_variants!(LHS = PedersenComm, RHS = Scalar, Output = PedersenComm); + +/// Decryption handle for Pedersen commitment. +/// +/// A decryption handle can be combined with Pedersen commitments to form an +/// ElGamal ciphertext. +#[derive(Serialize, Deserialize, Default, Clone, Copy, Debug, Eq, PartialEq)] +pub struct PedersenDecHandle(pub(crate) RistrettoPoint); +impl PedersenDecHandle { + pub fn get_point(&self) -> RistrettoPoint { + self.0 + } + + pub fn generate_handle(open: &PedersenOpen, pk: &ElGamalPK) -> PedersenDecHandle { + PedersenDecHandle(open.get_scalar() * pk.get_point()) + } + + /// Maps a decryption token and Pedersen commitment to ElGamal ciphertext + pub fn to_elgamal_ctxt(self, comm: PedersenComm) -> ElGamalCT { + ElGamalCT { + message_comm: comm, + decrypt_handle: self, + } + } + + #[allow(clippy::wrong_self_convention)] + pub fn to_bytes(&self) -> [u8; 32] { + self.0.compress().to_bytes() + } + + pub fn from_bytes(bytes: &[u8]) -> Option { + Some(PedersenDecHandle( + CompressedRistretto::from_slice(bytes).decompress()?, + )) + } +} + +impl<'a, 'b> Add<&'b PedersenDecHandle> for &'a PedersenDecHandle { + type Output = PedersenDecHandle; + + fn add(self, other: &'b PedersenDecHandle) -> PedersenDecHandle { + PedersenDecHandle(self.get_point() + other.get_point()) + } +} + +define_add_variants!( + LHS = PedersenDecHandle, + RHS = PedersenDecHandle, + Output = PedersenDecHandle +); + +impl<'a, 'b> Sub<&'b PedersenDecHandle> for &'a PedersenDecHandle { + type Output = PedersenDecHandle; + + fn sub(self, other: &'b PedersenDecHandle) -> PedersenDecHandle { + PedersenDecHandle(self.get_point() - other.get_point()) + } +} + +define_sub_variants!( + LHS = PedersenDecHandle, + RHS = PedersenDecHandle, + Output = PedersenDecHandle +); + +impl<'a, 'b> Mul<&'b Scalar> for &'a PedersenDecHandle { + type Output = PedersenDecHandle; + + fn mul(self, other: &'b Scalar) -> PedersenDecHandle { + PedersenDecHandle(self.get_point() * other) + } +} + +define_mul_variants!( + LHS = PedersenDecHandle, + RHS = Scalar, + Output = PedersenDecHandle +); + +impl<'a, 'b> Div<&'b Scalar> for &'a PedersenDecHandle { + type Output = PedersenDecHandle; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn div(self, other: &'b Scalar) -> PedersenDecHandle { + PedersenDecHandle(self.get_point() * other.invert()) + } +} + +define_div_variants!( + LHS = PedersenDecHandle, + RHS = Scalar, + Output = PedersenDecHandle +); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_commit_verification_correctness() { + let amt: u64 = 57; + let (comm, open) = Pedersen::commit(amt); + + assert!(Pedersen::verify(comm, open, amt).is_ok()); + } + + #[test] + fn test_homomorphic_addition() { + let amt_0: u64 = 77; + let amt_1: u64 = 57; + + let rng = &mut OsRng; + let open_0 = PedersenOpen(Scalar::random(rng)); + let open_1 = PedersenOpen(Scalar::random(rng)); + + let comm_0 = Pedersen::commit_with(amt_0, &open_0); + let comm_1 = Pedersen::commit_with(amt_1, &open_1); + let comm_addition = Pedersen::commit_with(amt_0 + amt_1, &(open_0 + open_1)); + + assert_eq!(comm_addition, comm_0 + comm_1); + } + + #[test] + fn test_homomorphic_subtraction() { + let amt_0: u64 = 77; + let amt_1: u64 = 57; + + let rng = &mut OsRng; + let open_0 = PedersenOpen(Scalar::random(rng)); + let open_1 = PedersenOpen(Scalar::random(rng)); + + let comm_0 = Pedersen::commit_with(amt_0, &open_0); + let comm_1 = Pedersen::commit_with(amt_1, &open_1); + let comm_addition = Pedersen::commit_with(amt_0 - amt_1, &(open_0 - open_1)); + + assert_eq!(comm_addition, comm_0 - comm_1); + } + + #[test] + fn test_homomorphic_multiplication() { + let amt_0: u64 = 77; + let amt_1: u64 = 57; + + let (comm, open) = Pedersen::commit(amt_0); + let scalar = Scalar::from(amt_1); + let comm_addition = Pedersen::commit_with(amt_0 * amt_1, &(open * scalar)); + + assert_eq!(comm_addition, comm * scalar); + } + + #[test] + fn test_homomorphic_division() { + let amt_0: u64 = 77; + let amt_1: u64 = 7; + + let (comm, open) = Pedersen::commit(amt_0); + let scalar = Scalar::from(amt_1); + let comm_addition = Pedersen::commit_with(amt_0 / amt_1, &(open / scalar)); + + assert_eq!(comm_addition, comm / scalar); + } + + #[test] + fn test_commitment_bytes() { + let amt: u64 = 77; + let (comm, _) = Pedersen::commit(amt); + + let encoded = comm.to_bytes(); + let decoded = PedersenComm::from_bytes(&encoded).unwrap(); + + assert_eq!(comm, decoded); + } + + #[test] + fn test_opening_bytes() { + let open = PedersenOpen(Scalar::random(&mut OsRng)); + + let encoded = open.to_bytes(); + let decoded = PedersenOpen::from_bytes(&encoded).unwrap(); + + assert_eq!(open, decoded); + } + + #[test] + fn test_decrypt_handle_bytes() { + let handle = PedersenDecHandle(RistrettoPoint::default()); + + let encoded = handle.to_bytes(); + let decoded = PedersenDecHandle::from_bytes(&encoded).unwrap(); + + assert_eq!(handle, decoded); + } + + #[test] + fn test_serde_commitment() { + let amt: u64 = 77; + let (comm, _) = Pedersen::commit(amt); + + let encoded = bincode::serialize(&comm).unwrap(); + let decoded: PedersenComm = bincode::deserialize(&encoded).unwrap(); + + assert_eq!(comm, decoded); + } + + #[test] + fn test_serde_opening() { + let open = PedersenOpen(Scalar::random(&mut OsRng)); + + let encoded = bincode::serialize(&open).unwrap(); + let decoded: PedersenOpen = bincode::deserialize(&encoded).unwrap(); + + assert_eq!(open, decoded); + } + + #[test] + fn test_serde_decrypt_handle() { + let handle = PedersenDecHandle(RistrettoPoint::default()); + + let encoded = bincode::serialize(&handle).unwrap(); + let decoded: PedersenDecHandle = bincode::deserialize(&encoded).unwrap(); + + assert_eq!(handle, decoded); + } +} diff --git a/zk-token-sdk/src/errors.rs b/zk-token-sdk/src/errors.rs new file mode 100644 index 000000000..ad7a251e8 --- /dev/null +++ b/zk-token-sdk/src/errors.rs @@ -0,0 +1,19 @@ +//! Errors related to proving and verifying proofs. + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum ProofError { + /// This error occurs when a proof failed to verify. + VerificationError, + /// This error occurs when the proof encoding is malformed. + FormatError, + /// This error occurs during proving if the number of blinding + /// factors does not match the number of values. + WrongNumBlindingFactors, + /// This error occurs when attempting to create a proof with + /// bitsize other than \\(8\\), \\(16\\), \\(32\\), or \\(64\\). + InvalidBitsize, + /// This error occurs when there are insufficient generators for the proof. + InvalidGeneratorsLength, + /// This error occurs when TODO + InconsistentCTData, +} diff --git a/zk-token-sdk/src/instruction/close_account.rs b/zk-token-sdk/src/instruction/close_account.rs new file mode 100644 index 000000000..d4cdcd619 --- /dev/null +++ b/zk-token-sdk/src/instruction/close_account.rs @@ -0,0 +1,163 @@ +use { + crate::pod::*, + bytemuck::{Pod, Zeroable}, +}; +#[cfg(not(target_arch = "bpf"))] +use { + crate::{ + encryption::elgamal::{ElGamalCT, ElGamalSK}, + errors::ProofError, + instruction::Verifiable, + transcript::TranscriptProtocol, + }, + curve25519_dalek::{ + ristretto::RistrettoPoint, + scalar::Scalar, + traits::{IsIdentity, MultiscalarMul}, + }, + merlin::Transcript, + rand::rngs::OsRng, + std::convert::TryInto, +}; + +/// This struct includes the cryptographic proof *and* the account data information needed to verify +/// the proof +/// +/// - The pre-instruction should call CloseAccountData::verify_proof(&self) +/// - The actual program should check that `balance` is consistent with what is +/// currently stored in the confidential token account +/// +#[derive(Clone, Copy, Pod, Zeroable)] +#[repr(C)] +pub struct CloseAccountData { + /// The source account available balance in encrypted form + pub balance: PodElGamalCT, // 64 bytes + + /// Proof that the source account available balance is zero + pub proof: CloseAccountProof, // 64 bytes +} + +#[cfg(not(target_arch = "bpf"))] +impl CloseAccountData { + pub fn new(source_sk: &ElGamalSK, balance: ElGamalCT) -> Self { + let proof = CloseAccountProof::new(source_sk, &balance); + + CloseAccountData { + balance: balance.into(), + proof, + } + } +} + +#[cfg(not(target_arch = "bpf"))] +impl Verifiable for CloseAccountData { + fn verify(&self) -> Result<(), ProofError> { + let balance = self.balance.try_into()?; + self.proof.verify(&balance) + } +} + +/// This struct represents the cryptographic proof component that certifies that the encrypted +/// balance is zero +#[derive(Clone, Copy, Pod, Zeroable)] +#[repr(C)] +#[allow(non_snake_case)] +pub struct CloseAccountProof { + pub R: PodCompressedRistretto, // 32 bytes + pub z: PodScalar, // 32 bytes +} + +#[allow(non_snake_case)] +#[cfg(not(target_arch = "bpf"))] +impl CloseAccountProof { + fn transcript_new() -> Transcript { + Transcript::new(b"CloseAccountProof") + } + + pub fn new(source_sk: &ElGamalSK, balance: &ElGamalCT) -> Self { + let mut transcript = Self::transcript_new(); + + // add a domain separator to record the start of the protocol + transcript.close_account_proof_domain_sep(); + + // extract the relevant scalar and Ristretto points from the input + let s = source_sk.get_scalar(); + let C = balance.decrypt_handle.get_point(); + + // generate a random masking factor that also serves as a nonce + let r = Scalar::random(&mut OsRng); // using OsRng for now + let R = (r * C).compress(); + + // record R on transcript and receive a challenge scalar + transcript.append_point(b"R", &R); + let c = transcript.challenge_scalar(b"c"); + + // compute the masked secret key + let z = c * s + r; + + CloseAccountProof { + R: R.into(), + z: z.into(), + } + } + + pub fn verify(&self, balance: &ElGamalCT) -> Result<(), ProofError> { + let mut transcript = Self::transcript_new(); + + // add a domain separator to record the start of the protocol + transcript.close_account_proof_domain_sep(); + + // extract the relevant scalar and Ristretto points from the input + let C = balance.message_comm.get_point(); + let D = balance.decrypt_handle.get_point(); + + let R = self.R.into(); + let z = self.z.into(); + + // generate a challenge scalar + // + // use `append_point` as opposed to `validate_and_append_point` as the ciphertext is + // already guaranteed to be well-formed + transcript.append_point(b"R", &R); + let c = transcript.challenge_scalar(b"c"); + + // decompress R or return verification error + let R = R.decompress().ok_or(ProofError::VerificationError)?; + + // check the required algebraic relation + let check = RistrettoPoint::multiscalar_mul(vec![z, -c, -Scalar::one()], vec![D, C, R]); + + if check.is_identity() { + Ok(()) + } else { + Err(ProofError::VerificationError) + } + } +} + +#[cfg(test)] +mod test { + use {super::*, crate::encryption::elgamal::ElGamal}; + + #[test] + fn test_close_account_correctness() { + let (source_pk, source_sk) = ElGamal::keygen(); + + // If account balance is 0, then the proof should succeed + let balance = source_pk.encrypt(0_u64); + + let proof = CloseAccountProof::new(&source_sk, &balance); + assert!(proof.verify(&balance).is_ok()); + + // If account balance is not zero, then the proof verification should fail + let balance = source_pk.encrypt(55_u64); + + let proof = CloseAccountProof::new(&source_sk, &balance); + assert!(proof.verify(&balance).is_err()); + + // A zeroed cyphertext should be considered as an account balance of 0 + let zeroed_ct: ElGamalCT = PodElGamalCT::zeroed().try_into().unwrap(); + let proof = CloseAccountProof::new(&source_sk, &zeroed_ct); + assert!(proof.verify(&zeroed_ct).is_ok()); + } +} diff --git a/zk-token-sdk/src/instruction/mod.rs b/zk-token-sdk/src/instruction/mod.rs new file mode 100644 index 000000000..7c0672cdf --- /dev/null +++ b/zk-token-sdk/src/instruction/mod.rs @@ -0,0 +1,21 @@ +mod close_account; +pub mod transfer; +mod update_account_pk; +mod withdraw; + +#[cfg(not(target_arch = "bpf"))] +use crate::errors::ProofError; +pub use { + close_account::CloseAccountData, + transfer::{ + TransferComms, TransferData, TransferEphemeralState, TransferPubKeys, + TransferRangeProofData, TransferValidityProofData, + }, + update_account_pk::UpdateAccountPkData, + withdraw::WithdrawData, +}; + +#[cfg(not(target_arch = "bpf"))] +pub trait Verifiable { + fn verify(&self) -> Result<(), ProofError>; +} diff --git a/zk-token-sdk/src/instruction/transfer.rs b/zk-token-sdk/src/instruction/transfer.rs new file mode 100644 index 000000000..8f34878fd --- /dev/null +++ b/zk-token-sdk/src/instruction/transfer.rs @@ -0,0 +1,546 @@ +use { + crate::pod::*, + bytemuck::{Pod, Zeroable}, +}; +#[cfg(not(target_arch = "bpf"))] +use { + crate::{ + encryption::{ + elgamal::{ElGamalCT, ElGamalPK, ElGamalSK}, + pedersen::{Pedersen, PedersenBase, PedersenComm, PedersenDecHandle, PedersenOpen}, + }, + errors::ProofError, + instruction::Verifiable, + range_proof::RangeProof, + transcript::TranscriptProtocol, + }, + curve25519_dalek::{ + ristretto::{CompressedRistretto, RistrettoPoint}, + scalar::Scalar, + traits::{IsIdentity, MultiscalarMul, VartimeMultiscalarMul}, + }, + merlin::Transcript, + rand::rngs::OsRng, + std::convert::TryInto, +}; + +/// Just a grouping struct for the data required for the two transfer instructions. It is +/// convenient to generate the two components jointly as they share common components. +pub struct TransferData { + pub range_proof: TransferRangeProofData, + pub validity_proof: TransferValidityProofData, +} + +#[cfg(not(target_arch = "bpf"))] +impl TransferData { + pub fn new( + transfer_amount: u64, + spendable_balance: u64, + spendable_ct: ElGamalCT, + source_pk: ElGamalPK, + source_sk: &ElGamalSK, + dest_pk: ElGamalPK, + auditor_pk: ElGamalPK, + ) -> Self { + // split and encrypt transfer amount + // + // encryption is a bit more involved since we are generating each components of an ElGamal + // ciphertext separately. + let (amount_lo, amount_hi) = split_u64_into_u32(transfer_amount); + + let (comm_lo, open_lo) = Pedersen::commit(amount_lo); + let (comm_hi, open_hi) = Pedersen::commit(amount_hi); + + let handle_source_lo = source_pk.gen_decrypt_handle(&open_lo); + let handle_dest_lo = dest_pk.gen_decrypt_handle(&open_lo); + let handle_auditor_lo = auditor_pk.gen_decrypt_handle(&open_lo); + + let handle_source_hi = source_pk.gen_decrypt_handle(&open_hi); + let handle_dest_hi = dest_pk.gen_decrypt_handle(&open_hi); + let handle_auditor_hi = auditor_pk.gen_decrypt_handle(&open_hi); + + // message encoding as Pedersen commitments, which will be included in range proof data + let amount_comms = TransferComms { + lo: comm_lo.into(), + hi: comm_hi.into(), + }; + + // decryption handles, which will be included in the validity proof data + let decryption_handles_lo = TransferHandles { + source: handle_source_lo.into(), + dest: handle_dest_lo.into(), + auditor: handle_auditor_lo.into(), + }; + + let decryption_handles_hi = TransferHandles { + source: handle_source_hi.into(), + dest: handle_dest_hi.into(), + auditor: handle_auditor_hi.into(), + }; + + // grouping of the public keys for the transfer + let transfer_public_keys = TransferPubKeys { + source_pk: source_pk.into(), + dest_pk: dest_pk.into(), + auditor_pk: auditor_pk.into(), + }; + + // subtract transfer amount from the spendable ciphertext + let spendable_comm = spendable_ct.message_comm; + let spendable_handle = spendable_ct.decrypt_handle; + + let new_spendable_balance = spendable_balance - transfer_amount; + let new_spendable_comm = spendable_comm - combine_u32_comms(comm_lo, comm_hi); + let new_spendable_handle = + spendable_handle - combine_u32_handles(handle_source_lo, handle_source_hi); + + let new_spendable_ct = ElGamalCT { + message_comm: new_spendable_comm, + decrypt_handle: new_spendable_handle, + }; + + // range_proof and validity_proof should be generated together + let (transfer_proofs, ephemeral_state) = TransferProofs::new( + source_sk, + &source_pk, + &dest_pk, + &auditor_pk, + (amount_lo as u64, amount_hi as u64), + &open_lo, + &open_hi, + new_spendable_balance, + &new_spendable_ct, + ); + + // generate data components + let range_proof = TransferRangeProofData { + amount_comms, + proof: transfer_proofs.range_proof, + ephemeral_state, + }; + + let validity_proof = TransferValidityProofData { + decryption_handles_lo, + decryption_handles_hi, + transfer_public_keys, + new_spendable_ct: new_spendable_ct.into(), + proof: transfer_proofs.validity_proof, + ephemeral_state, + }; + + TransferData { + range_proof, + validity_proof, + } + } +} + +#[derive(Clone, Copy, Pod, Zeroable)] +#[repr(C)] +pub struct TransferRangeProofData { + /// The transfer amount encoded as Pedersen commitments + pub amount_comms: TransferComms, // 64 bytes + + /// Proof that certifies: + /// 1. the source account has enough funds for the transfer (i.e. the final balance is a + /// 64-bit positive number) + /// 2. the transfer amount is a 64-bit positive number + pub proof: PodRangeProof128, // 736 bytes + + /// Ephemeral state between the two transfer instruction data + pub ephemeral_state: TransferEphemeralState, // 128 bytes +} + +#[cfg(not(target_arch = "bpf"))] +impl Verifiable for TransferRangeProofData { + fn verify(&self) -> Result<(), ProofError> { + let mut transcript = Transcript::new(b"TransferRangeProof"); + + // standard range proof verification + let proof: RangeProof = self.proof.try_into()?; + proof.verify_with( + vec![ + &self.ephemeral_state.spendable_comm_verification.into(), + &self.amount_comms.lo.into(), + &self.amount_comms.hi.into(), + ], + vec![64_usize, 32_usize, 32_usize], + Some(self.ephemeral_state.x.into()), + Some(self.ephemeral_state.z.into()), + &mut transcript, + ) + } +} + +#[derive(Clone, Copy, Pod, Zeroable)] +#[repr(C)] +pub struct TransferValidityProofData { + /// The decryption handles that allow decryption of the lo-bits + pub decryption_handles_lo: TransferHandles, // 96 bytes + + /// The decryption handles that allow decryption of the hi-bits + pub decryption_handles_hi: TransferHandles, // 96 bytes + + /// The public encryption keys associated with the transfer: source, dest, and auditor + pub transfer_public_keys: TransferPubKeys, // 96 bytes + + /// The final spendable ciphertext after the transfer + pub new_spendable_ct: PodElGamalCT, // 64 bytes + + /// Proof that certifies that the decryption handles are generated correctly + pub proof: ValidityProof, // 160 bytes + + /// Ephemeral state between the two transfer instruction data + pub ephemeral_state: TransferEphemeralState, // 128 bytes +} + +/// The joint data that is shared between the two transfer instructions. +/// +/// Identical ephemeral data should be included in the two transfer instructions and this should be +/// checked by the ZK Token program. +#[derive(Clone, Copy, Pod, Zeroable, PartialEq)] +#[repr(C)] +pub struct TransferEphemeralState { + pub spendable_comm_verification: PodPedersenComm, // 32 bytes + pub x: PodScalar, // 32 bytes + pub z: PodScalar, // 32 bytes + pub t_x_blinding: PodScalar, // 32 bytes +} + +#[cfg(not(target_arch = "bpf"))] +impl Verifiable for TransferValidityProofData { + fn verify(&self) -> Result<(), ProofError> { + self.proof.verify( + &self.new_spendable_ct.try_into()?, + &self.decryption_handles_lo, + &self.decryption_handles_hi, + &self.transfer_public_keys, + &self.ephemeral_state, + ) + } +} + +/// Just a grouping struct for the two proofs that are needed for a transfer instruction. The two +/// proofs have to be generated together as they share joint data. +pub struct TransferProofs { + pub range_proof: PodRangeProof128, + pub validity_proof: ValidityProof, +} + +#[allow(non_snake_case)] +#[cfg(not(target_arch = "bpf"))] +impl TransferProofs { + #[allow(clippy::too_many_arguments)] + #[allow(clippy::many_single_char_names)] + pub fn new( + source_sk: &ElGamalSK, + source_pk: &ElGamalPK, + dest_pk: &ElGamalPK, + auditor_pk: &ElGamalPK, + transfer_amt: (u64, u64), + lo_open: &PedersenOpen, + hi_open: &PedersenOpen, + new_spendable_balance: u64, + new_spendable_ct: &ElGamalCT, + ) -> (Self, TransferEphemeralState) { + // TODO: should also commit to pubkeys and commitments later + let mut transcript_validity_proof = merlin::Transcript::new(b"TransferValidityProof"); + + let H = PedersenBase::default().H; + let D = new_spendable_ct.decrypt_handle.get_point(); + let s = source_sk.get_scalar(); + + // Generate proof for the new spendable ciphertext + let r_new = Scalar::random(&mut OsRng); + let y = Scalar::random(&mut OsRng); + let R = RistrettoPoint::multiscalar_mul(vec![y, r_new], vec![D, H]).compress(); + + transcript_validity_proof.append_point(b"R", &R); + let c = transcript_validity_proof.challenge_scalar(b"c"); + + let z = s + c * y; + let new_spendable_open = PedersenOpen(c * r_new); + + let spendable_comm_verification = + Pedersen::commit_with(new_spendable_balance, &new_spendable_open); + + // Generate proof for the transfer amounts + let t_1_blinding = PedersenOpen::random(&mut OsRng); + let t_2_blinding = PedersenOpen::random(&mut OsRng); + + let u = transcript_validity_proof.challenge_scalar(b"u"); + let P_joint = RistrettoPoint::multiscalar_mul( + vec![Scalar::one(), u, u * u], + vec![ + source_pk.get_point(), + dest_pk.get_point(), + auditor_pk.get_point(), + ], + ); + let T_joint = (new_spendable_open.get_scalar() * P_joint).compress(); + let T_1 = (t_1_blinding.get_scalar() * P_joint).compress(); + let T_2 = (t_2_blinding.get_scalar() * P_joint).compress(); + + transcript_validity_proof.append_point(b"T_1", &T_1); + transcript_validity_proof.append_point(b"T_2", &T_2); + + // define the validity proof + let validity_proof = ValidityProof { + R: R.into(), + z: z.into(), + T_joint: T_joint.into(), + T_1: T_1.into(), + T_2: T_2.into(), + }; + + // generate the range proof + let mut transcript_range_proof = Transcript::new(b"TransferRangeProof"); + let (range_proof, x, z) = RangeProof::create_with( + vec![new_spendable_balance, transfer_amt.0, transfer_amt.1], + vec![64, 32, 32], + vec![&new_spendable_open, lo_open, hi_open], + &t_1_blinding, + &t_2_blinding, + &mut transcript_range_proof, + ); + + // define ephemeral state + let ephemeral_state = TransferEphemeralState { + spendable_comm_verification: spendable_comm_verification.into(), + x: x.into(), + z: z.into(), + t_x_blinding: range_proof.t_x_blinding.into(), + }; + + ( + Self { + range_proof: range_proof.try_into().expect("valid range_proof"), + validity_proof, + }, + ephemeral_state, + ) + } +} + +/// Proof components for transfer instructions. +/// +/// These two components should be output by a RangeProof creation function. +#[allow(non_snake_case)] +#[derive(Clone, Copy, Pod, Zeroable)] +#[repr(C)] +pub struct ValidityProof { + // Proof component for the spendable ciphertext components: R + pub R: PodCompressedRistretto, // 32 bytes + // Proof component for the spendable ciphertext components: z + pub z: PodScalar, // 32 bytes + // Proof component for the transaction amount components: T_src + pub T_joint: PodCompressedRistretto, // 32 bytes + // Proof component for the transaction amount components: T_1 + pub T_1: PodCompressedRistretto, // 32 bytes + // Proof component for the transaction amount components: T_2 + pub T_2: PodCompressedRistretto, // 32 bytes +} + +#[allow(non_snake_case)] +#[cfg(not(target_arch = "bpf"))] +impl ValidityProof { + pub fn verify( + self, + new_spendable_ct: &ElGamalCT, + decryption_handles_lo: &TransferHandles, + decryption_handles_hi: &TransferHandles, + transfer_public_keys: &TransferPubKeys, + ephemeral_state: &TransferEphemeralState, + ) -> Result<(), ProofError> { + let mut transcript = Transcript::new(b"TransferValidityProof"); + + let source_pk: ElGamalPK = transfer_public_keys.source_pk.try_into()?; + let dest_pk: ElGamalPK = transfer_public_keys.dest_pk.try_into()?; + let auditor_pk: ElGamalPK = transfer_public_keys.auditor_pk.try_into()?; + + // verify Pedersen commitment in the ephemeral state + let C_ephemeral: CompressedRistretto = ephemeral_state.spendable_comm_verification.into(); + + let C = new_spendable_ct.message_comm.get_point(); + let D = new_spendable_ct.decrypt_handle.get_point(); + + let R = self.R.into(); + let z: Scalar = self.z.into(); + + transcript.validate_and_append_point(b"R", &R)?; + let c = transcript.challenge_scalar(b"c"); + + let R = R.decompress().ok_or(ProofError::VerificationError)?; + + let spendable_comm_verification = + RistrettoPoint::multiscalar_mul(vec![Scalar::one(), -z, c], vec![C, D, R]).compress(); + + if C_ephemeral != spendable_comm_verification { + return Err(ProofError::VerificationError); + } + + // derive joint public key + let u = transcript.challenge_scalar(b"u"); + let P_joint = RistrettoPoint::vartime_multiscalar_mul( + vec![Scalar::one(), u, u * u], + vec![ + source_pk.get_point(), + dest_pk.get_point(), + auditor_pk.get_point(), + ], + ); + + // check well-formedness of decryption handles + let t_x_blinding: Scalar = ephemeral_state.t_x_blinding.into(); + let T_1: CompressedRistretto = self.T_1.into(); + let T_2: CompressedRistretto = self.T_2.into(); + + let x = ephemeral_state.x.into(); + let z: Scalar = ephemeral_state.z.into(); + + let handle_source_lo: PedersenDecHandle = decryption_handles_lo.source.try_into()?; + let handle_dest_lo: PedersenDecHandle = decryption_handles_lo.dest.try_into()?; + let handle_auditor_lo: PedersenDecHandle = decryption_handles_lo.auditor.try_into()?; + + let D_joint: CompressedRistretto = self.T_joint.into(); + let D_joint = D_joint.decompress().ok_or(ProofError::VerificationError)?; + + let D_joint_lo = RistrettoPoint::vartime_multiscalar_mul( + vec![Scalar::one(), u, u * u], + vec![ + handle_source_lo.get_point(), + handle_dest_lo.get_point(), + handle_auditor_lo.get_point(), + ], + ); + + let handle_source_hi: PedersenDecHandle = decryption_handles_hi.source.try_into()?; + let handle_dest_hi: PedersenDecHandle = decryption_handles_hi.dest.try_into()?; + let handle_auditor_hi: PedersenDecHandle = decryption_handles_hi.auditor.try_into()?; + + let D_joint_hi = RistrettoPoint::vartime_multiscalar_mul( + vec![Scalar::one(), u, u * u], + vec![ + handle_source_hi.get_point(), + handle_dest_hi.get_point(), + handle_auditor_hi.get_point(), + ], + ); + + // TODO: combine Pedersen commitment verification above to here for efficiency + // TODO: might need to add an additional proof-of-knowledge here (additional 64 byte) + let mega_check = RistrettoPoint::optional_multiscalar_mul( + vec![-t_x_blinding, x, x * x, z * z, z * z * z, z * z * z * z], + vec![ + Some(P_joint), + T_1.decompress(), + T_2.decompress(), + Some(D_joint), + Some(D_joint_lo), + Some(D_joint_hi), + ], + ) + .ok_or(ProofError::VerificationError)?; + + if mega_check.is_identity() { + Ok(()) + } else { + Err(ProofError::VerificationError) + } + } +} + +/// The ElGamal public keys needed for a transfer +#[derive(Clone, Copy, Pod, Zeroable)] +#[repr(C)] +pub struct TransferPubKeys { + pub source_pk: PodElGamalPK, // 32 bytes + pub dest_pk: PodElGamalPK, // 32 bytes + pub auditor_pk: PodElGamalPK, // 32 bytes +} + +/// The transfer amount commitments needed for a transfer +#[derive(Clone, Copy, Pod, Zeroable)] +#[repr(C)] +pub struct TransferComms { + pub lo: PodPedersenComm, // 32 bytes + pub hi: PodPedersenComm, // 32 bytes +} + +/// The decryption handles needed for a transfer +#[derive(Clone, Copy, Pod, Zeroable)] +#[repr(C)] +pub struct TransferHandles { + pub source: PodPedersenDecHandle, // 32 bytes + pub dest: PodPedersenDecHandle, // 32 bytes + pub auditor: PodPedersenDecHandle, // 32 bytes +} + +/// Split u64 number into two u32 numbers +#[cfg(not(target_arch = "bpf"))] +pub fn split_u64_into_u32(amt: u64) -> (u32, u32) { + let lo = amt as u32; + let hi = (amt >> 32) as u32; + + (lo, hi) +} + +/// Constant for 2^32 +#[cfg(not(target_arch = "bpf"))] +const TWO_32: u64 = 4294967296; + +#[cfg(not(target_arch = "bpf"))] +pub fn combine_u32_comms(comm_lo: PedersenComm, comm_hi: PedersenComm) -> PedersenComm { + comm_lo + comm_hi * Scalar::from(TWO_32) +} + +#[cfg(not(target_arch = "bpf"))] +pub fn combine_u32_handles( + handle_lo: PedersenDecHandle, + handle_hi: PedersenDecHandle, +) -> PedersenDecHandle { + handle_lo + handle_hi * Scalar::from(TWO_32) +} + +#[cfg(not(target_arch = "bpf"))] +pub fn combine_u32_ciphertexts(ct_lo: ElGamalCT, ct_hi: ElGamalCT) -> ElGamalCT { + ct_lo + ct_hi * Scalar::from(TWO_32) +} + +#[cfg(test)] +mod test { + use super::*; + use crate::encryption::elgamal::ElGamal; + + #[test] + fn test_transfer_correctness() { + // ElGamal keys for source, destination, and auditor accounts + let (source_pk, source_sk) = ElGamal::keygen(); + let (dest_pk, _) = ElGamal::keygen(); + let (auditor_pk, _) = ElGamal::keygen(); + + // create source account spendable ciphertext + let spendable_balance: u64 = 77; + let spendable_ct = source_pk.encrypt(spendable_balance); + + // transfer amount + let transfer_amount: u64 = 55; + + // create transfer data + let transfer_data = TransferData::new( + transfer_amount, + spendable_balance, + spendable_ct, + source_pk, + &source_sk, + dest_pk, + auditor_pk, + ); + + // verify range proof + assert!(transfer_data.range_proof.verify().is_ok()); + + // verify ciphertext validity proof + assert!(transfer_data.validity_proof.verify().is_ok()); + } +} diff --git a/zk-token-sdk/src/instruction/update_account_pk.rs b/zk-token-sdk/src/instruction/update_account_pk.rs new file mode 100644 index 000000000..0a1d1c571 --- /dev/null +++ b/zk-token-sdk/src/instruction/update_account_pk.rs @@ -0,0 +1,271 @@ +use { + crate::pod::*, + bytemuck::{Pod, Zeroable}, +}; +#[cfg(not(target_arch = "bpf"))] +use { + crate::{ + encryption::{ + elgamal::{ElGamalCT, ElGamalPK, ElGamalSK}, + pedersen::PedersenBase, + }, + errors::ProofError, + instruction::Verifiable, + transcript::TranscriptProtocol, + }, + curve25519_dalek::{ + ristretto::RistrettoPoint, + scalar::Scalar, + traits::{IsIdentity, MultiscalarMul}, + }, + merlin::Transcript, + rand::rngs::OsRng, + std::convert::TryInto, +}; + +/// This struct includes the cryptographic proof *and* the account data information needed to verify +/// the proof +/// +/// - The pre-instruction should call UpdateAccountPKData::verify(&self) +/// - The actual program should check that `current_ct` is consistent with what is +/// currently stored in the confidential token account +/// +#[derive(Clone, Copy, Pod, Zeroable)] +#[repr(C)] +pub struct UpdateAccountPkData { + /// Current ElGamal encryption key + pub current_pk: PodElGamalPK, // 32 bytes + + /// Current encrypted available balance + pub current_ct: PodElGamalCT, // 64 bytes + + /// New ElGamal encryption key + pub new_pk: PodElGamalPK, // 32 bytes + + /// New encrypted available balance + pub new_ct: PodElGamalCT, // 64 bytes + + /// Proof that the current and new ciphertexts are consistent + pub proof: UpdateAccountPkProof, // 160 bytes +} + +impl UpdateAccountPkData { + #[cfg(not(target_arch = "bpf"))] + pub fn new( + current_balance: u64, + current_ct: ElGamalCT, + current_pk: ElGamalPK, + current_sk: &ElGamalSK, + new_pk: ElGamalPK, + new_sk: &ElGamalSK, + ) -> Self { + let new_ct = new_pk.encrypt(current_balance); + + let proof = + UpdateAccountPkProof::new(current_balance, current_sk, new_sk, ¤t_ct, &new_ct); + + Self { + current_pk: current_pk.into(), + current_ct: current_ct.into(), + new_ct: new_ct.into(), + new_pk: new_pk.into(), + proof, + } + } +} + +#[cfg(not(target_arch = "bpf"))] +impl Verifiable for UpdateAccountPkData { + fn verify(&self) -> Result<(), ProofError> { + let current_ct = self.current_ct.try_into()?; + let new_ct = self.new_ct.try_into()?; + self.proof.verify(¤t_ct, &new_ct) + } +} + +/// This struct represents the cryptographic proof component that certifies that the current_ct and +/// new_ct encrypt equal values +#[derive(Clone, Copy, Pod, Zeroable)] +#[repr(C)] +#[allow(non_snake_case)] +pub struct UpdateAccountPkProof { + pub R_0: PodCompressedRistretto, // 32 bytes + pub R_1: PodCompressedRistretto, // 32 bytes + pub z_sk_0: PodScalar, // 32 bytes + pub z_sk_1: PodScalar, // 32 bytes + pub z_x: PodScalar, // 32 bytes +} + +#[allow(non_snake_case)] +#[cfg(not(target_arch = "bpf"))] +impl UpdateAccountPkProof { + fn transcript_new() -> Transcript { + Transcript::new(b"UpdateAccountPkProof") + } + + fn new( + current_balance: u64, + current_sk: &ElGamalSK, + new_sk: &ElGamalSK, + current_ct: &ElGamalCT, + new_ct: &ElGamalCT, + ) -> Self { + let mut transcript = Self::transcript_new(); + + // add a domain separator to record the start of the protocol + transcript.update_account_public_key_proof_domain_sep(); + + // extract the relevant scalar and Ristretto points from the input + let s_0 = current_sk.get_scalar(); + let s_1 = new_sk.get_scalar(); + let x = Scalar::from(current_balance); + + let D_0 = current_ct.decrypt_handle.get_point(); + let D_1 = new_ct.decrypt_handle.get_point(); + + let G = PedersenBase::default().G; + + // generate a random masking factor that also serves as a nonce + let r_sk_0 = Scalar::random(&mut OsRng); + let r_sk_1 = Scalar::random(&mut OsRng); + let r_x = Scalar::random(&mut OsRng); + + let R_0 = (r_sk_0 * D_0 + r_x * G).compress(); + let R_1 = (r_sk_1 * D_1 + r_x * G).compress(); + + // record R_0, R_1 on transcript and receive a challenge scalar + transcript.append_point(b"R_0", &R_0); + transcript.append_point(b"R_1", &R_1); + let c = transcript.challenge_scalar(b"c"); + let _w = transcript.challenge_scalar(b"w"); // for consistency of transcript + + // compute the masked secret keys and amount + let z_sk_0 = c * s_0 + r_sk_0; + let z_sk_1 = c * s_1 + r_sk_1; + let z_x = c * x + r_x; + + UpdateAccountPkProof { + R_0: R_0.into(), + R_1: R_1.into(), + z_sk_0: z_sk_0.into(), + z_sk_1: z_sk_1.into(), + z_x: z_x.into(), + } + } + + fn verify(&self, current_ct: &ElGamalCT, new_ct: &ElGamalCT) -> Result<(), ProofError> { + let mut transcript = Self::transcript_new(); + + // add a domain separator to record the start of the protocol + transcript.update_account_public_key_proof_domain_sep(); + + // extract the relevant scalar and Ristretto points from the input + let C_0 = current_ct.message_comm.get_point(); + let D_0 = current_ct.decrypt_handle.get_point(); + + let C_1 = new_ct.message_comm.get_point(); + let D_1 = new_ct.decrypt_handle.get_point(); + + let R_0 = self.R_0.into(); + let R_1 = self.R_1.into(); + let z_sk_0 = self.z_sk_0.into(); + let z_sk_1: Scalar = self.z_sk_1.into(); + let z_x = self.z_x.into(); + + let G = PedersenBase::default().G; + + // generate a challenge scalar + transcript.validate_and_append_point(b"R_0", &R_0)?; + transcript.validate_and_append_point(b"R_1", &R_1)?; + let c = transcript.challenge_scalar(b"c"); + let w = transcript.challenge_scalar(b"w"); + + // decompress R_0, R_1 or return verification error + let R_0 = R_0.decompress().ok_or(ProofError::VerificationError)?; + let R_1 = R_1.decompress().ok_or(ProofError::VerificationError)?; + + // check the required algebraic relation + let check = RistrettoPoint::multiscalar_mul( + vec![ + z_sk_0, + z_x, + -c, + -Scalar::one(), + w * z_sk_1, + w * z_x, + -w * c, + -w * Scalar::one(), + ], + vec![D_0, G, C_0, R_0, D_1, G, C_1, R_1], + ); + + if check.is_identity() { + Ok(()) + } else { + Err(ProofError::VerificationError) + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::encryption::elgamal::ElGamal; + + #[test] + fn test_update_account_public_key_correctness() { + let (current_pk, current_sk) = ElGamal::keygen(); + let (new_pk, new_sk) = ElGamal::keygen(); + + // If current_ct and new_ct encrypt same values, then the proof verification should succeed + let balance: u64 = 77; + let current_ct = current_pk.encrypt(balance); + let new_ct = new_pk.encrypt(balance); + + let proof = UpdateAccountPkProof::new(balance, ¤t_sk, &new_sk, ¤t_ct, &new_ct); + assert!(proof.verify(¤t_ct, &new_ct).is_ok()); + + // If current_ct and new_ct encrypt different values, then the proof verification should fail + let new_ct = new_pk.encrypt(55_u64); + + let proof = UpdateAccountPkProof::new(balance, ¤t_sk, &new_sk, ¤t_ct, &new_ct); + assert!(proof.verify(¤t_ct, &new_ct).is_err()); + + // A zeroed cipehrtext should be considered as an account balance of 0 + let balance: u64 = 0; + let zeroed_ct_as_current_ct: ElGamalCT = PodElGamalCT::zeroed().try_into().unwrap(); + let new_ct: ElGamalCT = new_pk.encrypt(balance); + let proof = UpdateAccountPkProof::new( + balance, + ¤t_sk, + &new_sk, + &zeroed_ct_as_current_ct, + &new_ct, + ); + assert!(proof.verify(&zeroed_ct_as_current_ct, &new_ct).is_ok()); + + let current_ct: ElGamalCT = PodElGamalCT::zeroed().try_into().unwrap(); + let zeroed_ct_as_new_ct: ElGamalCT = PodElGamalCT::zeroed().try_into().unwrap(); + let proof = UpdateAccountPkProof::new( + balance, + ¤t_sk, + &new_sk, + ¤t_ct, + &zeroed_ct_as_new_ct, + ); + assert!(proof.verify(¤t_ct, &zeroed_ct_as_new_ct).is_ok()); + + let zeroed_ct_as_current_ct: ElGamalCT = PodElGamalCT::zeroed().try_into().unwrap(); + let zeroed_ct_as_new_ct: ElGamalCT = PodElGamalCT::zeroed().try_into().unwrap(); + let proof = UpdateAccountPkProof::new( + balance, + ¤t_sk, + &new_sk, + &zeroed_ct_as_current_ct, + &zeroed_ct_as_new_ct, + ); + assert!(proof + .verify(&zeroed_ct_as_current_ct, &zeroed_ct_as_new_ct) + .is_ok()); + } +} diff --git a/zk-token-sdk/src/instruction/withdraw.rs b/zk-token-sdk/src/instruction/withdraw.rs new file mode 100644 index 000000000..26fcd451a --- /dev/null +++ b/zk-token-sdk/src/instruction/withdraw.rs @@ -0,0 +1,207 @@ +use { + crate::pod::*, + bytemuck::{Pod, Zeroable}, +}; +#[cfg(not(target_arch = "bpf"))] +use { + crate::{ + encryption::{ + elgamal::{ElGamalCT, ElGamalPK, ElGamalSK}, + pedersen::{PedersenBase, PedersenOpen}, + }, + errors::ProofError, + instruction::Verifiable, + range_proof::RangeProof, + transcript::TranscriptProtocol, + }, + curve25519_dalek::{ristretto::RistrettoPoint, scalar::Scalar, traits::MultiscalarMul}, + merlin::Transcript, + rand::rngs::OsRng, + std::convert::TryInto, +}; + +/// This struct includes the cryptographic proof *and* the account data information needed to verify +/// the proof +/// +/// - The pre-instruction should call WithdrawData::verify_proof(&self) +/// - The actual program should check that `current_ct` is consistent with what is +/// currently stored in the confidential token account TODO: update this statement +/// +#[derive(Clone, Copy, Pod, Zeroable)] +#[repr(C)] +pub struct WithdrawData { + /// The source account available balance *after* the withdraw (encrypted by + /// `source_pk` + pub final_balance_ct: PodElGamalCT, // 64 bytes + + /// Proof that the account is solvent + pub proof: WithdrawProof, // 736 bytes +} + +impl WithdrawData { + #[cfg(not(target_arch = "bpf"))] + pub fn new( + amount: u64, + source_pk: ElGamalPK, + source_sk: &ElGamalSK, + current_balance: u64, + current_balance_ct: ElGamalCT, + ) -> Self { + // subtract withdraw amount from current balance + // + // panics if current_balance < amount + let final_balance = current_balance - amount; + + // encode withdraw amount as an ElGamal ciphertext and subtract it from + // current source balance + let amount_encoded = source_pk.encrypt_with(amount, &PedersenOpen::default()); + let final_balance_ct = current_balance_ct - amount_encoded; + + let proof = WithdrawProof::new(source_sk, final_balance, &final_balance_ct); + + Self { + final_balance_ct: final_balance_ct.into(), + proof, + } + } +} + +#[cfg(not(target_arch = "bpf"))] +impl Verifiable for WithdrawData { + fn verify(&self) -> Result<(), ProofError> { + let final_balance_ct = self.final_balance_ct.try_into()?; + self.proof.verify(&final_balance_ct) + } +} + +/// This struct represents the cryptographic proof component that certifies the account's solvency +/// for withdrawal +#[derive(Clone, Copy, Pod, Zeroable)] +#[repr(C)] +#[allow(non_snake_case)] +pub struct WithdrawProof { + /// Wrapper for range proof: R component + pub R: PodCompressedRistretto, // 32 bytes + /// Wrapper for range proof: z component + pub z: PodScalar, // 32 bytes + /// Associated range proof + pub range_proof: PodRangeProof64, // 672 bytes +} + +#[allow(non_snake_case)] +#[cfg(not(target_arch = "bpf"))] +impl WithdrawProof { + fn transcript_new() -> Transcript { + Transcript::new(b"WithdrawProof") + } + + pub fn new(source_sk: &ElGamalSK, final_balance: u64, final_balance_ct: &ElGamalCT) -> Self { + let mut transcript = Self::transcript_new(); + + // add a domain separator to record the start of the protocol + transcript.withdraw_proof_domain_sep(); + + // extract the relevant scalar and Ristretto points from the input + let H = PedersenBase::default().H; + let D = final_balance_ct.decrypt_handle.get_point(); + let s = source_sk.get_scalar(); + + // new pedersen opening + let r_new = Scalar::random(&mut OsRng); + + // generate a random masking factor that also serves as a nonce + let y = Scalar::random(&mut OsRng); + + let R = RistrettoPoint::multiscalar_mul(vec![y, r_new], vec![D, H]).compress(); + + // record R on transcript and receive a challenge scalar + transcript.append_point(b"R", &R); + let c = transcript.challenge_scalar(b"c"); + + // compute the masked secret key + let z = s + c * y; + + // compute the new Pedersen commitment and opening + let new_open = PedersenOpen(c * r_new); + + let range_proof = RangeProof::create( + vec![final_balance], + vec![64], + vec![&new_open], + &mut transcript, + ); + + WithdrawProof { + R: R.into(), + z: z.into(), + range_proof: range_proof.try_into().expect("range proof"), + } + } + + pub fn verify(&self, final_balance_ct: &ElGamalCT) -> Result<(), ProofError> { + let mut transcript = Self::transcript_new(); + + // Add a domain separator to record the start of the protocol + transcript.withdraw_proof_domain_sep(); + + // Extract the relevant scalar and Ristretto points from the input + let C = final_balance_ct.message_comm.get_point(); + let D = final_balance_ct.decrypt_handle.get_point(); + + let R = self.R.into(); + let z: Scalar = self.z.into(); + + // generate a challenge scalar + transcript.validate_and_append_point(b"R", &R)?; + let c = transcript.challenge_scalar(b"c"); + + // decompress R or return verification error + let R = R.decompress().ok_or(ProofError::VerificationError)?; + + // compute new Pedersen commitment to verify range proof with + let new_comm = RistrettoPoint::multiscalar_mul(vec![Scalar::one(), -z, c], vec![C, D, R]); + + let range_proof: RangeProof = self.range_proof.try_into()?; + range_proof.verify(vec![&new_comm.compress()], vec![64_usize], &mut transcript) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::encryption::elgamal::ElGamal; + + #[test] + #[ignore] + fn test_withdraw_correctness() { + // generate and verify proof for the proper setting + let (source_pk, source_sk) = ElGamal::keygen(); + + let current_balance: u64 = 77; + let current_balance_ct = source_pk.encrypt(current_balance); + + let withdraw_amount: u64 = 55; + + let data = WithdrawData::new( + withdraw_amount, + source_pk, + &source_sk, + current_balance, + current_balance_ct, + ); + assert!(data.verify().is_ok()); + + // generate and verify proof with wrong balance + let wrong_balance: u64 = 99; + let data = WithdrawData::new( + withdraw_amount, + source_pk, + &source_sk, + wrong_balance, + current_balance_ct, + ); + assert!(data.verify().is_err()); + + // TODO: test for ciphertexts that encrypt numbers outside the 0, 2^64 range + } +} diff --git a/zk-token-sdk/src/lib.rs b/zk-token-sdk/src/lib.rs new file mode 100644 index 000000000..0942ceaae --- /dev/null +++ b/zk-token-sdk/src/lib.rs @@ -0,0 +1,19 @@ +#[cfg(not(target_arch = "bpf"))] +#[macro_use] +pub(crate) mod macros; + +#[cfg(not(target_arch = "bpf"))] +pub mod encryption; + +#[cfg(not(target_arch = "bpf"))] +mod errors; + +#[cfg(not(target_arch = "bpf"))] +mod range_proof; +#[cfg(not(target_arch = "bpf"))] +mod transcript; + +mod instruction; +pub mod pod; +pub mod zk_token_proof_instruction; +pub mod zk_token_proof_program; diff --git a/zk-token-sdk/src/macros.rs b/zk-token-sdk/src/macros.rs new file mode 100644 index 000000000..1b6c48b8d --- /dev/null +++ b/zk-token-sdk/src/macros.rs @@ -0,0 +1,101 @@ +// Internal macros, used to defined the 'main' impls +#[macro_export] +macro_rules! define_add_variants { + (LHS = $lhs:ty, RHS = $rhs:ty, Output = $out:ty) => { + impl<'b> Add<&'b $rhs> for $lhs { + type Output = $out; + fn add(self, rhs: &'b $rhs) -> $out { + &self + rhs + } + } + + impl<'a> Add<$rhs> for &'a $lhs { + type Output = $out; + fn add(self, rhs: $rhs) -> $out { + self + &rhs + } + } + + impl Add<$rhs> for $lhs { + type Output = $out; + fn add(self, rhs: $rhs) -> $out { + &self + &rhs + } + } + }; +} + +macro_rules! define_sub_variants { + (LHS = $lhs:ty, RHS = $rhs:ty, Output = $out:ty) => { + impl<'b> Sub<&'b $rhs> for $lhs { + type Output = $out; + fn sub(self, rhs: &'b $rhs) -> $out { + &self - rhs + } + } + + impl<'a> Sub<$rhs> for &'a $lhs { + type Output = $out; + fn sub(self, rhs: $rhs) -> $out { + self - &rhs + } + } + + impl Sub<$rhs> for $lhs { + type Output = $out; + fn sub(self, rhs: $rhs) -> $out { + &self - &rhs + } + } + }; +} + +macro_rules! define_mul_variants { + (LHS = $lhs:ty, RHS = $rhs:ty, Output = $out:ty) => { + impl<'b> Mul<&'b $rhs> for $lhs { + type Output = $out; + fn mul(self, rhs: &'b $rhs) -> $out { + &self * rhs + } + } + + impl<'a> Mul<$rhs> for &'a $lhs { + type Output = $out; + fn mul(self, rhs: $rhs) -> $out { + self * &rhs + } + } + + impl Mul<$rhs> for $lhs { + type Output = $out; + fn mul(self, rhs: $rhs) -> $out { + &self * &rhs + } + } + }; +} + +macro_rules! define_div_variants { + (LHS = $lhs:ty, RHS = $rhs:ty, Output = $out:ty) => { + impl<'b> Div<&'b $rhs> for $lhs { + type Output = $out; + fn div(self, rhs: &'b $rhs) -> $out { + &self / rhs + } + } + + impl<'a> Div<$rhs> for &'a $lhs { + type Output = $out; + fn div(self, rhs: $rhs) -> $out { + self / &rhs + } + } + + impl Div<$rhs> for $lhs { + type Output = $out; + fn div(self, rhs: $rhs) -> $out { + &self / &rhs + } + } + }; +} diff --git a/zk-token-sdk/src/pod.rs b/zk-token-sdk/src/pod.rs new file mode 100644 index 000000000..d614be783 --- /dev/null +++ b/zk-token-sdk/src/pod.rs @@ -0,0 +1,596 @@ +//! Plain Old Data wrappers for types that need to be sent over the wire + +use bytemuck::{Pod, Zeroable}; +#[cfg(not(target_arch = "bpf"))] +use { + crate::{ + encryption::elgamal::{ElGamalCT, ElGamalPK}, + encryption::pedersen::{PedersenComm, PedersenDecHandle}, + errors::ProofError, + range_proof::RangeProof, + }, + curve25519_dalek::{ + constants::RISTRETTO_BASEPOINT_COMPRESSED, ristretto::CompressedRistretto, scalar::Scalar, + }, + std::{ + convert::{TryFrom, TryInto}, + fmt, + }, +}; + +#[derive(Clone, Copy, Pod, Zeroable, PartialEq)] +#[repr(transparent)] +pub struct PodScalar([u8; 32]); + +#[cfg(not(target_arch = "bpf"))] +impl From for PodScalar { + fn from(scalar: Scalar) -> Self { + Self(scalar.to_bytes()) + } +} + +#[cfg(not(target_arch = "bpf"))] +impl From for Scalar { + fn from(pod: PodScalar) -> Self { + Scalar::from_bits(pod.0) + } +} + +#[derive(Clone, Copy, Pod, Zeroable)] +#[repr(transparent)] +pub struct PodCompressedRistretto([u8; 32]); + +#[cfg(not(target_arch = "bpf"))] +impl From for PodCompressedRistretto { + fn from(cr: CompressedRistretto) -> Self { + Self(cr.to_bytes()) + } +} + +#[cfg(not(target_arch = "bpf"))] +impl From for CompressedRistretto { + fn from(pod: PodCompressedRistretto) -> Self { + Self(pod.0) + } +} + +#[derive(Clone, Copy, Pod, Zeroable, PartialEq)] +#[repr(transparent)] +pub struct PodElGamalCT([u8; 64]); + +#[cfg(not(target_arch = "bpf"))] +impl From for PodElGamalCT { + fn from(ct: ElGamalCT) -> Self { + Self(ct.to_bytes()) + } +} + +#[cfg(not(target_arch = "bpf"))] +impl TryFrom for ElGamalCT { + type Error = ProofError; + + fn try_from(pod: PodElGamalCT) -> Result { + Self::from_bytes(&pod.0).ok_or(ProofError::InconsistentCTData) + } +} + +impl From<(PodPedersenComm, PodPedersenDecHandle)> for PodElGamalCT { + fn from((pod_comm, pod_decrypt_handle): (PodPedersenComm, PodPedersenDecHandle)) -> Self { + let mut buf = [0_u8; 64]; + buf[..32].copy_from_slice(&pod_comm.0); + buf[32..].copy_from_slice(&pod_decrypt_handle.0); + PodElGamalCT(buf) + } +} + +#[cfg(not(target_arch = "bpf"))] +impl fmt::Debug for PodElGamalCT { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{:?}", self.0) + } +} + +#[derive(Clone, Copy, Pod, Zeroable, PartialEq)] +#[repr(transparent)] +pub struct PodElGamalPK([u8; 32]); + +#[cfg(not(target_arch = "bpf"))] +impl From for PodElGamalPK { + fn from(pk: ElGamalPK) -> Self { + Self(pk.to_bytes()) + } +} + +#[cfg(not(target_arch = "bpf"))] +impl TryFrom for ElGamalPK { + type Error = ProofError; + + fn try_from(pod: PodElGamalPK) -> Result { + Self::from_bytes(&pod.0).ok_or(ProofError::InconsistentCTData) + } +} + +#[cfg(not(target_arch = "bpf"))] +impl fmt::Debug for PodElGamalPK { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{:?}", self.0) + } +} + +#[derive(Clone, Copy, Pod, Zeroable, PartialEq)] +#[repr(transparent)] +pub struct PodPedersenComm([u8; 32]); + +#[cfg(not(target_arch = "bpf"))] +impl From for PodPedersenComm { + fn from(comm: PedersenComm) -> Self { + Self(comm.to_bytes()) + } +} + +// For proof verification, interpret PodPedersenComm directly as CompressedRistretto +#[cfg(not(target_arch = "bpf"))] +impl From for CompressedRistretto { + fn from(pod: PodPedersenComm) -> Self { + Self(pod.0) + } +} + +#[cfg(not(target_arch = "bpf"))] +impl TryFrom for PedersenComm { + type Error = ProofError; + + fn try_from(pod: PodPedersenComm) -> Result { + Self::from_bytes(&pod.0).ok_or(ProofError::InconsistentCTData) + } +} + +#[cfg(not(target_arch = "bpf"))] +impl fmt::Debug for PodPedersenComm { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{:?}", self.0) + } +} + +#[derive(Clone, Copy, Pod, Zeroable, PartialEq)] +#[repr(transparent)] +pub struct PodPedersenDecHandle([u8; 32]); + +#[cfg(not(target_arch = "bpf"))] +impl From for PodPedersenDecHandle { + fn from(handle: PedersenDecHandle) -> Self { + Self(handle.to_bytes()) + } +} + +// For proof verification, interpret PodPedersenDecHandle as CompressedRistretto +#[cfg(not(target_arch = "bpf"))] +impl From for CompressedRistretto { + fn from(pod: PodPedersenDecHandle) -> Self { + Self(pod.0) + } +} + +#[cfg(not(target_arch = "bpf"))] +impl TryFrom for PedersenDecHandle { + type Error = ProofError; + + fn try_from(pod: PodPedersenDecHandle) -> Result { + Self::from_bytes(&pod.0).ok_or(ProofError::InconsistentCTData) + } +} + +#[cfg(not(target_arch = "bpf"))] +impl fmt::Debug for PodPedersenDecHandle { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{:?}", self.0) + } +} + +/// Serialization of range proofs for 64-bit numbers (for `Withdraw` instruction) +#[derive(Clone, Copy)] +#[repr(transparent)] +pub struct PodRangeProof64([u8; 672]); + +// `PodRangeProof64` is a Pod and Zeroable. +// Add the marker traits manually because `bytemuck` only adds them for some `u8` arrays +unsafe impl Zeroable for PodRangeProof64 {} +unsafe impl Pod for PodRangeProof64 {} + +#[cfg(not(target_arch = "bpf"))] +impl TryFrom for PodRangeProof64 { + type Error = ProofError; + + fn try_from(proof: RangeProof) -> Result { + if proof.ipp_proof.serialized_size() != 448 { + return Err(ProofError::VerificationError); + } + + let mut buf = [0_u8; 672]; + buf[..32].copy_from_slice(proof.A.as_bytes()); + buf[32..64].copy_from_slice(proof.S.as_bytes()); + buf[64..96].copy_from_slice(proof.T_1.as_bytes()); + buf[96..128].copy_from_slice(proof.T_2.as_bytes()); + buf[128..160].copy_from_slice(proof.t_x.as_bytes()); + buf[160..192].copy_from_slice(proof.t_x_blinding.as_bytes()); + buf[192..224].copy_from_slice(proof.e_blinding.as_bytes()); + buf[224..672].copy_from_slice(&proof.ipp_proof.to_bytes()); + Ok(PodRangeProof64(buf)) + } +} + +#[cfg(not(target_arch = "bpf"))] +impl TryFrom for RangeProof { + type Error = ProofError; + + fn try_from(pod: PodRangeProof64) -> Result { + Self::from_bytes(&pod.0) + } +} + +/// Serialization of range proofs for 128-bit numbers (for `TransferRangeProof` instruction) +#[derive(Clone, Copy)] +#[repr(transparent)] +pub struct PodRangeProof128([u8; 736]); + +// `PodRangeProof128` is a Pod and Zeroable. +// Add the marker traits manually because `bytemuck` only adds them for some `u8` arrays +unsafe impl Zeroable for PodRangeProof128 {} +unsafe impl Pod for PodRangeProof128 {} + +#[cfg(not(target_arch = "bpf"))] +impl TryFrom for PodRangeProof128 { + type Error = ProofError; + + fn try_from(proof: RangeProof) -> Result { + if proof.ipp_proof.serialized_size() != 512 { + return Err(ProofError::VerificationError); + } + + let mut buf = [0_u8; 736]; + buf[..32].copy_from_slice(proof.A.as_bytes()); + buf[32..64].copy_from_slice(proof.S.as_bytes()); + buf[64..96].copy_from_slice(proof.T_1.as_bytes()); + buf[96..128].copy_from_slice(proof.T_2.as_bytes()); + buf[128..160].copy_from_slice(proof.t_x.as_bytes()); + buf[160..192].copy_from_slice(proof.t_x_blinding.as_bytes()); + buf[192..224].copy_from_slice(proof.e_blinding.as_bytes()); + buf[224..736].copy_from_slice(&proof.ipp_proof.to_bytes()); + Ok(PodRangeProof128(buf)) + } +} + +#[cfg(not(target_arch = "bpf"))] +impl TryFrom for RangeProof { + type Error = ProofError; + + fn try_from(pod: PodRangeProof128) -> Result { + Self::from_bytes(&pod.0) + } +} + +pub struct PodElGamalArithmetic; + +#[cfg(not(target_arch = "bpf"))] +impl PodElGamalArithmetic { + const TWO_32: u64 = 4294967296; + + // On input two scalars x0, x1 and two ciphertexts ct0, ct1, + // returns `Some(x0*ct0 + x1*ct1)` or `None` if the input was invalid + fn add_pod_ciphertexts( + scalar_0: Scalar, + pod_ct_0: PodElGamalCT, + scalar_1: Scalar, + pod_ct_1: PodElGamalCT, + ) -> Option { + let ct_0: ElGamalCT = pod_ct_0.try_into().ok()?; + let ct_1: ElGamalCT = pod_ct_1.try_into().ok()?; + + let ct_sum = ct_0 * scalar_0 + ct_1 * scalar_1; + Some(PodElGamalCT::from(ct_sum)) + } + + fn combine_lo_hi(pod_ct_lo: PodElGamalCT, pod_ct_hi: PodElGamalCT) -> Option { + Self::add_pod_ciphertexts( + Scalar::one(), + pod_ct_lo, + Scalar::from(Self::TWO_32), + pod_ct_hi, + ) + } + + pub fn add(pod_ct_0: PodElGamalCT, pod_ct_1: PodElGamalCT) -> Option { + Self::add_pod_ciphertexts(Scalar::one(), pod_ct_0, Scalar::one(), pod_ct_1) + } + + pub fn add_with_lo_hi( + pod_ct_0: PodElGamalCT, + pod_ct_1_lo: PodElGamalCT, + pod_ct_1_hi: PodElGamalCT, + ) -> Option { + let pod_ct_1 = Self::combine_lo_hi(pod_ct_1_lo, pod_ct_1_hi)?; + Self::add_pod_ciphertexts(Scalar::one(), pod_ct_0, Scalar::one(), pod_ct_1) + } + + pub fn subtract(pod_ct_0: PodElGamalCT, pod_ct_1: PodElGamalCT) -> Option { + Self::add_pod_ciphertexts(Scalar::one(), pod_ct_0, -Scalar::one(), pod_ct_1) + } + + pub fn subtract_with_lo_hi( + pod_ct_0: PodElGamalCT, + pod_ct_1_lo: PodElGamalCT, + pod_ct_1_hi: PodElGamalCT, + ) -> Option { + let pod_ct_1 = Self::combine_lo_hi(pod_ct_1_lo, pod_ct_1_hi)?; + Self::add_pod_ciphertexts(Scalar::one(), pod_ct_0, -Scalar::one(), pod_ct_1) + } + + pub fn add_to(pod_ct: PodElGamalCT, amount: u64) -> Option { + let mut amount_as_pod_ct = [0_u8; 64]; + amount_as_pod_ct[..32].copy_from_slice(RISTRETTO_BASEPOINT_COMPRESSED.as_bytes()); + Self::add_pod_ciphertexts( + Scalar::one(), + pod_ct, + Scalar::from(amount), + PodElGamalCT(amount_as_pod_ct), + ) + } + + pub fn subtract_from(pod_ct: PodElGamalCT, amount: u64) -> Option { + let mut amount_as_pod_ct = [0_u8; 64]; + amount_as_pod_ct[..32].copy_from_slice(RISTRETTO_BASEPOINT_COMPRESSED.as_bytes()); + Self::add_pod_ciphertexts( + Scalar::one(), + pod_ct, + -Scalar::from(amount), + PodElGamalCT(amount_as_pod_ct), + ) + } +} +#[cfg(target_arch = "bpf")] +#[allow(unused_variables)] +impl PodElGamalArithmetic { + pub fn add(pod_ct_0: PodElGamalCT, pod_ct_1: PodElGamalCT) -> Option { + None + } + + pub fn add_with_lo_hi( + pod_ct_0: PodElGamalCT, + pod_ct_1_lo: PodElGamalCT, + pod_ct_1_hi: PodElGamalCT, + ) -> Option { + None + } + + pub fn subtract(pod_ct_0: PodElGamalCT, pod_ct_1: PodElGamalCT) -> Option { + None + } + + pub fn subtract_with_lo_hi( + pod_ct_0: PodElGamalCT, + pod_ct_1_lo: PodElGamalCT, + pod_ct_1_hi: PodElGamalCT, + ) -> Option { + None + } + + pub fn add_to(pod_ct: PodElGamalCT, amount: u64) -> Option { + None + } + + pub fn subtract_from(pod_ct: PodElGamalCT, amount: u64) -> Option { + None + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + crate::encryption::{ + elgamal::{ElGamal, ElGamalCT}, + pedersen::{Pedersen, PedersenOpen}, + }, + crate::instruction::transfer::split_u64_into_u32, + merlin::Transcript, + rand::rngs::OsRng, + std::convert::TryInto, + }; + + #[test] + fn test_zero_ct() { + let spendable_balance = PodElGamalCT::zeroed(); + let spendable_ct: ElGamalCT = spendable_balance.try_into().unwrap(); + + // spendable_ct should be an encryption of 0 for any public key when + // `PedersenOpen::default()` is used + let (pk, _) = ElGamal::keygen(); + let balance: u64 = 0; + assert_eq!( + spendable_ct, + pk.encrypt_with(balance, &PedersenOpen::default()) + ); + + // homomorphism should work like any other ciphertext + let open = PedersenOpen::random(&mut OsRng); + let transfer_amount_ct = pk.encrypt_with(55_u64, &open); + let transfer_amount_pod: PodElGamalCT = transfer_amount_ct.into(); + + let sum = PodElGamalArithmetic::add(spendable_balance, transfer_amount_pod).unwrap(); + + let expected: PodElGamalCT = pk.encrypt_with(55_u64, &open).into(); + assert_eq!(expected, sum); + } + + #[test] + fn test_add_to() { + let spendable_balance = PodElGamalCT::zeroed(); + + let added_ct = PodElGamalArithmetic::add_to(spendable_balance, 55).unwrap(); + + let (pk, _) = ElGamal::keygen(); + let expected: PodElGamalCT = pk.encrypt_with(55_u64, &PedersenOpen::default()).into(); + + assert_eq!(expected, added_ct); + } + + #[test] + fn test_subtract_from() { + let amount = 77_u64; + let (pk, _) = ElGamal::keygen(); + let open = PedersenOpen::random(&mut OsRng); + let encrypted_amount: PodElGamalCT = pk.encrypt_with(amount, &open).into(); + + let subtracted_ct = PodElGamalArithmetic::subtract_from(encrypted_amount, 55).unwrap(); + + let expected: PodElGamalCT = pk.encrypt_with(22_u64, &open).into(); + + assert_eq!(expected, subtracted_ct); + } + + #[test] + fn test_pod_range_proof_64() { + let (comm, open) = Pedersen::commit(55_u64); + + let mut transcript_create = Transcript::new(b"Test"); + let mut transcript_verify = Transcript::new(b"Test"); + + let proof = RangeProof::create(vec![55], vec![64], vec![&open], &mut transcript_create); + + let proof_serialized: PodRangeProof64 = proof.try_into().unwrap(); + let proof_deserialized: RangeProof = proof_serialized.try_into().unwrap(); + + assert!(proof_deserialized + .verify( + vec![&comm.get_point().compress()], + vec![64], + &mut transcript_verify + ) + .is_ok()); + + // should fail to serialize to PodRangeProof128 + let proof = RangeProof::create(vec![55], vec![64], vec![&open], &mut transcript_create); + + assert!(TryInto::::try_into(proof).is_err()); + } + + #[test] + fn test_pod_range_proof_128() { + let (comm_1, open_1) = Pedersen::commit(55_u64); + let (comm_2, open_2) = Pedersen::commit(77_u64); + let (comm_3, open_3) = Pedersen::commit(99_u64); + + let mut transcript_create = Transcript::new(b"Test"); + let mut transcript_verify = Transcript::new(b"Test"); + + let proof = RangeProof::create( + vec![55, 77, 99], + vec![64, 32, 32], + vec![&open_1, &open_2, &open_3], + &mut transcript_create, + ); + + let comm_1_point = comm_1.get_point().compress(); + let comm_2_point = comm_2.get_point().compress(); + let comm_3_point = comm_3.get_point().compress(); + + let proof_serialized: PodRangeProof128 = proof.try_into().unwrap(); + let proof_deserialized: RangeProof = proof_serialized.try_into().unwrap(); + + assert!(proof_deserialized + .verify( + vec![&comm_1_point, &comm_2_point, &comm_3_point], + vec![64, 32, 32], + &mut transcript_verify, + ) + .is_ok()); + + // should fail to serialize to PodRangeProof64 + let proof = RangeProof::create( + vec![55, 77, 99], + vec![64, 32, 32], + vec![&open_1, &open_2, &open_3], + &mut transcript_create, + ); + + assert!(TryInto::::try_into(proof).is_err()); + } + + #[test] + fn test_transfer_arithmetic() { + // setup + + // transfer amount + let transfer_amount: u64 = 55; + let (amount_lo, amount_hi) = split_u64_into_u32(transfer_amount); + + // generate public keys + let (source_pk, _) = ElGamal::keygen(); + let (dest_pk, _) = ElGamal::keygen(); + let (auditor_pk, _) = ElGamal::keygen(); + + // commitments associated with TransferRangeProof + let (comm_lo, open_lo) = Pedersen::commit(amount_lo); + let (comm_hi, open_hi) = Pedersen::commit(amount_hi); + + let comm_lo: PodPedersenComm = comm_lo.into(); + let comm_hi: PodPedersenComm = comm_hi.into(); + + // decryption handles associated with TransferValidityProof + let handle_source_lo: PodPedersenDecHandle = source_pk.gen_decrypt_handle(&open_lo).into(); + let handle_dest_lo: PodPedersenDecHandle = dest_pk.gen_decrypt_handle(&open_lo).into(); + let _handle_auditor_lo: PodPedersenDecHandle = + auditor_pk.gen_decrypt_handle(&open_lo).into(); + + let handle_source_hi: PodPedersenDecHandle = source_pk.gen_decrypt_handle(&open_hi).into(); + let handle_dest_hi: PodPedersenDecHandle = dest_pk.gen_decrypt_handle(&open_hi).into(); + let _handle_auditor_hi: PodPedersenDecHandle = + auditor_pk.gen_decrypt_handle(&open_hi).into(); + + // source spendable and recipient pending + let source_open = PedersenOpen::random(&mut OsRng); + let dest_open = PedersenOpen::random(&mut OsRng); + + let source_spendable_ct: PodElGamalCT = source_pk.encrypt_with(77_u64, &source_open).into(); + let dest_pending_ct: PodElGamalCT = dest_pk.encrypt_with(77_u64, &dest_open).into(); + + // program arithmetic for the source account + + // 1. Combine commitments and handles + let source_lo_ct: PodElGamalCT = (comm_lo, handle_source_lo).into(); + let source_hi_ct: PodElGamalCT = (comm_hi, handle_source_hi).into(); + + // 2. Combine lo and hi ciphertexts + let source_combined_ct = + PodElGamalArithmetic::combine_lo_hi(source_lo_ct, source_hi_ct).unwrap(); + + // 3. Subtract from available balance + let final_source_spendable = + PodElGamalArithmetic::subtract(source_spendable_ct, source_combined_ct).unwrap(); + + // test + let final_source_open = source_open + - (open_lo.clone() + open_hi.clone() * Scalar::from(PodElGamalArithmetic::TWO_32)); + let expected_source: PodElGamalCT = + source_pk.encrypt_with(22_u64, &final_source_open).into(); + assert_eq!(expected_source, final_source_spendable); + + // same for the destination account + + // 1. Combine commitments and handles + let dest_lo_ct: PodElGamalCT = (comm_lo, handle_dest_lo).into(); + let dest_hi_ct: PodElGamalCT = (comm_hi, handle_dest_hi).into(); + + // 2. Combine lo and hi ciphertexts + let dest_combined_ct = PodElGamalArithmetic::combine_lo_hi(dest_lo_ct, dest_hi_ct).unwrap(); + + // 3. Add to pending balance + let final_dest_pending = + PodElGamalArithmetic::add(dest_pending_ct, dest_combined_ct).unwrap(); + + let final_dest_open = + dest_open + (open_lo + open_hi * Scalar::from(PodElGamalArithmetic::TWO_32)); + let expected_dest_ct: PodElGamalCT = dest_pk.encrypt_with(132_u64, &final_dest_open).into(); + assert_eq!(expected_dest_ct, final_dest_pending); + } +} diff --git a/zk-token-sdk/src/range_proof/generators.rs b/zk-token-sdk/src/range_proof/generators.rs new file mode 100644 index 000000000..f41f6fcd0 --- /dev/null +++ b/zk-token-sdk/src/range_proof/generators.rs @@ -0,0 +1,156 @@ +use { + curve25519_dalek::{ + digest::{ExtendableOutput, Update, XofReader}, + ristretto::RistrettoPoint, + }, + sha3::{Sha3XofReader, Shake256}, +}; + +/// Generators for Pedersen vector commitments. +/// +/// The code is copied from https://github.com/dalek-cryptography/bulletproofs for now... + +struct GeneratorsChain { + reader: Sha3XofReader, +} + +impl GeneratorsChain { + /// Creates a chain of generators, determined by the hash of `label`. + fn new(label: &[u8]) -> Self { + let mut shake = Shake256::default(); + shake.update(b"GeneratorsChain"); + shake.update(label); + + GeneratorsChain { + reader: shake.finalize_xof(), + } + } + + /// Advances the reader n times, squeezing and discarding + /// the result. + fn fast_forward(mut self, n: usize) -> Self { + for _ in 0..n { + let mut buf = [0u8; 64]; + self.reader.read(&mut buf); + } + self + } +} + +impl Default for GeneratorsChain { + fn default() -> Self { + Self::new(&[]) + } +} + +impl Iterator for GeneratorsChain { + type Item = RistrettoPoint; + + fn next(&mut self) -> Option { + let mut uniform_bytes = [0u8; 64]; + self.reader.read(&mut uniform_bytes); + + Some(RistrettoPoint::from_uniform_bytes(&uniform_bytes)) + } + + fn size_hint(&self) -> (usize, Option) { + (usize::max_value(), None) + } +} + +#[allow(non_snake_case)] +#[derive(Clone)] +pub struct BulletproofGens { + /// The maximum number of usable generators. + pub gens_capacity: usize, + /// Precomputed \\(\mathbf G\\) generators. + G_vec: Vec, + /// Precomputed \\(\mathbf H\\) generators. + H_vec: Vec, +} + +impl BulletproofGens { + pub fn new(gens_capacity: usize) -> Self { + let mut gens = BulletproofGens { + gens_capacity: 0, + G_vec: Vec::new(), + H_vec: Vec::new(), + }; + gens.increase_capacity(gens_capacity); + gens + } + + // pub fn new_aggregate(gens_capacities: Vec) -> Vec { + // let mut gens_vector = Vec::new(); + // for (capacity, i) in gens_capacities.iter().enumerate() { + // gens_vector.push(BulletproofGens::new(capacity, &i.to_le_bytes())); + // } + // gens_vector + // } + + /// Increases the generators' capacity to the amount specified. + /// If less than or equal to the current capacity, does nothing. + pub fn increase_capacity(&mut self, new_capacity: usize) { + if self.gens_capacity >= new_capacity { + return; + } + + let label = [b'G']; + self.G_vec.extend( + &mut GeneratorsChain::new(&[label, [b'G']].concat()) + .fast_forward(self.gens_capacity) + .take(new_capacity - self.gens_capacity), + ); + + self.H_vec.extend( + &mut GeneratorsChain::new(&[label, [b'H']].concat()) + .fast_forward(self.gens_capacity) + .take(new_capacity - self.gens_capacity), + ); + + self.gens_capacity = new_capacity; + } + + #[allow(non_snake_case)] + pub(crate) fn G(&self, n: usize) -> impl Iterator { + GensIter { + array: &self.G_vec, + n, + gen_idx: 0, + } + } + + #[allow(non_snake_case)] + pub(crate) fn H(&self, n: usize) -> impl Iterator { + GensIter { + array: &self.H_vec, + n, + gen_idx: 0, + } + } +} + +struct GensIter<'a> { + array: &'a Vec, + n: usize, + gen_idx: usize, +} + +impl<'a> Iterator for GensIter<'a> { + type Item = &'a RistrettoPoint; + + fn next(&mut self) -> Option { + if self.gen_idx >= self.n { + None + } else { + let cur_gen = self.gen_idx; + self.gen_idx += 1; + Some(&self.array[cur_gen]) + } + } + + fn size_hint(&self) -> (usize, Option) { + let size = self.n - self.gen_idx; + (size, Some(size)) + } +} diff --git a/zk-token-sdk/src/range_proof/inner_product.rs b/zk-token-sdk/src/range_proof/inner_product.rs new file mode 100644 index 000000000..f23633221 --- /dev/null +++ b/zk-token-sdk/src/range_proof/inner_product.rs @@ -0,0 +1,505 @@ +use core::iter; +use std::borrow::Borrow; + +use curve25519_dalek::ristretto::{CompressedRistretto, RistrettoPoint}; +use curve25519_dalek::scalar::Scalar; +use curve25519_dalek::traits::VartimeMultiscalarMul; + +use crate::errors::ProofError; +use crate::range_proof::util; +use crate::transcript::TranscriptProtocol; + +use merlin::Transcript; + +#[allow(non_snake_case)] +#[derive(Clone)] +pub struct InnerProductProof { + pub L_vec: Vec, // 32 * log(bit_length) + pub R_vec: Vec, // 32 * log(bit_length) + pub a: Scalar, // 32 bytes + pub b: Scalar, // 32 bytes +} + +#[allow(non_snake_case)] +impl InnerProductProof { + /// Create an inner-product proof. + /// + /// The proof is created with respect to the bases \\(G\\), \\(H'\\), + /// where \\(H'\_i = H\_i \cdot \texttt{Hprime\\_factors}\_i\\). + /// + /// The `verifier` is passed in as a parameter so that the + /// challenges depend on the *entire* transcript (including parent + /// protocols). + /// + /// The lengths of the vectors must all be the same, and must all be + /// either 0 or a power of 2. + #[allow(clippy::too_many_arguments)] + pub fn create( + Q: &RistrettoPoint, + G_factors: &[Scalar], + H_factors: &[Scalar], + mut G_vec: Vec, + mut H_vec: Vec, + mut a_vec: Vec, + mut b_vec: Vec, + transcript: &mut Transcript, + ) -> InnerProductProof { + // Create slices G, H, a, b backed by their respective + // vectors. This lets us reslice as we compress the lengths + // of the vectors in the main loop below. + let mut G = &mut G_vec[..]; + let mut H = &mut H_vec[..]; + let mut a = &mut a_vec[..]; + let mut b = &mut b_vec[..]; + + let mut n = G.len(); + + // All of the input vectors must have the same length. + assert_eq!(G.len(), n); + assert_eq!(H.len(), n); + assert_eq!(a.len(), n); + assert_eq!(b.len(), n); + assert_eq!(G_factors.len(), n); + assert_eq!(H_factors.len(), n); + + // All of the input vectors must have a length that is a power of two. + assert!(n.is_power_of_two()); + + transcript.innerproduct_domain_sep(n as u64); + + let lg_n = n.next_power_of_two().trailing_zeros() as usize; + let mut L_vec = Vec::with_capacity(lg_n); + let mut R_vec = Vec::with_capacity(lg_n); + + // If it's the first iteration, unroll the Hprime = H*y_inv scalar mults + // into multiscalar muls, for performance. + if n != 1 { + n /= 2; + let (a_L, a_R) = a.split_at_mut(n); + let (b_L, b_R) = b.split_at_mut(n); + let (G_L, G_R) = G.split_at_mut(n); + let (H_L, H_R) = H.split_at_mut(n); + + let c_L = util::inner_product(a_L, b_R); + let c_R = util::inner_product(a_R, b_L); + + let L = RistrettoPoint::vartime_multiscalar_mul( + a_L.iter() + .zip(G_factors[n..2 * n].iter()) + .map(|(a_L_i, g)| a_L_i * g) + .chain( + b_R.iter() + .zip(H_factors[0..n].iter()) + .map(|(b_R_i, h)| b_R_i * h), + ) + .chain(iter::once(c_L)), + G_R.iter().chain(H_L.iter()).chain(iter::once(Q)), + ) + .compress(); + + let R = RistrettoPoint::vartime_multiscalar_mul( + a_R.iter() + .zip(G_factors[0..n].iter()) + .map(|(a_R_i, g)| a_R_i * g) + .chain( + b_L.iter() + .zip(H_factors[n..2 * n].iter()) + .map(|(b_L_i, h)| b_L_i * h), + ) + .chain(iter::once(c_R)), + G_L.iter().chain(H_R.iter()).chain(iter::once(Q)), + ) + .compress(); + + L_vec.push(L); + R_vec.push(R); + + transcript.append_point(b"L", &L); + transcript.append_point(b"R", &R); + + let u = transcript.challenge_scalar(b"u"); + let u_inv = u.invert(); + + for i in 0..n { + a_L[i] = a_L[i] * u + u_inv * a_R[i]; + b_L[i] = b_L[i] * u_inv + u * b_R[i]; + G_L[i] = RistrettoPoint::vartime_multiscalar_mul( + &[u_inv * G_factors[i], u * G_factors[n + i]], + &[G_L[i], G_R[i]], + ); + H_L[i] = RistrettoPoint::vartime_multiscalar_mul( + &[u * H_factors[i], u_inv * H_factors[n + i]], + &[H_L[i], H_R[i]], + ) + } + + a = a_L; + b = b_L; + G = G_L; + H = H_L; + } + + while n != 1 { + n /= 2; + let (a_L, a_R) = a.split_at_mut(n); + let (b_L, b_R) = b.split_at_mut(n); + let (G_L, G_R) = G.split_at_mut(n); + let (H_L, H_R) = H.split_at_mut(n); + + let c_L = util::inner_product(a_L, b_R); + let c_R = util::inner_product(a_R, b_L); + + let L = RistrettoPoint::vartime_multiscalar_mul( + a_L.iter().chain(b_R.iter()).chain(iter::once(&c_L)), + G_R.iter().chain(H_L.iter()).chain(iter::once(Q)), + ) + .compress(); + + let R = RistrettoPoint::vartime_multiscalar_mul( + a_R.iter().chain(b_L.iter()).chain(iter::once(&c_R)), + G_L.iter().chain(H_R.iter()).chain(iter::once(Q)), + ) + .compress(); + + L_vec.push(L); + R_vec.push(R); + + transcript.append_point(b"L", &L); + transcript.append_point(b"R", &R); + + let u = transcript.challenge_scalar(b"u"); + let u_inv = u.invert(); + + for i in 0..n { + a_L[i] = a_L[i] * u + u_inv * a_R[i]; + b_L[i] = b_L[i] * u_inv + u * b_R[i]; + G_L[i] = RistrettoPoint::vartime_multiscalar_mul(&[u_inv, u], &[G_L[i], G_R[i]]); + H_L[i] = RistrettoPoint::vartime_multiscalar_mul(&[u, u_inv], &[H_L[i], H_R[i]]); + } + + a = a_L; + b = b_L; + G = G_L; + H = H_L; + } + + InnerProductProof { + L_vec, + R_vec, + a: a[0], + b: b[0], + } + } + + /// Computes three vectors of verification scalars \\([u\_{i}^{2}]\\), \\([u\_{i}^{-2}]\\) and + /// \\([s\_{i}]\\) for combined multiscalar multiplication in a parent protocol. See [inner + /// product protocol notes](index.html#verification-equation) for details. The verifier must + /// provide the input length \\(n\\) explicitly to avoid unbounded allocation within the inner + /// product proof. + #[allow(clippy::type_complexity)] + pub(crate) fn verification_scalars( + &self, + n: usize, + transcript: &mut Transcript, + ) -> Result<(Vec, Vec, Vec), ProofError> { + let lg_n = self.L_vec.len(); + if lg_n >= 32 { + // 4 billion multiplications should be enough for anyone + // and this check prevents overflow in 1<( + &self, + n: usize, + G_factors: IG, + H_factors: IH, + P: &RistrettoPoint, + Q: &RistrettoPoint, + G: &[RistrettoPoint], + H: &[RistrettoPoint], + transcript: &mut Transcript, + ) -> Result<(), ProofError> + where + IG: IntoIterator, + IG::Item: Borrow, + IH: IntoIterator, + IH::Item: Borrow, + { + let (u_sq, u_inv_sq, s) = self.verification_scalars(n, transcript)?; + + let g_times_a_times_s = G_factors + .into_iter() + .zip(s.iter()) + .map(|(g_i, s_i)| (self.a * s_i) * g_i.borrow()) + .take(G.len()); + + // 1/s[i] is s[!i], and !i runs from n-1 to 0 as i runs from 0 to n-1 + let inv_s = s.iter().rev(); + + let h_times_b_div_s = H_factors + .into_iter() + .zip(inv_s) + .map(|(h_i, s_i_inv)| (self.b * s_i_inv) * h_i.borrow()); + + let neg_u_sq = u_sq.iter().map(|ui| -ui); + let neg_u_inv_sq = u_inv_sq.iter().map(|ui| -ui); + + let Ls = self + .L_vec + .iter() + .map(|p| p.decompress().ok_or(ProofError::VerificationError)) + .collect::, _>>()?; + + let Rs = self + .R_vec + .iter() + .map(|p| p.decompress().ok_or(ProofError::VerificationError)) + .collect::, _>>()?; + + let expect_P = RistrettoPoint::vartime_multiscalar_mul( + iter::once(self.a * self.b) + .chain(g_times_a_times_s) + .chain(h_times_b_div_s) + .chain(neg_u_sq) + .chain(neg_u_inv_sq), + iter::once(Q) + .chain(G.iter()) + .chain(H.iter()) + .chain(Ls.iter()) + .chain(Rs.iter()), + ); + + if expect_P == *P { + Ok(()) + } else { + Err(ProofError::VerificationError) + } + } + + /// Returns the size in bytes required to serialize the inner + /// product proof. + /// + /// For vectors of length `n` the proof size is + /// \\(32 \cdot (2\lg n+2)\\) bytes. + pub fn serialized_size(&self) -> usize { + (self.L_vec.len() * 2 + 2) * 32 + } + + /// Serializes the proof into a byte array of \\(2n+2\\) 32-byte elements. + /// The layout of the inner product proof is: + /// * \\(n\\) pairs of compressed Ristretto points \\(L_0, R_0 \dots, L_{n-1}, R_{n-1}\\), + /// * two scalars \\(a, b\\). + pub fn to_bytes(&self) -> Vec { + let mut buf = Vec::with_capacity(self.serialized_size()); + for (l, r) in self.L_vec.iter().zip(self.R_vec.iter()) { + buf.extend_from_slice(l.as_bytes()); + buf.extend_from_slice(r.as_bytes()); + } + buf.extend_from_slice(self.a.as_bytes()); + buf.extend_from_slice(self.b.as_bytes()); + buf + } + + // pub fn to_bytes_64(&self) -> Result { + // let mut bytes = [0u8; 448]; + + // self.L_vec.iter().chain(self.R_vec.iter()).enumerate().for_each( + // |(i, x)| bytes[i*32..(i+1)*32].copy_from_slice(x.as_bytes()) + // ); + // bytes[384..416].copy_from_slice(self.a.as_bytes()); + // bytes[416..448].copy_from_slice(self.a.as_bytes()); + // Ok(InnerProductProof64(bytes)) + // } + + /* + /// Converts the proof into a byte iterator over serialized view of the proof. + /// The layout of the inner product proof is: + /// * \\(n\\) pairs of compressed Ristretto points \\(L_0, R_0 \dots, L_{n-1}, R_{n-1}\\), + /// * two scalars \\(a, b\\). + #[inline] + pub(crate) fn to_bytes_iter(&self) -> impl Iterator + '_ { + self.L_vec + .iter() + .zip(self.R_vec.iter()) + .flat_map(|(l, r)| l.as_bytes().iter().chain(r.as_bytes())) + .chain(self.a.as_bytes()) + .chain(self.b.as_bytes()) + .copied() + } + */ + + /// Deserializes the proof from a byte slice. + /// Returns an error in the following cases: + /// * the slice does not have \\(2n+2\\) 32-byte elements, + /// * \\(n\\) is larger or equal to 32 (proof is too big), + /// * any of \\(2n\\) points are not valid compressed Ristretto points, + /// * any of 2 scalars are not canonical scalars modulo Ristretto group order. + pub fn from_bytes(slice: &[u8]) -> Result { + let b = slice.len(); + if b % 32 != 0 { + return Err(ProofError::FormatError); + } + let num_elements = b / 32; + if num_elements < 2 { + return Err(ProofError::FormatError); + } + if (num_elements - 2) % 2 != 0 { + return Err(ProofError::FormatError); + } + let lg_n = (num_elements - 2) / 2; + if lg_n >= 32 { + return Err(ProofError::FormatError); + } + + let mut L_vec: Vec = Vec::with_capacity(lg_n); + let mut R_vec: Vec = Vec::with_capacity(lg_n); + for i in 0..lg_n { + let pos = 2 * i * 32; + L_vec.push(CompressedRistretto(util::read32(&slice[pos..]))); + R_vec.push(CompressedRistretto(util::read32(&slice[pos + 32..]))); + } + + let pos = 2 * lg_n * 32; + let a = Scalar::from_canonical_bytes(util::read32(&slice[pos..])) + .ok_or(ProofError::FormatError)?; + let b = Scalar::from_canonical_bytes(util::read32(&slice[pos + 32..])) + .ok_or(ProofError::FormatError)?; + + Ok(InnerProductProof { L_vec, R_vec, a, b }) + } +} + +#[cfg(test)] +mod tests { + use { + super::*, crate::range_proof::generators::BulletproofGens, rand::rngs::OsRng, + sha3::Sha3_512, + }; + + #[test] + #[ignore] + #[allow(non_snake_case)] + fn test_basic_correctness() { + let n = 32; + + let bp_gens = BulletproofGens::new(n); + let G: Vec = bp_gens.G(n).cloned().collect(); + let H: Vec = bp_gens.H(n).cloned().collect(); + + let Q = RistrettoPoint::hash_from_bytes::(b"test point"); + + let a: Vec<_> = (0..n).map(|_| Scalar::random(&mut OsRng)).collect(); + let b: Vec<_> = (0..n).map(|_| Scalar::random(&mut OsRng)).collect(); + let c = util::inner_product(&a, &b); + + let G_factors: Vec = iter::repeat(Scalar::one()).take(n).collect(); + + let y_inv = Scalar::random(&mut OsRng); + let H_factors: Vec = util::exp_iter(y_inv).take(n).collect(); + + // P would be determined upstream, but we need a correct P to check the proof. + // + // To generate P = + + Q, compute + // P = + + Q, + // where b' = b \circ y^(-n) + let b_prime = b.iter().zip(util::exp_iter(y_inv)).map(|(bi, yi)| bi * yi); + // a.iter() has Item=&Scalar, need Item=Scalar to chain with b_prime + let a_prime = a.iter().cloned(); + + let P = RistrettoPoint::vartime_multiscalar_mul( + a_prime.chain(b_prime).chain(iter::once(c)), + G.iter().chain(H.iter()).chain(iter::once(&Q)), + ); + + let mut prover_transcript = Transcript::new(b"innerproducttest"); + let mut verifier_transcript = Transcript::new(b"innerproducttest"); + + let proof = InnerProductProof::create( + &Q, + &G_factors, + &H_factors, + G.clone(), + H.clone(), + a.clone(), + b.clone(), + &mut prover_transcript, + ); + + assert!(proof + .verify( + n, + iter::repeat(Scalar::one()).take(n), + util::exp_iter(y_inv).take(n), + &P, + &Q, + &G, + &H, + &mut verifier_transcript, + ) + .is_ok()); + + let proof = InnerProductProof::from_bytes(proof.to_bytes().as_slice()).unwrap(); + let mut verifier_transcript = Transcript::new(b"innerproducttest"); + assert!(proof + .verify( + n, + iter::repeat(Scalar::one()).take(n), + util::exp_iter(y_inv).take(n), + &P, + &Q, + &G, + &H, + &mut verifier_transcript, + ) + .is_ok()); + } +} diff --git a/zk-token-sdk/src/range_proof/mod.rs b/zk-token-sdk/src/range_proof/mod.rs new file mode 100644 index 000000000..c377e2910 --- /dev/null +++ b/zk-token-sdk/src/range_proof/mod.rs @@ -0,0 +1,456 @@ +#[cfg(not(target_arch = "bpf"))] +use { + crate::encryption::pedersen::{Pedersen, PedersenOpen}, + curve25519_dalek::traits::MultiscalarMul, + rand::rngs::OsRng, + subtle::{Choice, ConditionallySelectable}, +}; +use { + crate::{ + encryption::pedersen::PedersenBase, errors::ProofError, + range_proof::generators::BulletproofGens, range_proof::inner_product::InnerProductProof, + transcript::TranscriptProtocol, + }, + core::iter, + curve25519_dalek::{ + ristretto::{CompressedRistretto, RistrettoPoint}, + scalar::Scalar, + traits::{IsIdentity, VartimeMultiscalarMul}, + }, + merlin::Transcript, +}; + +pub mod generators; +pub mod inner_product; +pub mod util; + +#[allow(non_snake_case)] +#[derive(Clone)] +pub struct RangeProof { + pub A: CompressedRistretto, // 32 bytes + pub S: CompressedRistretto, // 32 bytes + pub T_1: CompressedRistretto, // 32 bytes + pub T_2: CompressedRistretto, // 32 bytes + pub t_x: Scalar, // 32 bytes + pub t_x_blinding: Scalar, // 32 bytes + pub e_blinding: Scalar, // 32 bytes + pub ipp_proof: InnerProductProof, // 448 bytes for withdraw; 512 for transfer +} + +#[allow(non_snake_case)] +impl RangeProof { + #[allow(clippy::many_single_char_names)] + #[cfg(not(target_arch = "bpf"))] + pub fn create( + amounts: Vec, + bit_lengths: Vec, + opens: Vec<&PedersenOpen>, + transcript: &mut Transcript, + ) -> Self { + let t_1_blinding = PedersenOpen::random(&mut OsRng); + let t_2_blinding = PedersenOpen::random(&mut OsRng); + + let (range_proof, _, _) = Self::create_with( + amounts, + bit_lengths, + opens, + &t_1_blinding, + &t_2_blinding, + transcript, + ); + + range_proof + } + + #[allow(clippy::many_single_char_names)] + #[cfg(not(target_arch = "bpf"))] + pub fn create_with( + amounts: Vec, + bit_lengths: Vec, + opens: Vec<&PedersenOpen>, + t_1_blinding: &PedersenOpen, + t_2_blinding: &PedersenOpen, + transcript: &mut Transcript, + ) -> (Self, Scalar, Scalar) { + let nm = bit_lengths.iter().sum(); + + // Computing the generators online for now. It should ultimately be precomputed. + let bp_gens = BulletproofGens::new(nm); + let G = PedersenBase::default().G; + let H = PedersenBase::default().H; + + // bit-decompose values and commit to the bits + let a_blinding = Scalar::random(&mut OsRng); + let mut A = a_blinding * H; + + let mut gens_iter = bp_gens.G(nm).zip(bp_gens.H(nm)); + for (amount_i, n_i) in amounts.iter().zip(bit_lengths.iter()) { + for j in 0..(*n_i) { + let (G_ij, H_ij) = gens_iter.next().unwrap(); + let v_ij = Choice::from(((amount_i >> j) & 1) as u8); + let mut point = -H_ij; + point.conditional_assign(G_ij, v_ij); + A += point; + } + } + + // generate blinding factors and commit as vectors + let s_blinding = Scalar::random(&mut OsRng); + + let s_L: Vec = (0..nm).map(|_| Scalar::random(&mut OsRng)).collect(); + let s_R: Vec = (0..nm).map(|_| Scalar::random(&mut OsRng)).collect(); + + let S = RistrettoPoint::multiscalar_mul( + iter::once(&s_blinding).chain(s_L.iter()).chain(s_R.iter()), + iter::once(&H).chain(bp_gens.G(nm)).chain(bp_gens.H(nm)), + ); + + transcript.append_point(b"A", &A.compress()); + transcript.append_point(b"S", &S.compress()); + + // commit to T1 and T2 + let y = transcript.challenge_scalar(b"y"); + let z = transcript.challenge_scalar(b"z"); + + let mut l_poly = util::VecPoly1::zero(nm); + let mut r_poly = util::VecPoly1::zero(nm); + + let mut i = 0; + let mut exp_z = z * z; + let mut exp_y = Scalar::one(); + for (amount_i, n_i) in amounts.iter().zip(bit_lengths.iter()) { + let mut exp_2 = Scalar::one(); + + for j in 0..(*n_i) { + let a_L_j = Scalar::from((amount_i >> j) & 1); + let a_R_j = a_L_j - Scalar::one(); + + l_poly.0[i] = a_L_j - z; + l_poly.1[i] = s_L[i]; + r_poly.0[i] = exp_y * (a_R_j + z) + exp_z * exp_2; + r_poly.1[i] = exp_y * s_R[i]; + + exp_y *= y; + exp_2 = exp_2 + exp_2; + i += 1; + } + exp_z *= z; + } + + let t_poly = l_poly.inner_product(&r_poly); + + let T_1 = Pedersen::commit_with(t_poly.1, t_1_blinding) + .get_point() + .compress(); + let T_2 = Pedersen::commit_with(t_poly.2, t_2_blinding) + .get_point() + .compress(); + + transcript.append_point(b"T_1", &T_1); + transcript.append_point(b"T_2", &T_2); + + let x = transcript.challenge_scalar(b"x"); + + let mut agg_open = Scalar::zero(); + let mut exp_z = z * z; + for open in opens { + agg_open += exp_z * open.get_scalar(); + exp_z *= z; + } + + let t_blinding_poly = util::Poly2( + agg_open, + t_1_blinding.get_scalar(), + t_2_blinding.get_scalar(), + ); + + // compute t_x + let t_x = t_poly.eval(x); + let t_x_blinding = t_blinding_poly.eval(x); + + let e_blinding = a_blinding + s_blinding * x; + let l_vec = l_poly.eval(x); + let r_vec = r_poly.eval(x); + + transcript.append_scalar(b"t_x", &t_x); + transcript.append_scalar(b"t_x_blinding", &t_x_blinding); + transcript.append_scalar(b"e_blinding", &e_blinding); + + let w = transcript.challenge_scalar(b"w"); + let Q = w * G; + + transcript.challenge_scalar(b"c"); + + let G_factors: Vec = iter::repeat(Scalar::one()).take(nm).collect(); + let H_factors: Vec = util::exp_iter(y.invert()).take(nm).collect(); + + let ipp_proof = InnerProductProof::create( + &Q, + &G_factors, + &H_factors, + bp_gens.G(nm).cloned().collect(), + bp_gens.H(nm).cloned().collect(), + l_vec, + r_vec, + transcript, + ); + + let range_proof = RangeProof { + A: A.compress(), + S: S.compress(), + T_1, + T_2, + t_x, + t_x_blinding, + e_blinding, + ipp_proof, + }; + + (range_proof, x, z) + } + + #[allow(clippy::many_single_char_names)] + pub fn verify( + &self, + comms: Vec<&CompressedRistretto>, + bit_lengths: Vec, + transcript: &mut Transcript, + ) -> Result<(), ProofError> { + self.verify_with(comms, bit_lengths, None, None, transcript) + } + + #[allow(clippy::many_single_char_names)] + pub fn verify_with( + &self, + comms: Vec<&CompressedRistretto>, + bit_lengths: Vec, + x_ver: Option, + z_ver: Option, + transcript: &mut Transcript, + ) -> Result<(), ProofError> { + let G = PedersenBase::default().G; + let H = PedersenBase::default().H; + + let m = bit_lengths.len(); + let nm: usize = bit_lengths.iter().sum(); + let bp_gens = BulletproofGens::new(nm); + + if !(nm == 8 || nm == 16 || nm == 32 || nm == 64 || nm == 128) { + return Err(ProofError::InvalidBitsize); + } + + transcript.validate_and_append_point(b"A", &self.A)?; + transcript.validate_and_append_point(b"S", &self.S)?; + + let y = transcript.challenge_scalar(b"y"); + let z = transcript.challenge_scalar(b"z"); + + if z_ver.is_some() && z_ver.unwrap() != z { + return Err(ProofError::VerificationError); + } + + let zz = z * z; + let minus_z = -z; + + transcript.validate_and_append_point(b"T_1", &self.T_1)?; + transcript.validate_and_append_point(b"T_2", &self.T_2)?; + + let x = transcript.challenge_scalar(b"x"); + + if x_ver.is_some() && x_ver.unwrap() != x { + return Err(ProofError::VerificationError); + } + + transcript.append_scalar(b"t_x", &self.t_x); + transcript.append_scalar(b"t_x_blinding", &self.t_x_blinding); + transcript.append_scalar(b"e_blinding", &self.e_blinding); + + let w = transcript.challenge_scalar(b"w"); + + // Challenge value for batching statements to be verified + let c = transcript.challenge_scalar(b"c"); + + let (x_sq, x_inv_sq, s) = self.ipp_proof.verification_scalars(nm, transcript)?; + let s_inv = s.iter().rev(); + + let a = self.ipp_proof.a; + let b = self.ipp_proof.b; + + // Construct concat_z_and_2, an iterator of the values of + // z^0 * \vec(2)^n || z^1 * \vec(2)^n || ... || z^(m-1) * \vec(2)^n + let concat_z_and_2: Vec = util::exp_iter(z) + .zip(bit_lengths.iter()) + .flat_map(|(exp_z, n_i)| { + util::exp_iter(Scalar::from(2u64)) + .take(*n_i) + .map(move |exp_2| exp_2 * exp_z) + }) + .collect(); + + let gs = s.iter().map(|s_i| minus_z - a * s_i); + let hs = s_inv + .clone() + .zip(util::exp_iter(y.invert())) + .zip(concat_z_and_2.iter()) + .map(|((s_i_inv, exp_y_inv), z_and_2)| z + exp_y_inv * (zz * z_and_2 - b * s_i_inv)); + + let basepoint_scalar = + w * (self.t_x - a * b) + c * (delta(&bit_lengths, &y, &z) - self.t_x); + let value_commitment_scalars = util::exp_iter(z).take(m).map(|z_exp| c * zz * z_exp); + + let mega_check = RistrettoPoint::optional_multiscalar_mul( + iter::once(Scalar::one()) + .chain(iter::once(x)) + .chain(iter::once(c * x)) + .chain(iter::once(c * x * x)) + .chain(iter::once(-self.e_blinding - c * self.t_x_blinding)) + .chain(iter::once(basepoint_scalar)) + .chain(x_sq.iter().cloned()) + .chain(x_inv_sq.iter().cloned()) + .chain(gs) + .chain(hs) + .chain(value_commitment_scalars), + iter::once(self.A.decompress()) + .chain(iter::once(self.S.decompress())) + .chain(iter::once(self.T_1.decompress())) + .chain(iter::once(self.T_2.decompress())) + .chain(iter::once(Some(H))) + .chain(iter::once(Some(G))) + .chain(self.ipp_proof.L_vec.iter().map(|L| L.decompress())) + .chain(self.ipp_proof.R_vec.iter().map(|R| R.decompress())) + .chain(bp_gens.G(nm).map(|&x| Some(x))) + .chain(bp_gens.H(nm).map(|&x| Some(x))) + .chain(comms.iter().map(|V| V.decompress())), + ) + .ok_or(ProofError::VerificationError)?; + + if mega_check.is_identity() { + Ok(()) + } else { + Err(ProofError::VerificationError) + } + } + + // Following the dalek rangeproof library signature for now. The exact method signature can be + // changed. + pub fn to_bytes(&self) -> Vec { + let mut buf = Vec::with_capacity(7 * 32 + self.ipp_proof.serialized_size()); + buf.extend_from_slice(self.A.as_bytes()); + buf.extend_from_slice(self.S.as_bytes()); + buf.extend_from_slice(self.T_1.as_bytes()); + buf.extend_from_slice(self.T_2.as_bytes()); + buf.extend_from_slice(self.t_x.as_bytes()); + buf.extend_from_slice(self.t_x_blinding.as_bytes()); + buf.extend_from_slice(self.e_blinding.as_bytes()); + buf.extend_from_slice(&self.ipp_proof.to_bytes()); + buf + } + + // Following the dalek rangeproof library signature for now. The exact method signature can be + // changed. + pub fn from_bytes(slice: &[u8]) -> Result { + if slice.len() % 32 != 0 { + return Err(ProofError::FormatError); + } + if slice.len() < 7 * 32 { + return Err(ProofError::FormatError); + } + + let A = CompressedRistretto(util::read32(&slice[0..])); + let S = CompressedRistretto(util::read32(&slice[32..])); + let T_1 = CompressedRistretto(util::read32(&slice[2 * 32..])); + let T_2 = CompressedRistretto(util::read32(&slice[3 * 32..])); + + let t_x = Scalar::from_canonical_bytes(util::read32(&slice[4 * 32..])) + .ok_or(ProofError::FormatError)?; + let t_x_blinding = Scalar::from_canonical_bytes(util::read32(&slice[5 * 32..])) + .ok_or(ProofError::FormatError)?; + let e_blinding = Scalar::from_canonical_bytes(util::read32(&slice[6 * 32..])) + .ok_or(ProofError::FormatError)?; + + let ipp_proof = InnerProductProof::from_bytes(&slice[7 * 32..])?; + + Ok(RangeProof { + A, + S, + T_1, + T_2, + t_x, + t_x_blinding, + e_blinding, + ipp_proof, + }) + } +} + +/// Compute +/// \\[ +/// \delta(y,z) = (z - z^{2}) \langle \mathbf{1}, {\mathbf{y}}^{n \cdot m} \rangle - \sum_{j=0}^{m-1} z^{j+3} \cdot \langle \mathbf{1}, {\mathbf{2}}^{n \cdot m} \rangle +/// \\] +fn delta(bit_lengths: &[usize], y: &Scalar, z: &Scalar) -> Scalar { + let nm: usize = bit_lengths.iter().sum(); + let sum_y = util::sum_of_powers(y, nm); + + let mut agg_delta = (z - z * z) * sum_y; + let mut exp_z = z * z * z; + for n_i in bit_lengths.iter() { + let sum_2 = util::sum_of_powers(&Scalar::from(2u64), *n_i); + agg_delta -= exp_z * sum_2; + exp_z *= z; + } + agg_delta +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_single_rangeproof() { + let (comm, open) = Pedersen::commit(55_u64); + + let mut transcript_create = Transcript::new(b"Test"); + let mut transcript_verify = Transcript::new(b"Test"); + + let proof = RangeProof::create(vec![55], vec![32], vec![&open], &mut transcript_create); + + assert!(proof + .verify( + vec![&comm.get_point().compress()], + vec![32], + &mut transcript_verify + ) + .is_ok()); + } + + #[test] + fn test_aggregated_rangeproof() { + let (comm_1, open_1) = Pedersen::commit(55_u64); + let (comm_2, open_2) = Pedersen::commit(77_u64); + let (comm_3, open_3) = Pedersen::commit(99_u64); + + let mut transcript_create = Transcript::new(b"Test"); + let mut transcript_verify = Transcript::new(b"Test"); + + let proof = RangeProof::create( + vec![55, 77, 99], + vec![64, 32, 32], + vec![&open_1, &open_2, &open_3], + &mut transcript_create, + ); + + let comm_1_point = comm_1.get_point().compress(); + let comm_2_point = comm_2.get_point().compress(); + let comm_3_point = comm_3.get_point().compress(); + + assert!(proof + .verify( + vec![&comm_1_point, &comm_2_point, &comm_3_point], + vec![64, 32, 32], + &mut transcript_verify, + ) + .is_ok()); + } + + // TODO: write test for serialization/deserialization +} diff --git a/zk-token-sdk/src/range_proof/util.rs b/zk-token-sdk/src/range_proof/util.rs new file mode 100644 index 000000000..c551abd8f --- /dev/null +++ b/zk-token-sdk/src/range_proof/util.rs @@ -0,0 +1,138 @@ +/// Utility functions for Bulletproofs. +/// +/// The code is copied from https://github.com/dalek-cryptography/bulletproofs for now... +use curve25519_dalek::scalar::Scalar; + +/// Represents a degree-1 vector polynomial \\(\mathbf{a} + \mathbf{b} \cdot x\\). +pub struct VecPoly1(pub Vec, pub Vec); + +impl VecPoly1 { + pub fn zero(n: usize) -> Self { + VecPoly1(vec![Scalar::zero(); n], vec![Scalar::zero(); n]) + } + + pub fn inner_product(&self, rhs: &VecPoly1) -> Poly2 { + // Uses Karatsuba's method + let l = self; + let r = rhs; + + let t0 = inner_product(&l.0, &r.0); + let t2 = inner_product(&l.1, &r.1); + + let l0_plus_l1 = add_vec(&l.0, &l.1); + let r0_plus_r1 = add_vec(&r.0, &r.1); + + let t1 = inner_product(&l0_plus_l1, &r0_plus_r1) - t0 - t2; + + Poly2(t0, t1, t2) + } + + pub fn eval(&self, x: Scalar) -> Vec { + let n = self.0.len(); + let mut out = vec![Scalar::zero(); n]; + #[allow(clippy::needless_range_loop)] + for i in 0..n { + out[i] = self.0[i] + self.1[i] * x; + } + out + } +} + +/// Represents a degree-2 scalar polynomial \\(a + b \cdot x + c \cdot x^2\\) +pub struct Poly2(pub Scalar, pub Scalar, pub Scalar); + +impl Poly2 { + pub fn eval(&self, x: Scalar) -> Scalar { + self.0 + x * (self.1 + x * self.2) + } +} + +/// Provides an iterator over the powers of a `Scalar`. +/// +/// This struct is created by the `exp_iter` function. +pub struct ScalarExp { + x: Scalar, + next_exp_x: Scalar, +} + +impl Iterator for ScalarExp { + type Item = Scalar; + + fn next(&mut self) -> Option { + let exp_x = self.next_exp_x; + self.next_exp_x *= self.x; + Some(exp_x) + } + + fn size_hint(&self) -> (usize, Option) { + (usize::max_value(), None) + } +} + +/// Return an iterator of the powers of `x`. +pub fn exp_iter(x: Scalar) -> ScalarExp { + let next_exp_x = Scalar::one(); + ScalarExp { x, next_exp_x } +} + +pub fn add_vec(a: &[Scalar], b: &[Scalar]) -> Vec { + if a.len() != b.len() { + // throw some error + //println!("lengths of vectors don't match for vector addition"); + } + let mut out = vec![Scalar::zero(); b.len()]; + for i in 0..a.len() { + out[i] = a[i] + b[i]; + } + out +} + +/// Given `data` with `len >= 32`, return the first 32 bytes. +pub fn read32(data: &[u8]) -> [u8; 32] { + let mut buf32 = [0u8; 32]; + buf32[..].copy_from_slice(&data[..32]); + buf32 +} + +/// Computes an inner product of two vectors +/// \\[ +/// {\langle {\mathbf{a}}, {\mathbf{b}} \rangle} = \sum\_{i=0}^{n-1} a\_i \cdot b\_i. +/// \\] +/// Panics if the lengths of \\(\mathbf{a}\\) and \\(\mathbf{b}\\) are not equal. +pub fn inner_product(a: &[Scalar], b: &[Scalar]) -> Scalar { + let mut out = Scalar::zero(); + if a.len() != b.len() { + panic!("inner_product(a,b): lengths of vectors do not match"); + } + for i in 0..a.len() { + out += a[i] * b[i]; + } + out +} + +/// Takes the sum of all the powers of `x`, up to `n` +/// If `n` is a power of 2, it uses the efficient algorithm with `2*lg n` multiplications and additions. +/// If `n` is not a power of 2, it uses the slow algorithm with `n` multiplications and additions. +/// In the Bulletproofs case, all calls to `sum_of_powers` should have `n` as a power of 2. +pub fn sum_of_powers(x: &Scalar, n: usize) -> Scalar { + if !n.is_power_of_two() { + return sum_of_powers_slow(x, n); + } + if n == 0 || n == 1 { + return Scalar::from(n as u64); + } + let mut m = n; + let mut result = Scalar::one() + x; + let mut factor = *x; + while m > 2 { + factor = factor * factor; + result = result + factor * result; + m /= 2; + } + result +} + +// takes the sum of all of the powers of x, up to n +fn sum_of_powers_slow(x: &Scalar, n: usize) -> Scalar { + exp_iter(*x).take(n).sum() +} diff --git a/zk-token-sdk/src/transcript.rs b/zk-token-sdk/src/transcript.rs new file mode 100644 index 000000000..7e3108d1b --- /dev/null +++ b/zk-token-sdk/src/transcript.rs @@ -0,0 +1,117 @@ +use curve25519_dalek::ristretto::CompressedRistretto; +use curve25519_dalek::scalar::Scalar; +use curve25519_dalek::traits::IsIdentity; + +use merlin::Transcript; + +use crate::errors::ProofError; + +pub trait TranscriptProtocol { + /// Append a domain separator for an `n`-bit rangeproof for ElGamal + /// ciphertext using a decryption key + fn rangeproof_from_key_domain_sep(&mut self, n: u64); + + /// Append a domain separator for an `n`-bit rangeproof for ElGamal + /// ciphertext using an opening + fn rangeproof_from_opening_domain_sep(&mut self, n: u64); + + /// Append a domain separator for a length-`n` inner product proof. + fn innerproduct_domain_sep(&mut self, n: u64); + + /// Append a domain separator for close account proof. + fn close_account_proof_domain_sep(&mut self); + + /// Append a domain separator for update account public key proof. + fn update_account_public_key_proof_domain_sep(&mut self); + + /// Append a domain separator for withdraw proof. + fn withdraw_proof_domain_sep(&mut self); + + /// Append a domain separator for transfer with range proof. + fn transfer_range_proof_sep(&mut self); + + /// Append a domain separator for transfer with validity proof. + fn transfer_validity_proof_sep(&mut self); + + /// Append a `scalar` with the given `label`. + fn append_scalar(&mut self, label: &'static [u8], scalar: &Scalar); + + /// Append a `point` with the given `label`. + fn append_point(&mut self, label: &'static [u8], point: &CompressedRistretto); + + /// Check that a point is not the identity, then append it to the + /// transcript. Otherwise, return an error. + fn validate_and_append_point( + &mut self, + label: &'static [u8], + point: &CompressedRistretto, + ) -> Result<(), ProofError>; + + /// Compute a `label`ed challenge variable. + fn challenge_scalar(&mut self, label: &'static [u8]) -> Scalar; +} + +impl TranscriptProtocol for Transcript { + fn rangeproof_from_key_domain_sep(&mut self, n: u64) { + self.append_message(b"dom-sep", b"rangeproof from opening v1"); + self.append_u64(b"n", n); + } + + fn rangeproof_from_opening_domain_sep(&mut self, n: u64) { + self.append_message(b"dom-sep", b"rangeproof from opening v1"); + self.append_u64(b"n", n); + } + + fn innerproduct_domain_sep(&mut self, n: u64) { + self.append_message(b"dom-sep", b"ipp v1"); + self.append_u64(b"n", n); + } + + fn close_account_proof_domain_sep(&mut self) { + self.append_message(b"dom_sep", b"CloseAccountProof"); + } + + fn update_account_public_key_proof_domain_sep(&mut self) { + self.append_message(b"dom_sep", b"UpdateAccountPublicKeyProof"); + } + + fn withdraw_proof_domain_sep(&mut self) { + self.append_message(b"dom_sep", b"WithdrawProof"); + } + + fn transfer_range_proof_sep(&mut self) { + self.append_message(b"dom_sep", b"TransferRangeProof"); + } + + fn transfer_validity_proof_sep(&mut self) { + self.append_message(b"dom_sep", b"TransferValidityProof"); + } + + fn append_scalar(&mut self, label: &'static [u8], scalar: &Scalar) { + self.append_message(label, scalar.as_bytes()); + } + + fn append_point(&mut self, label: &'static [u8], point: &CompressedRistretto) { + self.append_message(label, point.as_bytes()); + } + + fn validate_and_append_point( + &mut self, + label: &'static [u8], + point: &CompressedRistretto, + ) -> Result<(), ProofError> { + if point.is_identity() { + Err(ProofError::VerificationError) + } else { + self.append_message(label, point.as_bytes()); + Ok(()) + } + } + + fn challenge_scalar(&mut self, label: &'static [u8]) -> Scalar { + let mut buf = [0u8; 64]; + self.challenge_bytes(label, &mut buf); + + Scalar::from_bytes_mod_order_wide(&buf) + } +} diff --git a/zk-token-sdk/src/zk_token_proof_instruction.rs b/zk-token-sdk/src/zk_token_proof_instruction.rs new file mode 100644 index 000000000..e27cebc6a --- /dev/null +++ b/zk-token-sdk/src/zk_token_proof_instruction.rs @@ -0,0 +1,91 @@ +///! Instructions provided by the ZkToken Proof program +pub use crate::instruction::*; + +use { + bytemuck::{bytes_of, Pod}, + num_derive::{FromPrimitive, ToPrimitive}, + num_traits::{FromPrimitive, ToPrimitive}, + solana_program::{instruction::Instruction, pubkey::Pubkey}, +}; + +#[derive(Clone, Copy, Debug, FromPrimitive, ToPrimitive, PartialEq)] +#[repr(u8)] +pub enum ProofInstruction { + /// Verify a `UpdateAccountPkData` struct + /// + /// Accounts expected by this instruction: + /// None + /// + /// Data expected by this instruction: + /// `UpdateAccountPkData` + /// + VerifyUpdateAccountPk, + + /// Verify a `CloseAccountData` struct + /// + /// Accounts expected by this instruction: + /// None + /// + /// Data expected by this instruction: + /// `CloseAccountData` + /// + VerifyCloseAccount, + + /// Verify a `WithdrawData` struct + /// + /// Accounts expected by this instruction: + /// None + /// + /// Data expected by this instruction: + /// `WithdrawData` + /// + VerifyWithdraw, + + /// Verify a `TransferRangeProofData` struct + /// + /// Accounts expected by this instruction: + /// None + /// + /// Data expected by this instruction: + /// `TransferRangeProofData` + /// + VerifyTransferRangeProofData, + + /// Verify a `TransferValidityProofData` struct + /// + /// Accounts expected by this instruction: + /// None + /// + /// Data expected by this instruction: + /// `TransferValidityProofData` + /// + VerifyTransferValidityProofData, +} + +impl ProofInstruction { + pub fn encode(&self, proof: &T) -> Instruction { + let mut data = vec![ToPrimitive::to_u8(self).unwrap()]; + data.extend_from_slice(bytes_of(proof)); + Instruction { + program_id: crate::zk_token_proof_program::id(), + accounts: vec![], + data, + } + } + + pub fn decode_type(program_id: &Pubkey, input: &[u8]) -> Option { + if *program_id != crate::zk_token_proof_program::id() || input.is_empty() { + None + } else { + FromPrimitive::from_u8(input[0]) + } + } + + pub fn decode_data(input: &[u8]) -> Option<&T> { + if input.is_empty() { + None + } else { + bytemuck::try_from_bytes(&input[1..]).ok() + } + } +} diff --git a/zk-token-sdk/src/zk_token_proof_program.rs b/zk-token-sdk/src/zk_token_proof_program.rs new file mode 100644 index 000000000..498c4cce8 --- /dev/null +++ b/zk-token-sdk/src/zk_token_proof_program.rs @@ -0,0 +1,2 @@ +// Program Id of the ZkToken Proof program +solana_program::declare_id!("ZkTokenProof1111111111111111111111111111111");