Rename crypto crate to sdk

This commit is contained in:
Michael Vines 2021-09-29 21:45:35 -07:00
parent 7da620f0b4
commit d01d425e4b
21 changed files with 4672 additions and 0 deletions

33
zk-token-sdk/Cargo.toml Normal file
View File

@ -0,0 +1,33 @@
[package]
name = "spl-zk-token-sdk"
description = "Solana Program Library ZkToken SDK"
authors = ["Solana Maintainers <maintainers@solana.foundation>"]
repository = "https://github.com/solana-labs/solana-program-library"
version = "0.1.0"
license = "Apache-2.0"
edition = "2018"
publish = false
[dependencies]
bytemuck = { version = "1.7.2", features = ["derive"] }
num-derive = "0.3"
num-traits = "0.2"
solana-program = "=1.7.13"
[target.'cfg(not(target_arch = "bpf"))'.dependencies]
arrayref = "0.3.6"
bincode = "1"
byteorder = "1"
clear_on_drop = "0.2"
curve25519-dalek = { version = "3.2.0", features = ["serde"]}
getrandom = { version = "0.1", features = ["dummy"] }
merlin = "2"
rand = "0.7"
serde = { version = "1.0", features = ["derive"] }
sha3 = "0.9"
subtle = "2"
thiserror = "1"
zeroize = { version = "1.2.0", default-features = false, features = ["zeroize_derive"] }
[dev-dependencies]
time = "0.1.40"

View File

@ -0,0 +1,468 @@
#[cfg(not(target_arch = "bpf"))]
use rand::{rngs::OsRng, CryptoRng, RngCore};
use {
crate::encryption::{
encode::DiscreteLogInstance,
pedersen::{Pedersen, PedersenBase, PedersenComm, PedersenDecHandle, PedersenOpen},
},
arrayref::{array_ref, array_refs},
core::ops::{Add, Div, Mul, Sub},
curve25519_dalek::{
ristretto::{CompressedRistretto, RistrettoPoint},
scalar::Scalar,
},
serde::{Deserialize, Serialize},
std::convert::TryInto,
subtle::{Choice, ConstantTimeEq},
zeroize::Zeroize,
};
/// Handle for the (twisted) ElGamal encryption scheme
pub struct ElGamal;
impl ElGamal {
/// Generates the public and secret keys for ElGamal encryption.
#[cfg(not(target_arch = "bpf"))]
pub fn keygen() -> (ElGamalPK, ElGamalSK) {
ElGamal::keygen_with(&mut OsRng) // using OsRng for now
}
/// On input a randomness generator, the function generates the public and
/// secret keys for ElGamal encryption.
#[cfg(not(target_arch = "bpf"))]
#[allow(non_snake_case)]
pub fn keygen_with<T: RngCore + CryptoRng>(rng: &mut T) -> (ElGamalPK, ElGamalSK) {
// sample a non-zero scalar
let mut s: Scalar;
loop {
s = Scalar::random(rng);
if s != Scalar::zero() {
break;
}
}
let H = PedersenBase::default().H;
let P = s.invert() * H;
(ElGamalPK(P), ElGamalSK(s))
}
/// On input a public key and a message to be encrypted, the function
/// returns an ElGamal ciphertext of the message under the public key.
#[cfg(not(target_arch = "bpf"))]
pub fn encrypt<T: Into<Scalar>>(pk: &ElGamalPK, amount: T) -> ElGamalCT {
let (message_comm, open) = Pedersen::commit(amount);
let decrypt_handle = pk.gen_decrypt_handle(&open);
ElGamalCT {
message_comm,
decrypt_handle,
}
}
/// On input a public key, message, and Pedersen opening, the function
/// returns an ElGamal ciphertext of the message under the public key using
/// the opening.
pub fn encrypt_with<T: Into<Scalar>>(
pk: &ElGamalPK,
amount: T,
open: &PedersenOpen,
) -> ElGamalCT {
let message_comm = Pedersen::commit_with(amount, open);
let decrypt_handle = pk.gen_decrypt_handle(open);
ElGamalCT {
message_comm,
decrypt_handle,
}
}
/// On input a secret key and a ciphertext, the function decrypts the ciphertext.
///
/// The output of the function is of type `DiscreteLogInstance`. The exact message
/// can be recovered via the DiscreteLogInstance's decode method.
pub fn decrypt(sk: &ElGamalSK, ct: &ElGamalCT) -> DiscreteLogInstance {
let ElGamalSK(s) = sk;
let ElGamalCT {
message_comm,
decrypt_handle,
} = ct;
DiscreteLogInstance {
generator: PedersenBase::default().G,
target: message_comm.get_point() - s * decrypt_handle.get_point(),
}
}
/// On input a secret key and a ciphertext, the function decrypts the
/// ciphertext for a u32 value.
pub fn decrypt_u32(sk: &ElGamalSK, ct: &ElGamalCT) -> Option<u32> {
let discrete_log_instance = ElGamal::decrypt(sk, ct);
discrete_log_instance.decode_u32()
}
}
/// Public key for the ElGamal encryption scheme.
#[derive(Serialize, Deserialize, Default, Clone, Copy, Debug, Eq, PartialEq)]
pub struct ElGamalPK(RistrettoPoint);
impl ElGamalPK {
pub fn get_point(&self) -> RistrettoPoint {
self.0
}
#[allow(clippy::wrong_self_convention)]
pub fn to_bytes(&self) -> [u8; 32] {
self.0.compress().to_bytes()
}
pub fn from_bytes(bytes: &[u8]) -> Option<ElGamalPK> {
Some(ElGamalPK(
CompressedRistretto::from_slice(bytes).decompress()?,
))
}
/// Utility method for code ergonomics.
#[cfg(not(target_arch = "bpf"))]
pub fn encrypt<T: Into<Scalar>>(&self, msg: T) -> ElGamalCT {
ElGamal::encrypt(self, msg)
}
/// Utility method for code ergonomics.
pub fn encrypt_with<T: Into<Scalar>>(&self, msg: T, open: &PedersenOpen) -> ElGamalCT {
ElGamal::encrypt_with(self, msg, open)
}
/// Generate a decryption token from an ElGamal public key and a Pedersen
/// opening.
pub fn gen_decrypt_handle(self, open: &PedersenOpen) -> PedersenDecHandle {
PedersenDecHandle::generate_handle(open, &self)
}
}
impl From<RistrettoPoint> for ElGamalPK {
fn from(point: RistrettoPoint) -> ElGamalPK {
ElGamalPK(point)
}
}
/// Secret key for the ElGamal encryption scheme.
#[derive(Serialize, Deserialize, Debug, Zeroize)]
#[zeroize(drop)]
pub struct ElGamalSK(Scalar);
impl ElGamalSK {
pub fn get_scalar(&self) -> Scalar {
self.0
}
/// Utility method for code ergonomics.
pub fn decrypt(&self, ct: &ElGamalCT) -> DiscreteLogInstance {
ElGamal::decrypt(self, ct)
}
/// Utility method for code ergonomics.
pub fn decrypt_u32(&self, ct: &ElGamalCT) -> Option<u32> {
ElGamal::decrypt_u32(self, ct)
}
pub fn to_bytes(&self) -> [u8; 32] {
self.0.to_bytes()
}
pub fn from_bytes(bytes: &[u8]) -> Option<ElGamalSK> {
match bytes.try_into() {
Ok(bytes) => Scalar::from_canonical_bytes(bytes).map(ElGamalSK),
_ => None,
}
}
}
impl From<Scalar> for ElGamalSK {
fn from(scalar: Scalar) -> ElGamalSK {
ElGamalSK(scalar)
}
}
impl Eq for ElGamalSK {}
impl PartialEq for ElGamalSK {
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).unwrap_u8() == 1u8
}
}
impl ConstantTimeEq for ElGamalSK {
fn ct_eq(&self, other: &Self) -> Choice {
self.0.ct_eq(&other.0)
}
}
/// Ciphertext for the ElGamal encryption scheme.
#[allow(non_snake_case)]
#[derive(Serialize, Deserialize, Default, Clone, Copy, Debug, Eq, PartialEq)]
pub struct ElGamalCT {
pub message_comm: PedersenComm,
pub decrypt_handle: PedersenDecHandle,
}
impl ElGamalCT {
pub fn add_to_msg<T: Into<Scalar>>(&self, message: T) -> Self {
let diff_comm = Pedersen::commit_with(message, &PedersenOpen::default());
ElGamalCT {
message_comm: self.message_comm + diff_comm,
decrypt_handle: self.decrypt_handle,
}
}
pub fn sub_to_msg<T: Into<Scalar>>(&self, message: T) -> Self {
let diff_comm = Pedersen::commit_with(message, &PedersenOpen::default());
ElGamalCT {
message_comm: self.message_comm - diff_comm,
decrypt_handle: self.decrypt_handle,
}
}
#[allow(clippy::wrong_self_convention)]
pub fn to_bytes(&self) -> [u8; 64] {
let mut bytes = [0u8; 64];
bytes[..32].copy_from_slice(self.message_comm.get_point().compress().as_bytes());
bytes[32..].copy_from_slice(self.decrypt_handle.get_point().compress().as_bytes());
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Option<ElGamalCT> {
let bytes = array_ref![bytes, 0, 64];
let (message_comm, decrypt_handle) = array_refs![bytes, 32, 32];
let message_comm = CompressedRistretto::from_slice(message_comm).decompress()?;
let decrypt_handle = CompressedRistretto::from_slice(decrypt_handle).decompress()?;
Some(ElGamalCT {
message_comm: PedersenComm(message_comm),
decrypt_handle: PedersenDecHandle(decrypt_handle),
})
}
/// Utility method for code ergonomics.
pub fn decrypt(&self, sk: &ElGamalSK) -> DiscreteLogInstance {
ElGamal::decrypt(sk, self)
}
/// Utility method for code ergonomics.
pub fn decrypt_u32(&self, sk: &ElGamalSK) -> Option<u32> {
ElGamal::decrypt_u32(sk, self)
}
}
impl<'a, 'b> Add<&'b ElGamalCT> for &'a ElGamalCT {
type Output = ElGamalCT;
fn add(self, other: &'b ElGamalCT) -> ElGamalCT {
ElGamalCT {
message_comm: self.message_comm + other.message_comm,
decrypt_handle: self.decrypt_handle + other.decrypt_handle,
}
}
}
define_add_variants!(LHS = ElGamalCT, RHS = ElGamalCT, Output = ElGamalCT);
impl<'a, 'b> Sub<&'b ElGamalCT> for &'a ElGamalCT {
type Output = ElGamalCT;
fn sub(self, other: &'b ElGamalCT) -> ElGamalCT {
ElGamalCT {
message_comm: self.message_comm - other.message_comm,
decrypt_handle: self.decrypt_handle - other.decrypt_handle,
}
}
}
define_sub_variants!(LHS = ElGamalCT, RHS = ElGamalCT, Output = ElGamalCT);
impl<'a, 'b> Mul<&'b Scalar> for &'a ElGamalCT {
type Output = ElGamalCT;
fn mul(self, other: &'b Scalar) -> ElGamalCT {
ElGamalCT {
message_comm: self.message_comm * other,
decrypt_handle: self.decrypt_handle * other,
}
}
}
define_mul_variants!(LHS = ElGamalCT, RHS = Scalar, Output = ElGamalCT);
impl<'a, 'b> Div<&'b Scalar> for &'a ElGamalCT {
type Output = ElGamalCT;
fn div(self, other: &'b Scalar) -> ElGamalCT {
ElGamalCT {
message_comm: self.message_comm * other.invert(),
decrypt_handle: self.decrypt_handle * other.invert(),
}
}
}
define_div_variants!(LHS = ElGamalCT, RHS = Scalar, Output = ElGamalCT);
#[cfg(test)]
mod tests {
use super::*;
use crate::encryption::pedersen::Pedersen;
#[test]
fn test_encrypt_decrypt_correctness() {
let (pk, sk) = ElGamal::keygen();
let msg: u32 = 57;
let ct = ElGamal::encrypt(&pk, msg);
let expected_instance = DiscreteLogInstance {
generator: PedersenBase::default().G,
target: Scalar::from(msg) * PedersenBase::default().G,
};
assert_eq!(expected_instance, ElGamal::decrypt(&sk, &ct));
// Commenting out for faster testing
// assert_eq!(msg, ElGamal::decrypt_u32(&sk, &ct).unwrap());
}
#[test]
fn test_decrypt_handle() {
let (pk_1, sk_1) = ElGamal::keygen();
let (pk_2, sk_2) = ElGamal::keygen();
let msg: u32 = 77;
let (comm, open) = Pedersen::commit(msg);
let decrypt_handle_1 = pk_1.gen_decrypt_handle(&open);
let decrypt_handle_2 = pk_2.gen_decrypt_handle(&open);
let ct_1 = decrypt_handle_1.to_elgamal_ctxt(comm);
let ct_2 = decrypt_handle_2.to_elgamal_ctxt(comm);
let expected_instance = DiscreteLogInstance {
generator: PedersenBase::default().G,
target: Scalar::from(msg) * PedersenBase::default().G,
};
assert_eq!(expected_instance, sk_1.decrypt(&ct_1));
assert_eq!(expected_instance, sk_2.decrypt(&ct_2));
// Commenting out for faster testing
// assert_eq!(msg, sk_1.decrypt_u32(&ct_1).unwrap());
// assert_eq!(msg, sk_2.decrypt_u32(&ct_2).unwrap());
}
#[test]
fn test_homomorphic_addition() {
let (pk, _) = ElGamal::keygen();
let msg_0: u64 = 57;
let msg_1: u64 = 77;
// Add two ElGamal ciphertexts
let open_0 = PedersenOpen::random(&mut OsRng);
let open_1 = PedersenOpen::random(&mut OsRng);
let ct_0 = ElGamal::encrypt_with(&pk, msg_0, &open_0);
let ct_1 = ElGamal::encrypt_with(&pk, msg_1, &open_1);
let ct_sum = ElGamal::encrypt_with(&pk, msg_0 + msg_1, &(open_0 + open_1));
assert_eq!(ct_sum, ct_0 + ct_1);
// Add to ElGamal ciphertext
let open = PedersenOpen::random(&mut OsRng);
let ct = ElGamal::encrypt_with(&pk, msg_0, &open);
let ct_sum = ElGamal::encrypt_with(&pk, msg_0 + msg_1, &open);
assert_eq!(ct_sum, ct.add_to_msg(msg_1));
}
#[test]
fn test_homomorphic_subtraction() {
let (pk, _) = ElGamal::keygen();
let msg_0: u64 = 77;
let msg_1: u64 = 55;
// Subtract two ElGamal ciphertexts
let open_0 = PedersenOpen::random(&mut OsRng);
let open_1 = PedersenOpen::random(&mut OsRng);
let ct_0 = ElGamal::encrypt_with(&pk, msg_0, &open_0);
let ct_1 = ElGamal::encrypt_with(&pk, msg_1, &open_1);
let ct_sub = ElGamal::encrypt_with(&pk, msg_0 - msg_1, &(open_0 - open_1));
assert_eq!(ct_sub, ct_0 - ct_1);
// Subtract to ElGamal ciphertext
let open = PedersenOpen::random(&mut OsRng);
let ct = ElGamal::encrypt_with(&pk, msg_0, &open);
let ct_sub = ElGamal::encrypt_with(&pk, msg_0 - msg_1, &open);
assert_eq!(ct_sub, ct.sub_to_msg(msg_1));
}
#[test]
fn test_homomorphic_multiplication() {
let (pk, _) = ElGamal::keygen();
let msg_0: u64 = 57;
let msg_1: u64 = 77;
let open = PedersenOpen::random(&mut OsRng);
let ct = ElGamal::encrypt_with(&pk, msg_0, &open);
let scalar = Scalar::from(msg_1);
let ct_prod = ElGamal::encrypt_with(&pk, msg_0 * msg_1, &(open * scalar));
assert_eq!(ct_prod, ct * scalar);
}
#[test]
fn test_homomorphic_division() {
let (pk, _) = ElGamal::keygen();
let msg_0: u64 = 55;
let msg_1: u64 = 5;
let open = PedersenOpen::random(&mut OsRng);
let ct = ElGamal::encrypt_with(&pk, msg_0, &open);
let scalar = Scalar::from(msg_1);
let ct_div = ElGamal::encrypt_with(&pk, msg_0 / msg_1, &(open / scalar));
assert_eq!(ct_div, ct / scalar);
}
#[test]
fn test_serde_ciphertext() {
let (pk, _) = ElGamal::keygen();
let msg: u64 = 77;
let ct = pk.encrypt(msg);
let encoded = bincode::serialize(&ct).unwrap();
let decoded: ElGamalCT = bincode::deserialize(&encoded).unwrap();
assert_eq!(ct, decoded);
}
#[test]
fn test_serde_pubkey() {
let (pk, _) = ElGamal::keygen();
let encoded = bincode::serialize(&pk).unwrap();
let decoded: ElGamalPK = bincode::deserialize(&encoded).unwrap();
assert_eq!(pk, decoded);
}
#[test]
fn test_serde_secretkey() {
let (_, sk) = ElGamal::keygen();
let encoded = bincode::serialize(&sk).unwrap();
let decoded: ElGamalSK = bincode::deserialize(&encoded).unwrap();
assert_eq!(sk, decoded);
}
}

View File

@ -0,0 +1,276 @@
use core::ops::{Add, Neg, Sub};
use curve25519_dalek::constants::RISTRETTO_BASEPOINT_POINT as G;
use curve25519_dalek::ristretto::RistrettoPoint;
use curve25519_dalek::scalar::Scalar;
use curve25519_dalek::traits::Identity;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use serde::{Deserialize, Serialize};
const TWO14: u32 = 16384; // 2^14
const TWO16: u32 = 65536; // 2^16
const TWO18: u32 = 262144; // 2^18
/// Type that captures a discrete log challenge.
///
/// The goal of discrete log is to find x such that x * generator = target.
#[derive(Serialize, Deserialize, Copy, Clone, Debug, Eq, PartialEq)]
pub struct DiscreteLogInstance {
/// Generator point for discrete log
pub generator: RistrettoPoint,
/// Target point for discrete log
pub target: RistrettoPoint,
}
/// Solves the discrete log instance using a 18/14 bit offline/online split
impl DiscreteLogInstance {
/// Solves the discrete log problem under the assumption that the solution
/// is a 32-bit number.
pub fn decode_u32(self) -> Option<u32> {
let hashmap = DiscreteLogInstance::decode_u32_precomputation(self.generator);
self.decode_u32_online(&hashmap)
}
pub fn decode_u32_precomputation(generator: RistrettoPoint) -> HashMap<HashableRistretto, u32> {
let mut hashmap = HashMap::new();
let two16_scalar = Scalar::from(TWO16);
let identity = HashableRistretto(RistrettoPoint::identity()); // 0 * G
let generator = HashableRistretto(two16_scalar * generator); // 2^16 * G
// iterator for 2^16*0G , 2^16*1G, 2^16*2G, ...
let ristretto_iter = RistrettoIterator::new(identity, generator);
ristretto_iter.zip(0..TWO16).for_each(|(elem, x_hi)| {
hashmap.insert(elem, x_hi);
});
hashmap
}
pub fn decode_u32_online(self, hashmap: &HashMap<HashableRistretto, u32>) -> Option<u32> {
// iterator for 0G, -1G, -2G, ...
let ristretto_iter = RistrettoIterator::new(
HashableRistretto(self.target),
HashableRistretto(-self.generator),
);
let mut decoded = None;
ristretto_iter.zip(0..TWO16).for_each(|(elem, x_lo)| {
if hashmap.contains_key(&elem) {
let x_hi = hashmap[&elem];
decoded = Some(x_lo + TWO16 * x_hi);
}
});
decoded
}
}
/// Solves the discrete log instance using a 18/14 bit offline/online split
impl DiscreteLogInstance {
/// Solves the discrete log problem under the assumption that the solution
/// is a 32-bit number.
pub fn decode_u32_alt(self) -> Option<u32> {
let hashmap = DiscreteLogInstance::decode_u32_precomputation_alt(self.generator);
self.decode_u32_online_alt(&hashmap)
}
pub fn decode_u32_precomputation_alt(
generator: RistrettoPoint,
) -> HashMap<HashableRistretto, u32> {
let mut hashmap = HashMap::new();
let two12_scalar = Scalar::from(TWO14);
let identity = HashableRistretto(RistrettoPoint::identity()); // 0 * G
let generator = HashableRistretto(two12_scalar * generator); // 2^12 * G
// iterator for 2^12*0G , 2^12*1G, 2^12*2G, ...
let ristretto_iter = RistrettoIterator::new(identity, generator);
ristretto_iter.zip(0..TWO18).for_each(|(elem, x_hi)| {
hashmap.insert(elem, x_hi);
});
hashmap
}
pub fn decode_u32_online_alt(self, hashmap: &HashMap<HashableRistretto, u32>) -> Option<u32> {
// iterator for 0G, -1G, -2G, ...
let ristretto_iter = RistrettoIterator::new(
HashableRistretto(self.target),
HashableRistretto(-self.generator),
);
let mut decoded = None;
ristretto_iter.zip(0..TWO14).for_each(|(elem, x_lo)| {
if hashmap.contains_key(&elem) {
let x_hi = hashmap[&elem];
decoded = Some(x_lo + TWO14 * x_hi);
}
});
decoded
}
}
/// Type wrapper for RistrettoPoint that implements the Hash trait
#[derive(Serialize, Deserialize, Copy, Clone, Debug, Eq)]
pub struct HashableRistretto(pub RistrettoPoint);
impl HashableRistretto {
pub fn encode<T: Into<Scalar>>(amount: T) -> Self {
HashableRistretto(amount.into() * G)
}
}
impl Hash for HashableRistretto {
fn hash<H: Hasher>(&self, state: &mut H) {
bincode::serialize(self).unwrap().hash(state);
}
}
impl PartialEq for HashableRistretto {
fn eq(&self, other: &Self) -> bool {
self == other
}
}
/// HashableRistretto iterator.
///
/// Given an initial point X and a stepping point P, the iterator iterates through
/// X + 0*P, X + 1*P, X + 2*P, X + 3*P, ...
struct RistrettoIterator {
pub curr: HashableRistretto,
pub step: HashableRistretto,
}
impl RistrettoIterator {
fn new(curr: HashableRistretto, step: HashableRistretto) -> Self {
RistrettoIterator { curr, step }
}
}
impl Iterator for RistrettoIterator {
type Item = HashableRistretto;
fn next(&mut self) -> Option<Self::Item> {
let r = self.curr;
self.curr = self.curr + self.step;
Some(r)
}
}
impl<'a, 'b> Add<&'b HashableRistretto> for &'a HashableRistretto {
type Output = HashableRistretto;
fn add(self, other: &HashableRistretto) -> HashableRistretto {
HashableRistretto(self.0 + other.0)
}
}
define_add_variants!(
LHS = HashableRistretto,
RHS = HashableRistretto,
Output = HashableRistretto
);
impl<'a, 'b> Sub<&'b HashableRistretto> for &'a HashableRistretto {
type Output = HashableRistretto;
fn sub(self, other: &HashableRistretto) -> HashableRistretto {
HashableRistretto(self.0 - other.0)
}
}
define_sub_variants!(
LHS = HashableRistretto,
RHS = HashableRistretto,
Output = HashableRistretto
);
impl Neg for HashableRistretto {
type Output = HashableRistretto;
fn neg(self) -> HashableRistretto {
HashableRistretto(-self.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
/// Discrete log test for 16/16 split
///
/// Very informal measurements on my machine:
/// - 8 sec for precomputation
/// - 3 sec for online computation
#[test]
#[ignore]
fn test_decode_correctness() {
let amount: u32 = 65545;
let instance = DiscreteLogInstance {
generator: G,
target: Scalar::from(amount) * G,
};
// Very informal measurements for now
let start_precomputation = time::precise_time_s();
let precomputed_hashmap = DiscreteLogInstance::decode_u32_precomputation(G);
let end_precomputation = time::precise_time_s();
let start_online = time::precise_time_s();
let computed_amount = instance.decode_u32_online(&precomputed_hashmap).unwrap();
let end_online = time::precise_time_s();
assert_eq!(amount, computed_amount);
println!(
"16/16 Split precomputation: {:?} sec",
end_precomputation - start_precomputation
);
println!(
"16/16 Split online computation: {:?} sec",
end_online - start_online
);
}
/// Discrete log test for 18/14 split
///
/// Very informal measurements on my machine:
/// - 33 sec for precomputation
/// - 0.8 sec for online computation
#[test]
#[ignore]
fn test_decode_alt_correctness() {
let amount: u32 = 65545;
let instance = DiscreteLogInstance {
generator: G,
target: Scalar::from(amount) * G,
};
// Very informal measurements for now
let start_precomputation = time::precise_time_s();
let precomputed_hashmap = DiscreteLogInstance::decode_u32_precomputation_alt(G);
let end_precomputation = time::precise_time_s();
let start_online = time::precise_time_s();
let computed_amount = instance
.decode_u32_online_alt(&precomputed_hashmap)
.unwrap();
let end_online = time::precise_time_s();
assert_eq!(amount, computed_amount);
println!(
"18/14 Split precomputation: {:?} sec",
end_precomputation - start_precomputation
);
println!(
"18/14 Split online computation: {:?} sec",
end_online - start_online
);
}
}

View File

@ -0,0 +1,3 @@
pub mod elgamal;
pub mod encode;
pub mod pedersen;

View File

@ -0,0 +1,484 @@
#[cfg(not(target_arch = "bpf"))]
use rand::{rngs::OsRng, CryptoRng, RngCore};
use {
crate::{
encryption::elgamal::{ElGamalCT, ElGamalPK},
errors::ProofError,
},
core::ops::{Add, Div, Mul, Sub},
curve25519_dalek::{
constants::{RISTRETTO_BASEPOINT_COMPRESSED, RISTRETTO_BASEPOINT_POINT},
ristretto::{CompressedRistretto, RistrettoPoint},
scalar::Scalar,
traits::MultiscalarMul,
},
serde::{Deserialize, Serialize},
sha3::Sha3_512,
std::convert::TryInto,
subtle::{Choice, ConstantTimeEq},
zeroize::Zeroize,
};
/// Curve basepoints for which Pedersen commitment is defined over.
///
/// These points should be fixed for the entire system.
/// TODO: Consider setting these points as constants?
#[allow(non_snake_case)]
#[derive(Serialize, Deserialize, Clone, Copy, Debug, Eq, PartialEq)]
pub struct PedersenBase {
pub G: RistrettoPoint,
pub H: RistrettoPoint,
}
/// Default PedersenBase. This is set arbitrarily for now, but it should be fixed
/// for the entire system.
///
/// `G` is a constant point in the curve25519_dalek library
/// `H` is the Sha3 hash of `G` interpretted as a RistrettoPoint
impl Default for PedersenBase {
#[allow(non_snake_case)]
fn default() -> PedersenBase {
let G = RISTRETTO_BASEPOINT_POINT;
let H =
RistrettoPoint::hash_from_bytes::<Sha3_512>(RISTRETTO_BASEPOINT_COMPRESSED.as_bytes());
PedersenBase { G, H }
}
}
/// Handle for the Pedersen commitment scheme
pub struct Pedersen;
impl Pedersen {
/// Given a number as input, the function returns a Pedersen commitment of
/// the number and its corresponding opening.
///
/// TODO: Interface that takes a random generator as input
#[cfg(not(target_arch = "bpf"))]
pub fn commit<T: Into<Scalar>>(amount: T) -> (PedersenComm, PedersenOpen) {
let open = PedersenOpen(Scalar::random(&mut OsRng));
let comm = Pedersen::commit_with(amount, &open);
(comm, open)
}
/// Given a number and an opening as inputs, the function returns their
/// Pedersen commitment.
#[allow(non_snake_case)]
pub fn commit_with<T: Into<Scalar>>(amount: T, open: &PedersenOpen) -> PedersenComm {
let G = PedersenBase::default().G;
let H = PedersenBase::default().H;
let x: Scalar = amount.into();
let r = open.get_scalar();
PedersenComm(RistrettoPoint::multiscalar_mul(&[x, r], &[G, H]))
}
/// Given a number, opening, and Pedersen commitment, the function verifies
/// the validity of the commitment with respect to the number and opening.
///
/// This function is included for completeness and is not used for the
/// c-token program.
#[allow(non_snake_case)]
pub fn verify<T: Into<Scalar>>(
comm: PedersenComm,
open: PedersenOpen,
amount: T,
) -> Result<(), ProofError> {
let G = PedersenBase::default().G;
let H = PedersenBase::default().H;
let x: Scalar = amount.into();
let r = open.get_scalar();
let C = comm.get_point();
if C == RistrettoPoint::multiscalar_mul(&[x, r], &[G, H]) {
Ok(())
} else {
Err(ProofError::VerificationError)
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug, Zeroize)]
#[zeroize(drop)]
pub struct PedersenOpen(pub(crate) Scalar);
impl PedersenOpen {
pub fn get_scalar(&self) -> Scalar {
self.0
}
#[cfg(not(target_arch = "bpf"))]
pub fn random<T: RngCore + CryptoRng>(rng: &mut T) -> Self {
PedersenOpen(Scalar::random(rng))
}
#[allow(clippy::wrong_self_convention)]
pub fn to_bytes(&self) -> [u8; 32] {
self.0.to_bytes()
}
pub fn from_bytes(bytes: &[u8]) -> Option<PedersenOpen> {
match bytes.try_into() {
Ok(bytes) => Scalar::from_canonical_bytes(bytes).map(PedersenOpen),
_ => None,
}
}
}
impl Eq for PedersenOpen {}
impl PartialEq for PedersenOpen {
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).unwrap_u8() == 1u8
}
}
impl ConstantTimeEq for PedersenOpen {
fn ct_eq(&self, other: &Self) -> Choice {
self.0.ct_eq(&other.0)
}
}
impl Default for PedersenOpen {
fn default() -> Self {
PedersenOpen(Scalar::default())
}
}
impl<'a, 'b> Add<&'b PedersenOpen> for &'a PedersenOpen {
type Output = PedersenOpen;
fn add(self, other: &'b PedersenOpen) -> PedersenOpen {
PedersenOpen(self.get_scalar() + other.get_scalar())
}
}
define_add_variants!(
LHS = PedersenOpen,
RHS = PedersenOpen,
Output = PedersenOpen
);
impl<'a, 'b> Sub<&'b PedersenOpen> for &'a PedersenOpen {
type Output = PedersenOpen;
fn sub(self, other: &'b PedersenOpen) -> PedersenOpen {
PedersenOpen(self.get_scalar() - other.get_scalar())
}
}
define_sub_variants!(
LHS = PedersenOpen,
RHS = PedersenOpen,
Output = PedersenOpen
);
impl<'a, 'b> Mul<&'b Scalar> for &'a PedersenOpen {
type Output = PedersenOpen;
fn mul(self, other: &'b Scalar) -> PedersenOpen {
PedersenOpen(self.get_scalar() * other)
}
}
define_mul_variants!(LHS = PedersenOpen, RHS = Scalar, Output = PedersenOpen);
impl<'a, 'b> Div<&'b Scalar> for &'a PedersenOpen {
type Output = PedersenOpen;
#[allow(clippy::suspicious_arithmetic_impl)]
fn div(self, other: &'b Scalar) -> PedersenOpen {
PedersenOpen(self.get_scalar() * other.invert())
}
}
define_div_variants!(LHS = PedersenOpen, RHS = Scalar, Output = PedersenOpen);
#[derive(Serialize, Deserialize, Default, Clone, Copy, Debug, Eq, PartialEq)]
pub struct PedersenComm(pub(crate) RistrettoPoint);
impl PedersenComm {
pub fn get_point(&self) -> RistrettoPoint {
self.0
}
#[allow(clippy::wrong_self_convention)]
pub fn to_bytes(&self) -> [u8; 32] {
self.0.compress().to_bytes()
}
pub fn from_bytes(bytes: &[u8]) -> Option<PedersenComm> {
Some(PedersenComm(
CompressedRistretto::from_slice(bytes).decompress()?,
))
}
}
impl<'a, 'b> Add<&'b PedersenComm> for &'a PedersenComm {
type Output = PedersenComm;
fn add(self, other: &'b PedersenComm) -> PedersenComm {
PedersenComm(self.get_point() + other.get_point())
}
}
define_add_variants!(
LHS = PedersenComm,
RHS = PedersenComm,
Output = PedersenComm
);
impl<'a, 'b> Sub<&'b PedersenComm> for &'a PedersenComm {
type Output = PedersenComm;
fn sub(self, other: &'b PedersenComm) -> PedersenComm {
PedersenComm(self.get_point() - other.get_point())
}
}
define_sub_variants!(
LHS = PedersenComm,
RHS = PedersenComm,
Output = PedersenComm
);
impl<'a, 'b> Mul<&'b Scalar> for &'a PedersenComm {
type Output = PedersenComm;
fn mul(self, other: &'b Scalar) -> PedersenComm {
PedersenComm(self.get_point() * other)
}
}
define_mul_variants!(LHS = PedersenComm, RHS = Scalar, Output = PedersenComm);
impl<'a, 'b> Div<&'b Scalar> for &'a PedersenComm {
type Output = PedersenComm;
#[allow(clippy::suspicious_arithmetic_impl)]
fn div(self, other: &'b Scalar) -> PedersenComm {
PedersenComm(self.get_point() * other.invert())
}
}
define_div_variants!(LHS = PedersenComm, RHS = Scalar, Output = PedersenComm);
/// Decryption handle for Pedersen commitment.
///
/// A decryption handle can be combined with Pedersen commitments to form an
/// ElGamal ciphertext.
#[derive(Serialize, Deserialize, Default, Clone, Copy, Debug, Eq, PartialEq)]
pub struct PedersenDecHandle(pub(crate) RistrettoPoint);
impl PedersenDecHandle {
pub fn get_point(&self) -> RistrettoPoint {
self.0
}
pub fn generate_handle(open: &PedersenOpen, pk: &ElGamalPK) -> PedersenDecHandle {
PedersenDecHandle(open.get_scalar() * pk.get_point())
}
/// Maps a decryption token and Pedersen commitment to ElGamal ciphertext
pub fn to_elgamal_ctxt(self, comm: PedersenComm) -> ElGamalCT {
ElGamalCT {
message_comm: comm,
decrypt_handle: self,
}
}
#[allow(clippy::wrong_self_convention)]
pub fn to_bytes(&self) -> [u8; 32] {
self.0.compress().to_bytes()
}
pub fn from_bytes(bytes: &[u8]) -> Option<PedersenDecHandle> {
Some(PedersenDecHandle(
CompressedRistretto::from_slice(bytes).decompress()?,
))
}
}
impl<'a, 'b> Add<&'b PedersenDecHandle> for &'a PedersenDecHandle {
type Output = PedersenDecHandle;
fn add(self, other: &'b PedersenDecHandle) -> PedersenDecHandle {
PedersenDecHandle(self.get_point() + other.get_point())
}
}
define_add_variants!(
LHS = PedersenDecHandle,
RHS = PedersenDecHandle,
Output = PedersenDecHandle
);
impl<'a, 'b> Sub<&'b PedersenDecHandle> for &'a PedersenDecHandle {
type Output = PedersenDecHandle;
fn sub(self, other: &'b PedersenDecHandle) -> PedersenDecHandle {
PedersenDecHandle(self.get_point() - other.get_point())
}
}
define_sub_variants!(
LHS = PedersenDecHandle,
RHS = PedersenDecHandle,
Output = PedersenDecHandle
);
impl<'a, 'b> Mul<&'b Scalar> for &'a PedersenDecHandle {
type Output = PedersenDecHandle;
fn mul(self, other: &'b Scalar) -> PedersenDecHandle {
PedersenDecHandle(self.get_point() * other)
}
}
define_mul_variants!(
LHS = PedersenDecHandle,
RHS = Scalar,
Output = PedersenDecHandle
);
impl<'a, 'b> Div<&'b Scalar> for &'a PedersenDecHandle {
type Output = PedersenDecHandle;
#[allow(clippy::suspicious_arithmetic_impl)]
fn div(self, other: &'b Scalar) -> PedersenDecHandle {
PedersenDecHandle(self.get_point() * other.invert())
}
}
define_div_variants!(
LHS = PedersenDecHandle,
RHS = Scalar,
Output = PedersenDecHandle
);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_commit_verification_correctness() {
let amt: u64 = 57;
let (comm, open) = Pedersen::commit(amt);
assert!(Pedersen::verify(comm, open, amt).is_ok());
}
#[test]
fn test_homomorphic_addition() {
let amt_0: u64 = 77;
let amt_1: u64 = 57;
let rng = &mut OsRng;
let open_0 = PedersenOpen(Scalar::random(rng));
let open_1 = PedersenOpen(Scalar::random(rng));
let comm_0 = Pedersen::commit_with(amt_0, &open_0);
let comm_1 = Pedersen::commit_with(amt_1, &open_1);
let comm_addition = Pedersen::commit_with(amt_0 + amt_1, &(open_0 + open_1));
assert_eq!(comm_addition, comm_0 + comm_1);
}
#[test]
fn test_homomorphic_subtraction() {
let amt_0: u64 = 77;
let amt_1: u64 = 57;
let rng = &mut OsRng;
let open_0 = PedersenOpen(Scalar::random(rng));
let open_1 = PedersenOpen(Scalar::random(rng));
let comm_0 = Pedersen::commit_with(amt_0, &open_0);
let comm_1 = Pedersen::commit_with(amt_1, &open_1);
let comm_addition = Pedersen::commit_with(amt_0 - amt_1, &(open_0 - open_1));
assert_eq!(comm_addition, comm_0 - comm_1);
}
#[test]
fn test_homomorphic_multiplication() {
let amt_0: u64 = 77;
let amt_1: u64 = 57;
let (comm, open) = Pedersen::commit(amt_0);
let scalar = Scalar::from(amt_1);
let comm_addition = Pedersen::commit_with(amt_0 * amt_1, &(open * scalar));
assert_eq!(comm_addition, comm * scalar);
}
#[test]
fn test_homomorphic_division() {
let amt_0: u64 = 77;
let amt_1: u64 = 7;
let (comm, open) = Pedersen::commit(amt_0);
let scalar = Scalar::from(amt_1);
let comm_addition = Pedersen::commit_with(amt_0 / amt_1, &(open / scalar));
assert_eq!(comm_addition, comm / scalar);
}
#[test]
fn test_commitment_bytes() {
let amt: u64 = 77;
let (comm, _) = Pedersen::commit(amt);
let encoded = comm.to_bytes();
let decoded = PedersenComm::from_bytes(&encoded).unwrap();
assert_eq!(comm, decoded);
}
#[test]
fn test_opening_bytes() {
let open = PedersenOpen(Scalar::random(&mut OsRng));
let encoded = open.to_bytes();
let decoded = PedersenOpen::from_bytes(&encoded).unwrap();
assert_eq!(open, decoded);
}
#[test]
fn test_decrypt_handle_bytes() {
let handle = PedersenDecHandle(RistrettoPoint::default());
let encoded = handle.to_bytes();
let decoded = PedersenDecHandle::from_bytes(&encoded).unwrap();
assert_eq!(handle, decoded);
}
#[test]
fn test_serde_commitment() {
let amt: u64 = 77;
let (comm, _) = Pedersen::commit(amt);
let encoded = bincode::serialize(&comm).unwrap();
let decoded: PedersenComm = bincode::deserialize(&encoded).unwrap();
assert_eq!(comm, decoded);
}
#[test]
fn test_serde_opening() {
let open = PedersenOpen(Scalar::random(&mut OsRng));
let encoded = bincode::serialize(&open).unwrap();
let decoded: PedersenOpen = bincode::deserialize(&encoded).unwrap();
assert_eq!(open, decoded);
}
#[test]
fn test_serde_decrypt_handle() {
let handle = PedersenDecHandle(RistrettoPoint::default());
let encoded = bincode::serialize(&handle).unwrap();
let decoded: PedersenDecHandle = bincode::deserialize(&encoded).unwrap();
assert_eq!(handle, decoded);
}
}

View File

@ -0,0 +1,19 @@
//! Errors related to proving and verifying proofs.
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum ProofError {
/// This error occurs when a proof failed to verify.
VerificationError,
/// This error occurs when the proof encoding is malformed.
FormatError,
/// This error occurs during proving if the number of blinding
/// factors does not match the number of values.
WrongNumBlindingFactors,
/// This error occurs when attempting to create a proof with
/// bitsize other than \\(8\\), \\(16\\), \\(32\\), or \\(64\\).
InvalidBitsize,
/// This error occurs when there are insufficient generators for the proof.
InvalidGeneratorsLength,
/// This error occurs when TODO
InconsistentCTData,
}

View File

@ -0,0 +1,163 @@
use {
crate::pod::*,
bytemuck::{Pod, Zeroable},
};
#[cfg(not(target_arch = "bpf"))]
use {
crate::{
encryption::elgamal::{ElGamalCT, ElGamalSK},
errors::ProofError,
instruction::Verifiable,
transcript::TranscriptProtocol,
},
curve25519_dalek::{
ristretto::RistrettoPoint,
scalar::Scalar,
traits::{IsIdentity, MultiscalarMul},
},
merlin::Transcript,
rand::rngs::OsRng,
std::convert::TryInto,
};
/// This struct includes the cryptographic proof *and* the account data information needed to verify
/// the proof
///
/// - The pre-instruction should call CloseAccountData::verify_proof(&self)
/// - The actual program should check that `balance` is consistent with what is
/// currently stored in the confidential token account
///
#[derive(Clone, Copy, Pod, Zeroable)]
#[repr(C)]
pub struct CloseAccountData {
/// The source account available balance in encrypted form
pub balance: PodElGamalCT, // 64 bytes
/// Proof that the source account available balance is zero
pub proof: CloseAccountProof, // 64 bytes
}
#[cfg(not(target_arch = "bpf"))]
impl CloseAccountData {
pub fn new(source_sk: &ElGamalSK, balance: ElGamalCT) -> Self {
let proof = CloseAccountProof::new(source_sk, &balance);
CloseAccountData {
balance: balance.into(),
proof,
}
}
}
#[cfg(not(target_arch = "bpf"))]
impl Verifiable for CloseAccountData {
fn verify(&self) -> Result<(), ProofError> {
let balance = self.balance.try_into()?;
self.proof.verify(&balance)
}
}
/// This struct represents the cryptographic proof component that certifies that the encrypted
/// balance is zero
#[derive(Clone, Copy, Pod, Zeroable)]
#[repr(C)]
#[allow(non_snake_case)]
pub struct CloseAccountProof {
pub R: PodCompressedRistretto, // 32 bytes
pub z: PodScalar, // 32 bytes
}
#[allow(non_snake_case)]
#[cfg(not(target_arch = "bpf"))]
impl CloseAccountProof {
fn transcript_new() -> Transcript {
Transcript::new(b"CloseAccountProof")
}
pub fn new(source_sk: &ElGamalSK, balance: &ElGamalCT) -> Self {
let mut transcript = Self::transcript_new();
// add a domain separator to record the start of the protocol
transcript.close_account_proof_domain_sep();
// extract the relevant scalar and Ristretto points from the input
let s = source_sk.get_scalar();
let C = balance.decrypt_handle.get_point();
// generate a random masking factor that also serves as a nonce
let r = Scalar::random(&mut OsRng); // using OsRng for now
let R = (r * C).compress();
// record R on transcript and receive a challenge scalar
transcript.append_point(b"R", &R);
let c = transcript.challenge_scalar(b"c");
// compute the masked secret key
let z = c * s + r;
CloseAccountProof {
R: R.into(),
z: z.into(),
}
}
pub fn verify(&self, balance: &ElGamalCT) -> Result<(), ProofError> {
let mut transcript = Self::transcript_new();
// add a domain separator to record the start of the protocol
transcript.close_account_proof_domain_sep();
// extract the relevant scalar and Ristretto points from the input
let C = balance.message_comm.get_point();
let D = balance.decrypt_handle.get_point();
let R = self.R.into();
let z = self.z.into();
// generate a challenge scalar
//
// use `append_point` as opposed to `validate_and_append_point` as the ciphertext is
// already guaranteed to be well-formed
transcript.append_point(b"R", &R);
let c = transcript.challenge_scalar(b"c");
// decompress R or return verification error
let R = R.decompress().ok_or(ProofError::VerificationError)?;
// check the required algebraic relation
let check = RistrettoPoint::multiscalar_mul(vec![z, -c, -Scalar::one()], vec![D, C, R]);
if check.is_identity() {
Ok(())
} else {
Err(ProofError::VerificationError)
}
}
}
#[cfg(test)]
mod test {
use {super::*, crate::encryption::elgamal::ElGamal};
#[test]
fn test_close_account_correctness() {
let (source_pk, source_sk) = ElGamal::keygen();
// If account balance is 0, then the proof should succeed
let balance = source_pk.encrypt(0_u64);
let proof = CloseAccountProof::new(&source_sk, &balance);
assert!(proof.verify(&balance).is_ok());
// If account balance is not zero, then the proof verification should fail
let balance = source_pk.encrypt(55_u64);
let proof = CloseAccountProof::new(&source_sk, &balance);
assert!(proof.verify(&balance).is_err());
// A zeroed cyphertext should be considered as an account balance of 0
let zeroed_ct: ElGamalCT = PodElGamalCT::zeroed().try_into().unwrap();
let proof = CloseAccountProof::new(&source_sk, &zeroed_ct);
assert!(proof.verify(&zeroed_ct).is_ok());
}
}

View File

@ -0,0 +1,21 @@
mod close_account;
pub mod transfer;
mod update_account_pk;
mod withdraw;
#[cfg(not(target_arch = "bpf"))]
use crate::errors::ProofError;
pub use {
close_account::CloseAccountData,
transfer::{
TransferComms, TransferData, TransferEphemeralState, TransferPubKeys,
TransferRangeProofData, TransferValidityProofData,
},
update_account_pk::UpdateAccountPkData,
withdraw::WithdrawData,
};
#[cfg(not(target_arch = "bpf"))]
pub trait Verifiable {
fn verify(&self) -> Result<(), ProofError>;
}

View File

@ -0,0 +1,546 @@
use {
crate::pod::*,
bytemuck::{Pod, Zeroable},
};
#[cfg(not(target_arch = "bpf"))]
use {
crate::{
encryption::{
elgamal::{ElGamalCT, ElGamalPK, ElGamalSK},
pedersen::{Pedersen, PedersenBase, PedersenComm, PedersenDecHandle, PedersenOpen},
},
errors::ProofError,
instruction::Verifiable,
range_proof::RangeProof,
transcript::TranscriptProtocol,
},
curve25519_dalek::{
ristretto::{CompressedRistretto, RistrettoPoint},
scalar::Scalar,
traits::{IsIdentity, MultiscalarMul, VartimeMultiscalarMul},
},
merlin::Transcript,
rand::rngs::OsRng,
std::convert::TryInto,
};
/// Just a grouping struct for the data required for the two transfer instructions. It is
/// convenient to generate the two components jointly as they share common components.
pub struct TransferData {
pub range_proof: TransferRangeProofData,
pub validity_proof: TransferValidityProofData,
}
#[cfg(not(target_arch = "bpf"))]
impl TransferData {
pub fn new(
transfer_amount: u64,
spendable_balance: u64,
spendable_ct: ElGamalCT,
source_pk: ElGamalPK,
source_sk: &ElGamalSK,
dest_pk: ElGamalPK,
auditor_pk: ElGamalPK,
) -> Self {
// split and encrypt transfer amount
//
// encryption is a bit more involved since we are generating each components of an ElGamal
// ciphertext separately.
let (amount_lo, amount_hi) = split_u64_into_u32(transfer_amount);
let (comm_lo, open_lo) = Pedersen::commit(amount_lo);
let (comm_hi, open_hi) = Pedersen::commit(amount_hi);
let handle_source_lo = source_pk.gen_decrypt_handle(&open_lo);
let handle_dest_lo = dest_pk.gen_decrypt_handle(&open_lo);
let handle_auditor_lo = auditor_pk.gen_decrypt_handle(&open_lo);
let handle_source_hi = source_pk.gen_decrypt_handle(&open_hi);
let handle_dest_hi = dest_pk.gen_decrypt_handle(&open_hi);
let handle_auditor_hi = auditor_pk.gen_decrypt_handle(&open_hi);
// message encoding as Pedersen commitments, which will be included in range proof data
let amount_comms = TransferComms {
lo: comm_lo.into(),
hi: comm_hi.into(),
};
// decryption handles, which will be included in the validity proof data
let decryption_handles_lo = TransferHandles {
source: handle_source_lo.into(),
dest: handle_dest_lo.into(),
auditor: handle_auditor_lo.into(),
};
let decryption_handles_hi = TransferHandles {
source: handle_source_hi.into(),
dest: handle_dest_hi.into(),
auditor: handle_auditor_hi.into(),
};
// grouping of the public keys for the transfer
let transfer_public_keys = TransferPubKeys {
source_pk: source_pk.into(),
dest_pk: dest_pk.into(),
auditor_pk: auditor_pk.into(),
};
// subtract transfer amount from the spendable ciphertext
let spendable_comm = spendable_ct.message_comm;
let spendable_handle = spendable_ct.decrypt_handle;
let new_spendable_balance = spendable_balance - transfer_amount;
let new_spendable_comm = spendable_comm - combine_u32_comms(comm_lo, comm_hi);
let new_spendable_handle =
spendable_handle - combine_u32_handles(handle_source_lo, handle_source_hi);
let new_spendable_ct = ElGamalCT {
message_comm: new_spendable_comm,
decrypt_handle: new_spendable_handle,
};
// range_proof and validity_proof should be generated together
let (transfer_proofs, ephemeral_state) = TransferProofs::new(
source_sk,
&source_pk,
&dest_pk,
&auditor_pk,
(amount_lo as u64, amount_hi as u64),
&open_lo,
&open_hi,
new_spendable_balance,
&new_spendable_ct,
);
// generate data components
let range_proof = TransferRangeProofData {
amount_comms,
proof: transfer_proofs.range_proof,
ephemeral_state,
};
let validity_proof = TransferValidityProofData {
decryption_handles_lo,
decryption_handles_hi,
transfer_public_keys,
new_spendable_ct: new_spendable_ct.into(),
proof: transfer_proofs.validity_proof,
ephemeral_state,
};
TransferData {
range_proof,
validity_proof,
}
}
}
#[derive(Clone, Copy, Pod, Zeroable)]
#[repr(C)]
pub struct TransferRangeProofData {
/// The transfer amount encoded as Pedersen commitments
pub amount_comms: TransferComms, // 64 bytes
/// Proof that certifies:
/// 1. the source account has enough funds for the transfer (i.e. the final balance is a
/// 64-bit positive number)
/// 2. the transfer amount is a 64-bit positive number
pub proof: PodRangeProof128, // 736 bytes
/// Ephemeral state between the two transfer instruction data
pub ephemeral_state: TransferEphemeralState, // 128 bytes
}
#[cfg(not(target_arch = "bpf"))]
impl Verifiable for TransferRangeProofData {
fn verify(&self) -> Result<(), ProofError> {
let mut transcript = Transcript::new(b"TransferRangeProof");
// standard range proof verification
let proof: RangeProof = self.proof.try_into()?;
proof.verify_with(
vec![
&self.ephemeral_state.spendable_comm_verification.into(),
&self.amount_comms.lo.into(),
&self.amount_comms.hi.into(),
],
vec![64_usize, 32_usize, 32_usize],
Some(self.ephemeral_state.x.into()),
Some(self.ephemeral_state.z.into()),
&mut transcript,
)
}
}
#[derive(Clone, Copy, Pod, Zeroable)]
#[repr(C)]
pub struct TransferValidityProofData {
/// The decryption handles that allow decryption of the lo-bits
pub decryption_handles_lo: TransferHandles, // 96 bytes
/// The decryption handles that allow decryption of the hi-bits
pub decryption_handles_hi: TransferHandles, // 96 bytes
/// The public encryption keys associated with the transfer: source, dest, and auditor
pub transfer_public_keys: TransferPubKeys, // 96 bytes
/// The final spendable ciphertext after the transfer
pub new_spendable_ct: PodElGamalCT, // 64 bytes
/// Proof that certifies that the decryption handles are generated correctly
pub proof: ValidityProof, // 160 bytes
/// Ephemeral state between the two transfer instruction data
pub ephemeral_state: TransferEphemeralState, // 128 bytes
}
/// The joint data that is shared between the two transfer instructions.
///
/// Identical ephemeral data should be included in the two transfer instructions and this should be
/// checked by the ZK Token program.
#[derive(Clone, Copy, Pod, Zeroable, PartialEq)]
#[repr(C)]
pub struct TransferEphemeralState {
pub spendable_comm_verification: PodPedersenComm, // 32 bytes
pub x: PodScalar, // 32 bytes
pub z: PodScalar, // 32 bytes
pub t_x_blinding: PodScalar, // 32 bytes
}
#[cfg(not(target_arch = "bpf"))]
impl Verifiable for TransferValidityProofData {
fn verify(&self) -> Result<(), ProofError> {
self.proof.verify(
&self.new_spendable_ct.try_into()?,
&self.decryption_handles_lo,
&self.decryption_handles_hi,
&self.transfer_public_keys,
&self.ephemeral_state,
)
}
}
/// Just a grouping struct for the two proofs that are needed for a transfer instruction. The two
/// proofs have to be generated together as they share joint data.
pub struct TransferProofs {
pub range_proof: PodRangeProof128,
pub validity_proof: ValidityProof,
}
#[allow(non_snake_case)]
#[cfg(not(target_arch = "bpf"))]
impl TransferProofs {
#[allow(clippy::too_many_arguments)]
#[allow(clippy::many_single_char_names)]
pub fn new(
source_sk: &ElGamalSK,
source_pk: &ElGamalPK,
dest_pk: &ElGamalPK,
auditor_pk: &ElGamalPK,
transfer_amt: (u64, u64),
lo_open: &PedersenOpen,
hi_open: &PedersenOpen,
new_spendable_balance: u64,
new_spendable_ct: &ElGamalCT,
) -> (Self, TransferEphemeralState) {
// TODO: should also commit to pubkeys and commitments later
let mut transcript_validity_proof = merlin::Transcript::new(b"TransferValidityProof");
let H = PedersenBase::default().H;
let D = new_spendable_ct.decrypt_handle.get_point();
let s = source_sk.get_scalar();
// Generate proof for the new spendable ciphertext
let r_new = Scalar::random(&mut OsRng);
let y = Scalar::random(&mut OsRng);
let R = RistrettoPoint::multiscalar_mul(vec![y, r_new], vec![D, H]).compress();
transcript_validity_proof.append_point(b"R", &R);
let c = transcript_validity_proof.challenge_scalar(b"c");
let z = s + c * y;
let new_spendable_open = PedersenOpen(c * r_new);
let spendable_comm_verification =
Pedersen::commit_with(new_spendable_balance, &new_spendable_open);
// Generate proof for the transfer amounts
let t_1_blinding = PedersenOpen::random(&mut OsRng);
let t_2_blinding = PedersenOpen::random(&mut OsRng);
let u = transcript_validity_proof.challenge_scalar(b"u");
let P_joint = RistrettoPoint::multiscalar_mul(
vec![Scalar::one(), u, u * u],
vec![
source_pk.get_point(),
dest_pk.get_point(),
auditor_pk.get_point(),
],
);
let T_joint = (new_spendable_open.get_scalar() * P_joint).compress();
let T_1 = (t_1_blinding.get_scalar() * P_joint).compress();
let T_2 = (t_2_blinding.get_scalar() * P_joint).compress();
transcript_validity_proof.append_point(b"T_1", &T_1);
transcript_validity_proof.append_point(b"T_2", &T_2);
// define the validity proof
let validity_proof = ValidityProof {
R: R.into(),
z: z.into(),
T_joint: T_joint.into(),
T_1: T_1.into(),
T_2: T_2.into(),
};
// generate the range proof
let mut transcript_range_proof = Transcript::new(b"TransferRangeProof");
let (range_proof, x, z) = RangeProof::create_with(
vec![new_spendable_balance, transfer_amt.0, transfer_amt.1],
vec![64, 32, 32],
vec![&new_spendable_open, lo_open, hi_open],
&t_1_blinding,
&t_2_blinding,
&mut transcript_range_proof,
);
// define ephemeral state
let ephemeral_state = TransferEphemeralState {
spendable_comm_verification: spendable_comm_verification.into(),
x: x.into(),
z: z.into(),
t_x_blinding: range_proof.t_x_blinding.into(),
};
(
Self {
range_proof: range_proof.try_into().expect("valid range_proof"),
validity_proof,
},
ephemeral_state,
)
}
}
/// Proof components for transfer instructions.
///
/// These two components should be output by a RangeProof creation function.
#[allow(non_snake_case)]
#[derive(Clone, Copy, Pod, Zeroable)]
#[repr(C)]
pub struct ValidityProof {
// Proof component for the spendable ciphertext components: R
pub R: PodCompressedRistretto, // 32 bytes
// Proof component for the spendable ciphertext components: z
pub z: PodScalar, // 32 bytes
// Proof component for the transaction amount components: T_src
pub T_joint: PodCompressedRistretto, // 32 bytes
// Proof component for the transaction amount components: T_1
pub T_1: PodCompressedRistretto, // 32 bytes
// Proof component for the transaction amount components: T_2
pub T_2: PodCompressedRistretto, // 32 bytes
}
#[allow(non_snake_case)]
#[cfg(not(target_arch = "bpf"))]
impl ValidityProof {
pub fn verify(
self,
new_spendable_ct: &ElGamalCT,
decryption_handles_lo: &TransferHandles,
decryption_handles_hi: &TransferHandles,
transfer_public_keys: &TransferPubKeys,
ephemeral_state: &TransferEphemeralState,
) -> Result<(), ProofError> {
let mut transcript = Transcript::new(b"TransferValidityProof");
let source_pk: ElGamalPK = transfer_public_keys.source_pk.try_into()?;
let dest_pk: ElGamalPK = transfer_public_keys.dest_pk.try_into()?;
let auditor_pk: ElGamalPK = transfer_public_keys.auditor_pk.try_into()?;
// verify Pedersen commitment in the ephemeral state
let C_ephemeral: CompressedRistretto = ephemeral_state.spendable_comm_verification.into();
let C = new_spendable_ct.message_comm.get_point();
let D = new_spendable_ct.decrypt_handle.get_point();
let R = self.R.into();
let z: Scalar = self.z.into();
transcript.validate_and_append_point(b"R", &R)?;
let c = transcript.challenge_scalar(b"c");
let R = R.decompress().ok_or(ProofError::VerificationError)?;
let spendable_comm_verification =
RistrettoPoint::multiscalar_mul(vec![Scalar::one(), -z, c], vec![C, D, R]).compress();
if C_ephemeral != spendable_comm_verification {
return Err(ProofError::VerificationError);
}
// derive joint public key
let u = transcript.challenge_scalar(b"u");
let P_joint = RistrettoPoint::vartime_multiscalar_mul(
vec![Scalar::one(), u, u * u],
vec![
source_pk.get_point(),
dest_pk.get_point(),
auditor_pk.get_point(),
],
);
// check well-formedness of decryption handles
let t_x_blinding: Scalar = ephemeral_state.t_x_blinding.into();
let T_1: CompressedRistretto = self.T_1.into();
let T_2: CompressedRistretto = self.T_2.into();
let x = ephemeral_state.x.into();
let z: Scalar = ephemeral_state.z.into();
let handle_source_lo: PedersenDecHandle = decryption_handles_lo.source.try_into()?;
let handle_dest_lo: PedersenDecHandle = decryption_handles_lo.dest.try_into()?;
let handle_auditor_lo: PedersenDecHandle = decryption_handles_lo.auditor.try_into()?;
let D_joint: CompressedRistretto = self.T_joint.into();
let D_joint = D_joint.decompress().ok_or(ProofError::VerificationError)?;
let D_joint_lo = RistrettoPoint::vartime_multiscalar_mul(
vec![Scalar::one(), u, u * u],
vec![
handle_source_lo.get_point(),
handle_dest_lo.get_point(),
handle_auditor_lo.get_point(),
],
);
let handle_source_hi: PedersenDecHandle = decryption_handles_hi.source.try_into()?;
let handle_dest_hi: PedersenDecHandle = decryption_handles_hi.dest.try_into()?;
let handle_auditor_hi: PedersenDecHandle = decryption_handles_hi.auditor.try_into()?;
let D_joint_hi = RistrettoPoint::vartime_multiscalar_mul(
vec![Scalar::one(), u, u * u],
vec![
handle_source_hi.get_point(),
handle_dest_hi.get_point(),
handle_auditor_hi.get_point(),
],
);
// TODO: combine Pedersen commitment verification above to here for efficiency
// TODO: might need to add an additional proof-of-knowledge here (additional 64 byte)
let mega_check = RistrettoPoint::optional_multiscalar_mul(
vec![-t_x_blinding, x, x * x, z * z, z * z * z, z * z * z * z],
vec![
Some(P_joint),
T_1.decompress(),
T_2.decompress(),
Some(D_joint),
Some(D_joint_lo),
Some(D_joint_hi),
],
)
.ok_or(ProofError::VerificationError)?;
if mega_check.is_identity() {
Ok(())
} else {
Err(ProofError::VerificationError)
}
}
}
/// The ElGamal public keys needed for a transfer
#[derive(Clone, Copy, Pod, Zeroable)]
#[repr(C)]
pub struct TransferPubKeys {
pub source_pk: PodElGamalPK, // 32 bytes
pub dest_pk: PodElGamalPK, // 32 bytes
pub auditor_pk: PodElGamalPK, // 32 bytes
}
/// The transfer amount commitments needed for a transfer
#[derive(Clone, Copy, Pod, Zeroable)]
#[repr(C)]
pub struct TransferComms {
pub lo: PodPedersenComm, // 32 bytes
pub hi: PodPedersenComm, // 32 bytes
}
/// The decryption handles needed for a transfer
#[derive(Clone, Copy, Pod, Zeroable)]
#[repr(C)]
pub struct TransferHandles {
pub source: PodPedersenDecHandle, // 32 bytes
pub dest: PodPedersenDecHandle, // 32 bytes
pub auditor: PodPedersenDecHandle, // 32 bytes
}
/// Split u64 number into two u32 numbers
#[cfg(not(target_arch = "bpf"))]
pub fn split_u64_into_u32(amt: u64) -> (u32, u32) {
let lo = amt as u32;
let hi = (amt >> 32) as u32;
(lo, hi)
}
/// Constant for 2^32
#[cfg(not(target_arch = "bpf"))]
const TWO_32: u64 = 4294967296;
#[cfg(not(target_arch = "bpf"))]
pub fn combine_u32_comms(comm_lo: PedersenComm, comm_hi: PedersenComm) -> PedersenComm {
comm_lo + comm_hi * Scalar::from(TWO_32)
}
#[cfg(not(target_arch = "bpf"))]
pub fn combine_u32_handles(
handle_lo: PedersenDecHandle,
handle_hi: PedersenDecHandle,
) -> PedersenDecHandle {
handle_lo + handle_hi * Scalar::from(TWO_32)
}
#[cfg(not(target_arch = "bpf"))]
pub fn combine_u32_ciphertexts(ct_lo: ElGamalCT, ct_hi: ElGamalCT) -> ElGamalCT {
ct_lo + ct_hi * Scalar::from(TWO_32)
}
#[cfg(test)]
mod test {
use super::*;
use crate::encryption::elgamal::ElGamal;
#[test]
fn test_transfer_correctness() {
// ElGamal keys for source, destination, and auditor accounts
let (source_pk, source_sk) = ElGamal::keygen();
let (dest_pk, _) = ElGamal::keygen();
let (auditor_pk, _) = ElGamal::keygen();
// create source account spendable ciphertext
let spendable_balance: u64 = 77;
let spendable_ct = source_pk.encrypt(spendable_balance);
// transfer amount
let transfer_amount: u64 = 55;
// create transfer data
let transfer_data = TransferData::new(
transfer_amount,
spendable_balance,
spendable_ct,
source_pk,
&source_sk,
dest_pk,
auditor_pk,
);
// verify range proof
assert!(transfer_data.range_proof.verify().is_ok());
// verify ciphertext validity proof
assert!(transfer_data.validity_proof.verify().is_ok());
}
}

View File

@ -0,0 +1,271 @@
use {
crate::pod::*,
bytemuck::{Pod, Zeroable},
};
#[cfg(not(target_arch = "bpf"))]
use {
crate::{
encryption::{
elgamal::{ElGamalCT, ElGamalPK, ElGamalSK},
pedersen::PedersenBase,
},
errors::ProofError,
instruction::Verifiable,
transcript::TranscriptProtocol,
},
curve25519_dalek::{
ristretto::RistrettoPoint,
scalar::Scalar,
traits::{IsIdentity, MultiscalarMul},
},
merlin::Transcript,
rand::rngs::OsRng,
std::convert::TryInto,
};
/// This struct includes the cryptographic proof *and* the account data information needed to verify
/// the proof
///
/// - The pre-instruction should call UpdateAccountPKData::verify(&self)
/// - The actual program should check that `current_ct` is consistent with what is
/// currently stored in the confidential token account
///
#[derive(Clone, Copy, Pod, Zeroable)]
#[repr(C)]
pub struct UpdateAccountPkData {
/// Current ElGamal encryption key
pub current_pk: PodElGamalPK, // 32 bytes
/// Current encrypted available balance
pub current_ct: PodElGamalCT, // 64 bytes
/// New ElGamal encryption key
pub new_pk: PodElGamalPK, // 32 bytes
/// New encrypted available balance
pub new_ct: PodElGamalCT, // 64 bytes
/// Proof that the current and new ciphertexts are consistent
pub proof: UpdateAccountPkProof, // 160 bytes
}
impl UpdateAccountPkData {
#[cfg(not(target_arch = "bpf"))]
pub fn new(
current_balance: u64,
current_ct: ElGamalCT,
current_pk: ElGamalPK,
current_sk: &ElGamalSK,
new_pk: ElGamalPK,
new_sk: &ElGamalSK,
) -> Self {
let new_ct = new_pk.encrypt(current_balance);
let proof =
UpdateAccountPkProof::new(current_balance, current_sk, new_sk, &current_ct, &new_ct);
Self {
current_pk: current_pk.into(),
current_ct: current_ct.into(),
new_ct: new_ct.into(),
new_pk: new_pk.into(),
proof,
}
}
}
#[cfg(not(target_arch = "bpf"))]
impl Verifiable for UpdateAccountPkData {
fn verify(&self) -> Result<(), ProofError> {
let current_ct = self.current_ct.try_into()?;
let new_ct = self.new_ct.try_into()?;
self.proof.verify(&current_ct, &new_ct)
}
}
/// This struct represents the cryptographic proof component that certifies that the current_ct and
/// new_ct encrypt equal values
#[derive(Clone, Copy, Pod, Zeroable)]
#[repr(C)]
#[allow(non_snake_case)]
pub struct UpdateAccountPkProof {
pub R_0: PodCompressedRistretto, // 32 bytes
pub R_1: PodCompressedRistretto, // 32 bytes
pub z_sk_0: PodScalar, // 32 bytes
pub z_sk_1: PodScalar, // 32 bytes
pub z_x: PodScalar, // 32 bytes
}
#[allow(non_snake_case)]
#[cfg(not(target_arch = "bpf"))]
impl UpdateAccountPkProof {
fn transcript_new() -> Transcript {
Transcript::new(b"UpdateAccountPkProof")
}
fn new(
current_balance: u64,
current_sk: &ElGamalSK,
new_sk: &ElGamalSK,
current_ct: &ElGamalCT,
new_ct: &ElGamalCT,
) -> Self {
let mut transcript = Self::transcript_new();
// add a domain separator to record the start of the protocol
transcript.update_account_public_key_proof_domain_sep();
// extract the relevant scalar and Ristretto points from the input
let s_0 = current_sk.get_scalar();
let s_1 = new_sk.get_scalar();
let x = Scalar::from(current_balance);
let D_0 = current_ct.decrypt_handle.get_point();
let D_1 = new_ct.decrypt_handle.get_point();
let G = PedersenBase::default().G;
// generate a random masking factor that also serves as a nonce
let r_sk_0 = Scalar::random(&mut OsRng);
let r_sk_1 = Scalar::random(&mut OsRng);
let r_x = Scalar::random(&mut OsRng);
let R_0 = (r_sk_0 * D_0 + r_x * G).compress();
let R_1 = (r_sk_1 * D_1 + r_x * G).compress();
// record R_0, R_1 on transcript and receive a challenge scalar
transcript.append_point(b"R_0", &R_0);
transcript.append_point(b"R_1", &R_1);
let c = transcript.challenge_scalar(b"c");
let _w = transcript.challenge_scalar(b"w"); // for consistency of transcript
// compute the masked secret keys and amount
let z_sk_0 = c * s_0 + r_sk_0;
let z_sk_1 = c * s_1 + r_sk_1;
let z_x = c * x + r_x;
UpdateAccountPkProof {
R_0: R_0.into(),
R_1: R_1.into(),
z_sk_0: z_sk_0.into(),
z_sk_1: z_sk_1.into(),
z_x: z_x.into(),
}
}
fn verify(&self, current_ct: &ElGamalCT, new_ct: &ElGamalCT) -> Result<(), ProofError> {
let mut transcript = Self::transcript_new();
// add a domain separator to record the start of the protocol
transcript.update_account_public_key_proof_domain_sep();
// extract the relevant scalar and Ristretto points from the input
let C_0 = current_ct.message_comm.get_point();
let D_0 = current_ct.decrypt_handle.get_point();
let C_1 = new_ct.message_comm.get_point();
let D_1 = new_ct.decrypt_handle.get_point();
let R_0 = self.R_0.into();
let R_1 = self.R_1.into();
let z_sk_0 = self.z_sk_0.into();
let z_sk_1: Scalar = self.z_sk_1.into();
let z_x = self.z_x.into();
let G = PedersenBase::default().G;
// generate a challenge scalar
transcript.validate_and_append_point(b"R_0", &R_0)?;
transcript.validate_and_append_point(b"R_1", &R_1)?;
let c = transcript.challenge_scalar(b"c");
let w = transcript.challenge_scalar(b"w");
// decompress R_0, R_1 or return verification error
let R_0 = R_0.decompress().ok_or(ProofError::VerificationError)?;
let R_1 = R_1.decompress().ok_or(ProofError::VerificationError)?;
// check the required algebraic relation
let check = RistrettoPoint::multiscalar_mul(
vec![
z_sk_0,
z_x,
-c,
-Scalar::one(),
w * z_sk_1,
w * z_x,
-w * c,
-w * Scalar::one(),
],
vec![D_0, G, C_0, R_0, D_1, G, C_1, R_1],
);
if check.is_identity() {
Ok(())
} else {
Err(ProofError::VerificationError)
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::encryption::elgamal::ElGamal;
#[test]
fn test_update_account_public_key_correctness() {
let (current_pk, current_sk) = ElGamal::keygen();
let (new_pk, new_sk) = ElGamal::keygen();
// If current_ct and new_ct encrypt same values, then the proof verification should succeed
let balance: u64 = 77;
let current_ct = current_pk.encrypt(balance);
let new_ct = new_pk.encrypt(balance);
let proof = UpdateAccountPkProof::new(balance, &current_sk, &new_sk, &current_ct, &new_ct);
assert!(proof.verify(&current_ct, &new_ct).is_ok());
// If current_ct and new_ct encrypt different values, then the proof verification should fail
let new_ct = new_pk.encrypt(55_u64);
let proof = UpdateAccountPkProof::new(balance, &current_sk, &new_sk, &current_ct, &new_ct);
assert!(proof.verify(&current_ct, &new_ct).is_err());
// A zeroed cipehrtext should be considered as an account balance of 0
let balance: u64 = 0;
let zeroed_ct_as_current_ct: ElGamalCT = PodElGamalCT::zeroed().try_into().unwrap();
let new_ct: ElGamalCT = new_pk.encrypt(balance);
let proof = UpdateAccountPkProof::new(
balance,
&current_sk,
&new_sk,
&zeroed_ct_as_current_ct,
&new_ct,
);
assert!(proof.verify(&zeroed_ct_as_current_ct, &new_ct).is_ok());
let current_ct: ElGamalCT = PodElGamalCT::zeroed().try_into().unwrap();
let zeroed_ct_as_new_ct: ElGamalCT = PodElGamalCT::zeroed().try_into().unwrap();
let proof = UpdateAccountPkProof::new(
balance,
&current_sk,
&new_sk,
&current_ct,
&zeroed_ct_as_new_ct,
);
assert!(proof.verify(&current_ct, &zeroed_ct_as_new_ct).is_ok());
let zeroed_ct_as_current_ct: ElGamalCT = PodElGamalCT::zeroed().try_into().unwrap();
let zeroed_ct_as_new_ct: ElGamalCT = PodElGamalCT::zeroed().try_into().unwrap();
let proof = UpdateAccountPkProof::new(
balance,
&current_sk,
&new_sk,
&zeroed_ct_as_current_ct,
&zeroed_ct_as_new_ct,
);
assert!(proof
.verify(&zeroed_ct_as_current_ct, &zeroed_ct_as_new_ct)
.is_ok());
}
}

View File

@ -0,0 +1,207 @@
use {
crate::pod::*,
bytemuck::{Pod, Zeroable},
};
#[cfg(not(target_arch = "bpf"))]
use {
crate::{
encryption::{
elgamal::{ElGamalCT, ElGamalPK, ElGamalSK},
pedersen::{PedersenBase, PedersenOpen},
},
errors::ProofError,
instruction::Verifiable,
range_proof::RangeProof,
transcript::TranscriptProtocol,
},
curve25519_dalek::{ristretto::RistrettoPoint, scalar::Scalar, traits::MultiscalarMul},
merlin::Transcript,
rand::rngs::OsRng,
std::convert::TryInto,
};
/// This struct includes the cryptographic proof *and* the account data information needed to verify
/// the proof
///
/// - The pre-instruction should call WithdrawData::verify_proof(&self)
/// - The actual program should check that `current_ct` is consistent with what is
/// currently stored in the confidential token account TODO: update this statement
///
#[derive(Clone, Copy, Pod, Zeroable)]
#[repr(C)]
pub struct WithdrawData {
/// The source account available balance *after* the withdraw (encrypted by
/// `source_pk`
pub final_balance_ct: PodElGamalCT, // 64 bytes
/// Proof that the account is solvent
pub proof: WithdrawProof, // 736 bytes
}
impl WithdrawData {
#[cfg(not(target_arch = "bpf"))]
pub fn new(
amount: u64,
source_pk: ElGamalPK,
source_sk: &ElGamalSK,
current_balance: u64,
current_balance_ct: ElGamalCT,
) -> Self {
// subtract withdraw amount from current balance
//
// panics if current_balance < amount
let final_balance = current_balance - amount;
// encode withdraw amount as an ElGamal ciphertext and subtract it from
// current source balance
let amount_encoded = source_pk.encrypt_with(amount, &PedersenOpen::default());
let final_balance_ct = current_balance_ct - amount_encoded;
let proof = WithdrawProof::new(source_sk, final_balance, &final_balance_ct);
Self {
final_balance_ct: final_balance_ct.into(),
proof,
}
}
}
#[cfg(not(target_arch = "bpf"))]
impl Verifiable for WithdrawData {
fn verify(&self) -> Result<(), ProofError> {
let final_balance_ct = self.final_balance_ct.try_into()?;
self.proof.verify(&final_balance_ct)
}
}
/// This struct represents the cryptographic proof component that certifies the account's solvency
/// for withdrawal
#[derive(Clone, Copy, Pod, Zeroable)]
#[repr(C)]
#[allow(non_snake_case)]
pub struct WithdrawProof {
/// Wrapper for range proof: R component
pub R: PodCompressedRistretto, // 32 bytes
/// Wrapper for range proof: z component
pub z: PodScalar, // 32 bytes
/// Associated range proof
pub range_proof: PodRangeProof64, // 672 bytes
}
#[allow(non_snake_case)]
#[cfg(not(target_arch = "bpf"))]
impl WithdrawProof {
fn transcript_new() -> Transcript {
Transcript::new(b"WithdrawProof")
}
pub fn new(source_sk: &ElGamalSK, final_balance: u64, final_balance_ct: &ElGamalCT) -> Self {
let mut transcript = Self::transcript_new();
// add a domain separator to record the start of the protocol
transcript.withdraw_proof_domain_sep();
// extract the relevant scalar and Ristretto points from the input
let H = PedersenBase::default().H;
let D = final_balance_ct.decrypt_handle.get_point();
let s = source_sk.get_scalar();
// new pedersen opening
let r_new = Scalar::random(&mut OsRng);
// generate a random masking factor that also serves as a nonce
let y = Scalar::random(&mut OsRng);
let R = RistrettoPoint::multiscalar_mul(vec![y, r_new], vec![D, H]).compress();
// record R on transcript and receive a challenge scalar
transcript.append_point(b"R", &R);
let c = transcript.challenge_scalar(b"c");
// compute the masked secret key
let z = s + c * y;
// compute the new Pedersen commitment and opening
let new_open = PedersenOpen(c * r_new);
let range_proof = RangeProof::create(
vec![final_balance],
vec![64],
vec![&new_open],
&mut transcript,
);
WithdrawProof {
R: R.into(),
z: z.into(),
range_proof: range_proof.try_into().expect("range proof"),
}
}
pub fn verify(&self, final_balance_ct: &ElGamalCT) -> Result<(), ProofError> {
let mut transcript = Self::transcript_new();
// Add a domain separator to record the start of the protocol
transcript.withdraw_proof_domain_sep();
// Extract the relevant scalar and Ristretto points from the input
let C = final_balance_ct.message_comm.get_point();
let D = final_balance_ct.decrypt_handle.get_point();
let R = self.R.into();
let z: Scalar = self.z.into();
// generate a challenge scalar
transcript.validate_and_append_point(b"R", &R)?;
let c = transcript.challenge_scalar(b"c");
// decompress R or return verification error
let R = R.decompress().ok_or(ProofError::VerificationError)?;
// compute new Pedersen commitment to verify range proof with
let new_comm = RistrettoPoint::multiscalar_mul(vec![Scalar::one(), -z, c], vec![C, D, R]);
let range_proof: RangeProof = self.range_proof.try_into()?;
range_proof.verify(vec![&new_comm.compress()], vec![64_usize], &mut transcript)
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::encryption::elgamal::ElGamal;
#[test]
#[ignore]
fn test_withdraw_correctness() {
// generate and verify proof for the proper setting
let (source_pk, source_sk) = ElGamal::keygen();
let current_balance: u64 = 77;
let current_balance_ct = source_pk.encrypt(current_balance);
let withdraw_amount: u64 = 55;
let data = WithdrawData::new(
withdraw_amount,
source_pk,
&source_sk,
current_balance,
current_balance_ct,
);
assert!(data.verify().is_ok());
// generate and verify proof with wrong balance
let wrong_balance: u64 = 99;
let data = WithdrawData::new(
withdraw_amount,
source_pk,
&source_sk,
wrong_balance,
current_balance_ct,
);
assert!(data.verify().is_err());
// TODO: test for ciphertexts that encrypt numbers outside the 0, 2^64 range
}
}

19
zk-token-sdk/src/lib.rs Normal file
View File

@ -0,0 +1,19 @@
#[cfg(not(target_arch = "bpf"))]
#[macro_use]
pub(crate) mod macros;
#[cfg(not(target_arch = "bpf"))]
pub mod encryption;
#[cfg(not(target_arch = "bpf"))]
mod errors;
#[cfg(not(target_arch = "bpf"))]
mod range_proof;
#[cfg(not(target_arch = "bpf"))]
mod transcript;
mod instruction;
pub mod pod;
pub mod zk_token_proof_instruction;
pub mod zk_token_proof_program;

101
zk-token-sdk/src/macros.rs Normal file
View File

@ -0,0 +1,101 @@
// Internal macros, used to defined the 'main' impls
#[macro_export]
macro_rules! define_add_variants {
(LHS = $lhs:ty, RHS = $rhs:ty, Output = $out:ty) => {
impl<'b> Add<&'b $rhs> for $lhs {
type Output = $out;
fn add(self, rhs: &'b $rhs) -> $out {
&self + rhs
}
}
impl<'a> Add<$rhs> for &'a $lhs {
type Output = $out;
fn add(self, rhs: $rhs) -> $out {
self + &rhs
}
}
impl Add<$rhs> for $lhs {
type Output = $out;
fn add(self, rhs: $rhs) -> $out {
&self + &rhs
}
}
};
}
macro_rules! define_sub_variants {
(LHS = $lhs:ty, RHS = $rhs:ty, Output = $out:ty) => {
impl<'b> Sub<&'b $rhs> for $lhs {
type Output = $out;
fn sub(self, rhs: &'b $rhs) -> $out {
&self - rhs
}
}
impl<'a> Sub<$rhs> for &'a $lhs {
type Output = $out;
fn sub(self, rhs: $rhs) -> $out {
self - &rhs
}
}
impl Sub<$rhs> for $lhs {
type Output = $out;
fn sub(self, rhs: $rhs) -> $out {
&self - &rhs
}
}
};
}
macro_rules! define_mul_variants {
(LHS = $lhs:ty, RHS = $rhs:ty, Output = $out:ty) => {
impl<'b> Mul<&'b $rhs> for $lhs {
type Output = $out;
fn mul(self, rhs: &'b $rhs) -> $out {
&self * rhs
}
}
impl<'a> Mul<$rhs> for &'a $lhs {
type Output = $out;
fn mul(self, rhs: $rhs) -> $out {
self * &rhs
}
}
impl Mul<$rhs> for $lhs {
type Output = $out;
fn mul(self, rhs: $rhs) -> $out {
&self * &rhs
}
}
};
}
macro_rules! define_div_variants {
(LHS = $lhs:ty, RHS = $rhs:ty, Output = $out:ty) => {
impl<'b> Div<&'b $rhs> for $lhs {
type Output = $out;
fn div(self, rhs: &'b $rhs) -> $out {
&self / rhs
}
}
impl<'a> Div<$rhs> for &'a $lhs {
type Output = $out;
fn div(self, rhs: $rhs) -> $out {
self / &rhs
}
}
impl Div<$rhs> for $lhs {
type Output = $out;
fn div(self, rhs: $rhs) -> $out {
&self / &rhs
}
}
};
}

596
zk-token-sdk/src/pod.rs Normal file
View File

@ -0,0 +1,596 @@
//! Plain Old Data wrappers for types that need to be sent over the wire
use bytemuck::{Pod, Zeroable};
#[cfg(not(target_arch = "bpf"))]
use {
crate::{
encryption::elgamal::{ElGamalCT, ElGamalPK},
encryption::pedersen::{PedersenComm, PedersenDecHandle},
errors::ProofError,
range_proof::RangeProof,
},
curve25519_dalek::{
constants::RISTRETTO_BASEPOINT_COMPRESSED, ristretto::CompressedRistretto, scalar::Scalar,
},
std::{
convert::{TryFrom, TryInto},
fmt,
},
};
#[derive(Clone, Copy, Pod, Zeroable, PartialEq)]
#[repr(transparent)]
pub struct PodScalar([u8; 32]);
#[cfg(not(target_arch = "bpf"))]
impl From<Scalar> for PodScalar {
fn from(scalar: Scalar) -> Self {
Self(scalar.to_bytes())
}
}
#[cfg(not(target_arch = "bpf"))]
impl From<PodScalar> for Scalar {
fn from(pod: PodScalar) -> Self {
Scalar::from_bits(pod.0)
}
}
#[derive(Clone, Copy, Pod, Zeroable)]
#[repr(transparent)]
pub struct PodCompressedRistretto([u8; 32]);
#[cfg(not(target_arch = "bpf"))]
impl From<CompressedRistretto> for PodCompressedRistretto {
fn from(cr: CompressedRistretto) -> Self {
Self(cr.to_bytes())
}
}
#[cfg(not(target_arch = "bpf"))]
impl From<PodCompressedRistretto> for CompressedRistretto {
fn from(pod: PodCompressedRistretto) -> Self {
Self(pod.0)
}
}
#[derive(Clone, Copy, Pod, Zeroable, PartialEq)]
#[repr(transparent)]
pub struct PodElGamalCT([u8; 64]);
#[cfg(not(target_arch = "bpf"))]
impl From<ElGamalCT> for PodElGamalCT {
fn from(ct: ElGamalCT) -> Self {
Self(ct.to_bytes())
}
}
#[cfg(not(target_arch = "bpf"))]
impl TryFrom<PodElGamalCT> for ElGamalCT {
type Error = ProofError;
fn try_from(pod: PodElGamalCT) -> Result<Self, Self::Error> {
Self::from_bytes(&pod.0).ok_or(ProofError::InconsistentCTData)
}
}
impl From<(PodPedersenComm, PodPedersenDecHandle)> for PodElGamalCT {
fn from((pod_comm, pod_decrypt_handle): (PodPedersenComm, PodPedersenDecHandle)) -> Self {
let mut buf = [0_u8; 64];
buf[..32].copy_from_slice(&pod_comm.0);
buf[32..].copy_from_slice(&pod_decrypt_handle.0);
PodElGamalCT(buf)
}
}
#[cfg(not(target_arch = "bpf"))]
impl fmt::Debug for PodElGamalCT {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", self.0)
}
}
#[derive(Clone, Copy, Pod, Zeroable, PartialEq)]
#[repr(transparent)]
pub struct PodElGamalPK([u8; 32]);
#[cfg(not(target_arch = "bpf"))]
impl From<ElGamalPK> for PodElGamalPK {
fn from(pk: ElGamalPK) -> Self {
Self(pk.to_bytes())
}
}
#[cfg(not(target_arch = "bpf"))]
impl TryFrom<PodElGamalPK> for ElGamalPK {
type Error = ProofError;
fn try_from(pod: PodElGamalPK) -> Result<Self, Self::Error> {
Self::from_bytes(&pod.0).ok_or(ProofError::InconsistentCTData)
}
}
#[cfg(not(target_arch = "bpf"))]
impl fmt::Debug for PodElGamalPK {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", self.0)
}
}
#[derive(Clone, Copy, Pod, Zeroable, PartialEq)]
#[repr(transparent)]
pub struct PodPedersenComm([u8; 32]);
#[cfg(not(target_arch = "bpf"))]
impl From<PedersenComm> for PodPedersenComm {
fn from(comm: PedersenComm) -> Self {
Self(comm.to_bytes())
}
}
// For proof verification, interpret PodPedersenComm directly as CompressedRistretto
#[cfg(not(target_arch = "bpf"))]
impl From<PodPedersenComm> for CompressedRistretto {
fn from(pod: PodPedersenComm) -> Self {
Self(pod.0)
}
}
#[cfg(not(target_arch = "bpf"))]
impl TryFrom<PodPedersenComm> for PedersenComm {
type Error = ProofError;
fn try_from(pod: PodPedersenComm) -> Result<Self, Self::Error> {
Self::from_bytes(&pod.0).ok_or(ProofError::InconsistentCTData)
}
}
#[cfg(not(target_arch = "bpf"))]
impl fmt::Debug for PodPedersenComm {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", self.0)
}
}
#[derive(Clone, Copy, Pod, Zeroable, PartialEq)]
#[repr(transparent)]
pub struct PodPedersenDecHandle([u8; 32]);
#[cfg(not(target_arch = "bpf"))]
impl From<PedersenDecHandle> for PodPedersenDecHandle {
fn from(handle: PedersenDecHandle) -> Self {
Self(handle.to_bytes())
}
}
// For proof verification, interpret PodPedersenDecHandle as CompressedRistretto
#[cfg(not(target_arch = "bpf"))]
impl From<PodPedersenDecHandle> for CompressedRistretto {
fn from(pod: PodPedersenDecHandle) -> Self {
Self(pod.0)
}
}
#[cfg(not(target_arch = "bpf"))]
impl TryFrom<PodPedersenDecHandle> for PedersenDecHandle {
type Error = ProofError;
fn try_from(pod: PodPedersenDecHandle) -> Result<Self, Self::Error> {
Self::from_bytes(&pod.0).ok_or(ProofError::InconsistentCTData)
}
}
#[cfg(not(target_arch = "bpf"))]
impl fmt::Debug for PodPedersenDecHandle {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", self.0)
}
}
/// Serialization of range proofs for 64-bit numbers (for `Withdraw` instruction)
#[derive(Clone, Copy)]
#[repr(transparent)]
pub struct PodRangeProof64([u8; 672]);
// `PodRangeProof64` is a Pod and Zeroable.
// Add the marker traits manually because `bytemuck` only adds them for some `u8` arrays
unsafe impl Zeroable for PodRangeProof64 {}
unsafe impl Pod for PodRangeProof64 {}
#[cfg(not(target_arch = "bpf"))]
impl TryFrom<RangeProof> for PodRangeProof64 {
type Error = ProofError;
fn try_from(proof: RangeProof) -> Result<Self, Self::Error> {
if proof.ipp_proof.serialized_size() != 448 {
return Err(ProofError::VerificationError);
}
let mut buf = [0_u8; 672];
buf[..32].copy_from_slice(proof.A.as_bytes());
buf[32..64].copy_from_slice(proof.S.as_bytes());
buf[64..96].copy_from_slice(proof.T_1.as_bytes());
buf[96..128].copy_from_slice(proof.T_2.as_bytes());
buf[128..160].copy_from_slice(proof.t_x.as_bytes());
buf[160..192].copy_from_slice(proof.t_x_blinding.as_bytes());
buf[192..224].copy_from_slice(proof.e_blinding.as_bytes());
buf[224..672].copy_from_slice(&proof.ipp_proof.to_bytes());
Ok(PodRangeProof64(buf))
}
}
#[cfg(not(target_arch = "bpf"))]
impl TryFrom<PodRangeProof64> for RangeProof {
type Error = ProofError;
fn try_from(pod: PodRangeProof64) -> Result<Self, Self::Error> {
Self::from_bytes(&pod.0)
}
}
/// Serialization of range proofs for 128-bit numbers (for `TransferRangeProof` instruction)
#[derive(Clone, Copy)]
#[repr(transparent)]
pub struct PodRangeProof128([u8; 736]);
// `PodRangeProof128` is a Pod and Zeroable.
// Add the marker traits manually because `bytemuck` only adds them for some `u8` arrays
unsafe impl Zeroable for PodRangeProof128 {}
unsafe impl Pod for PodRangeProof128 {}
#[cfg(not(target_arch = "bpf"))]
impl TryFrom<RangeProof> for PodRangeProof128 {
type Error = ProofError;
fn try_from(proof: RangeProof) -> Result<Self, Self::Error> {
if proof.ipp_proof.serialized_size() != 512 {
return Err(ProofError::VerificationError);
}
let mut buf = [0_u8; 736];
buf[..32].copy_from_slice(proof.A.as_bytes());
buf[32..64].copy_from_slice(proof.S.as_bytes());
buf[64..96].copy_from_slice(proof.T_1.as_bytes());
buf[96..128].copy_from_slice(proof.T_2.as_bytes());
buf[128..160].copy_from_slice(proof.t_x.as_bytes());
buf[160..192].copy_from_slice(proof.t_x_blinding.as_bytes());
buf[192..224].copy_from_slice(proof.e_blinding.as_bytes());
buf[224..736].copy_from_slice(&proof.ipp_proof.to_bytes());
Ok(PodRangeProof128(buf))
}
}
#[cfg(not(target_arch = "bpf"))]
impl TryFrom<PodRangeProof128> for RangeProof {
type Error = ProofError;
fn try_from(pod: PodRangeProof128) -> Result<Self, Self::Error> {
Self::from_bytes(&pod.0)
}
}
pub struct PodElGamalArithmetic;
#[cfg(not(target_arch = "bpf"))]
impl PodElGamalArithmetic {
const TWO_32: u64 = 4294967296;
// On input two scalars x0, x1 and two ciphertexts ct0, ct1,
// returns `Some(x0*ct0 + x1*ct1)` or `None` if the input was invalid
fn add_pod_ciphertexts(
scalar_0: Scalar,
pod_ct_0: PodElGamalCT,
scalar_1: Scalar,
pod_ct_1: PodElGamalCT,
) -> Option<PodElGamalCT> {
let ct_0: ElGamalCT = pod_ct_0.try_into().ok()?;
let ct_1: ElGamalCT = pod_ct_1.try_into().ok()?;
let ct_sum = ct_0 * scalar_0 + ct_1 * scalar_1;
Some(PodElGamalCT::from(ct_sum))
}
fn combine_lo_hi(pod_ct_lo: PodElGamalCT, pod_ct_hi: PodElGamalCT) -> Option<PodElGamalCT> {
Self::add_pod_ciphertexts(
Scalar::one(),
pod_ct_lo,
Scalar::from(Self::TWO_32),
pod_ct_hi,
)
}
pub fn add(pod_ct_0: PodElGamalCT, pod_ct_1: PodElGamalCT) -> Option<PodElGamalCT> {
Self::add_pod_ciphertexts(Scalar::one(), pod_ct_0, Scalar::one(), pod_ct_1)
}
pub fn add_with_lo_hi(
pod_ct_0: PodElGamalCT,
pod_ct_1_lo: PodElGamalCT,
pod_ct_1_hi: PodElGamalCT,
) -> Option<PodElGamalCT> {
let pod_ct_1 = Self::combine_lo_hi(pod_ct_1_lo, pod_ct_1_hi)?;
Self::add_pod_ciphertexts(Scalar::one(), pod_ct_0, Scalar::one(), pod_ct_1)
}
pub fn subtract(pod_ct_0: PodElGamalCT, pod_ct_1: PodElGamalCT) -> Option<PodElGamalCT> {
Self::add_pod_ciphertexts(Scalar::one(), pod_ct_0, -Scalar::one(), pod_ct_1)
}
pub fn subtract_with_lo_hi(
pod_ct_0: PodElGamalCT,
pod_ct_1_lo: PodElGamalCT,
pod_ct_1_hi: PodElGamalCT,
) -> Option<PodElGamalCT> {
let pod_ct_1 = Self::combine_lo_hi(pod_ct_1_lo, pod_ct_1_hi)?;
Self::add_pod_ciphertexts(Scalar::one(), pod_ct_0, -Scalar::one(), pod_ct_1)
}
pub fn add_to(pod_ct: PodElGamalCT, amount: u64) -> Option<PodElGamalCT> {
let mut amount_as_pod_ct = [0_u8; 64];
amount_as_pod_ct[..32].copy_from_slice(RISTRETTO_BASEPOINT_COMPRESSED.as_bytes());
Self::add_pod_ciphertexts(
Scalar::one(),
pod_ct,
Scalar::from(amount),
PodElGamalCT(amount_as_pod_ct),
)
}
pub fn subtract_from(pod_ct: PodElGamalCT, amount: u64) -> Option<PodElGamalCT> {
let mut amount_as_pod_ct = [0_u8; 64];
amount_as_pod_ct[..32].copy_from_slice(RISTRETTO_BASEPOINT_COMPRESSED.as_bytes());
Self::add_pod_ciphertexts(
Scalar::one(),
pod_ct,
-Scalar::from(amount),
PodElGamalCT(amount_as_pod_ct),
)
}
}
#[cfg(target_arch = "bpf")]
#[allow(unused_variables)]
impl PodElGamalArithmetic {
pub fn add(pod_ct_0: PodElGamalCT, pod_ct_1: PodElGamalCT) -> Option<PodElGamalCT> {
None
}
pub fn add_with_lo_hi(
pod_ct_0: PodElGamalCT,
pod_ct_1_lo: PodElGamalCT,
pod_ct_1_hi: PodElGamalCT,
) -> Option<PodElGamalCT> {
None
}
pub fn subtract(pod_ct_0: PodElGamalCT, pod_ct_1: PodElGamalCT) -> Option<PodElGamalCT> {
None
}
pub fn subtract_with_lo_hi(
pod_ct_0: PodElGamalCT,
pod_ct_1_lo: PodElGamalCT,
pod_ct_1_hi: PodElGamalCT,
) -> Option<PodElGamalCT> {
None
}
pub fn add_to(pod_ct: PodElGamalCT, amount: u64) -> Option<PodElGamalCT> {
None
}
pub fn subtract_from(pod_ct: PodElGamalCT, amount: u64) -> Option<PodElGamalCT> {
None
}
}
#[cfg(test)]
mod tests {
use {
super::*,
crate::encryption::{
elgamal::{ElGamal, ElGamalCT},
pedersen::{Pedersen, PedersenOpen},
},
crate::instruction::transfer::split_u64_into_u32,
merlin::Transcript,
rand::rngs::OsRng,
std::convert::TryInto,
};
#[test]
fn test_zero_ct() {
let spendable_balance = PodElGamalCT::zeroed();
let spendable_ct: ElGamalCT = spendable_balance.try_into().unwrap();
// spendable_ct should be an encryption of 0 for any public key when
// `PedersenOpen::default()` is used
let (pk, _) = ElGamal::keygen();
let balance: u64 = 0;
assert_eq!(
spendable_ct,
pk.encrypt_with(balance, &PedersenOpen::default())
);
// homomorphism should work like any other ciphertext
let open = PedersenOpen::random(&mut OsRng);
let transfer_amount_ct = pk.encrypt_with(55_u64, &open);
let transfer_amount_pod: PodElGamalCT = transfer_amount_ct.into();
let sum = PodElGamalArithmetic::add(spendable_balance, transfer_amount_pod).unwrap();
let expected: PodElGamalCT = pk.encrypt_with(55_u64, &open).into();
assert_eq!(expected, sum);
}
#[test]
fn test_add_to() {
let spendable_balance = PodElGamalCT::zeroed();
let added_ct = PodElGamalArithmetic::add_to(spendable_balance, 55).unwrap();
let (pk, _) = ElGamal::keygen();
let expected: PodElGamalCT = pk.encrypt_with(55_u64, &PedersenOpen::default()).into();
assert_eq!(expected, added_ct);
}
#[test]
fn test_subtract_from() {
let amount = 77_u64;
let (pk, _) = ElGamal::keygen();
let open = PedersenOpen::random(&mut OsRng);
let encrypted_amount: PodElGamalCT = pk.encrypt_with(amount, &open).into();
let subtracted_ct = PodElGamalArithmetic::subtract_from(encrypted_amount, 55).unwrap();
let expected: PodElGamalCT = pk.encrypt_with(22_u64, &open).into();
assert_eq!(expected, subtracted_ct);
}
#[test]
fn test_pod_range_proof_64() {
let (comm, open) = Pedersen::commit(55_u64);
let mut transcript_create = Transcript::new(b"Test");
let mut transcript_verify = Transcript::new(b"Test");
let proof = RangeProof::create(vec![55], vec![64], vec![&open], &mut transcript_create);
let proof_serialized: PodRangeProof64 = proof.try_into().unwrap();
let proof_deserialized: RangeProof = proof_serialized.try_into().unwrap();
assert!(proof_deserialized
.verify(
vec![&comm.get_point().compress()],
vec![64],
&mut transcript_verify
)
.is_ok());
// should fail to serialize to PodRangeProof128
let proof = RangeProof::create(vec![55], vec![64], vec![&open], &mut transcript_create);
assert!(TryInto::<PodRangeProof128>::try_into(proof).is_err());
}
#[test]
fn test_pod_range_proof_128() {
let (comm_1, open_1) = Pedersen::commit(55_u64);
let (comm_2, open_2) = Pedersen::commit(77_u64);
let (comm_3, open_3) = Pedersen::commit(99_u64);
let mut transcript_create = Transcript::new(b"Test");
let mut transcript_verify = Transcript::new(b"Test");
let proof = RangeProof::create(
vec![55, 77, 99],
vec![64, 32, 32],
vec![&open_1, &open_2, &open_3],
&mut transcript_create,
);
let comm_1_point = comm_1.get_point().compress();
let comm_2_point = comm_2.get_point().compress();
let comm_3_point = comm_3.get_point().compress();
let proof_serialized: PodRangeProof128 = proof.try_into().unwrap();
let proof_deserialized: RangeProof = proof_serialized.try_into().unwrap();
assert!(proof_deserialized
.verify(
vec![&comm_1_point, &comm_2_point, &comm_3_point],
vec![64, 32, 32],
&mut transcript_verify,
)
.is_ok());
// should fail to serialize to PodRangeProof64
let proof = RangeProof::create(
vec![55, 77, 99],
vec![64, 32, 32],
vec![&open_1, &open_2, &open_3],
&mut transcript_create,
);
assert!(TryInto::<PodRangeProof64>::try_into(proof).is_err());
}
#[test]
fn test_transfer_arithmetic() {
// setup
// transfer amount
let transfer_amount: u64 = 55;
let (amount_lo, amount_hi) = split_u64_into_u32(transfer_amount);
// generate public keys
let (source_pk, _) = ElGamal::keygen();
let (dest_pk, _) = ElGamal::keygen();
let (auditor_pk, _) = ElGamal::keygen();
// commitments associated with TransferRangeProof
let (comm_lo, open_lo) = Pedersen::commit(amount_lo);
let (comm_hi, open_hi) = Pedersen::commit(amount_hi);
let comm_lo: PodPedersenComm = comm_lo.into();
let comm_hi: PodPedersenComm = comm_hi.into();
// decryption handles associated with TransferValidityProof
let handle_source_lo: PodPedersenDecHandle = source_pk.gen_decrypt_handle(&open_lo).into();
let handle_dest_lo: PodPedersenDecHandle = dest_pk.gen_decrypt_handle(&open_lo).into();
let _handle_auditor_lo: PodPedersenDecHandle =
auditor_pk.gen_decrypt_handle(&open_lo).into();
let handle_source_hi: PodPedersenDecHandle = source_pk.gen_decrypt_handle(&open_hi).into();
let handle_dest_hi: PodPedersenDecHandle = dest_pk.gen_decrypt_handle(&open_hi).into();
let _handle_auditor_hi: PodPedersenDecHandle =
auditor_pk.gen_decrypt_handle(&open_hi).into();
// source spendable and recipient pending
let source_open = PedersenOpen::random(&mut OsRng);
let dest_open = PedersenOpen::random(&mut OsRng);
let source_spendable_ct: PodElGamalCT = source_pk.encrypt_with(77_u64, &source_open).into();
let dest_pending_ct: PodElGamalCT = dest_pk.encrypt_with(77_u64, &dest_open).into();
// program arithmetic for the source account
// 1. Combine commitments and handles
let source_lo_ct: PodElGamalCT = (comm_lo, handle_source_lo).into();
let source_hi_ct: PodElGamalCT = (comm_hi, handle_source_hi).into();
// 2. Combine lo and hi ciphertexts
let source_combined_ct =
PodElGamalArithmetic::combine_lo_hi(source_lo_ct, source_hi_ct).unwrap();
// 3. Subtract from available balance
let final_source_spendable =
PodElGamalArithmetic::subtract(source_spendable_ct, source_combined_ct).unwrap();
// test
let final_source_open = source_open
- (open_lo.clone() + open_hi.clone() * Scalar::from(PodElGamalArithmetic::TWO_32));
let expected_source: PodElGamalCT =
source_pk.encrypt_with(22_u64, &final_source_open).into();
assert_eq!(expected_source, final_source_spendable);
// same for the destination account
// 1. Combine commitments and handles
let dest_lo_ct: PodElGamalCT = (comm_lo, handle_dest_lo).into();
let dest_hi_ct: PodElGamalCT = (comm_hi, handle_dest_hi).into();
// 2. Combine lo and hi ciphertexts
let dest_combined_ct = PodElGamalArithmetic::combine_lo_hi(dest_lo_ct, dest_hi_ct).unwrap();
// 3. Add to pending balance
let final_dest_pending =
PodElGamalArithmetic::add(dest_pending_ct, dest_combined_ct).unwrap();
let final_dest_open =
dest_open + (open_lo + open_hi * Scalar::from(PodElGamalArithmetic::TWO_32));
let expected_dest_ct: PodElGamalCT = dest_pk.encrypt_with(132_u64, &final_dest_open).into();
assert_eq!(expected_dest_ct, final_dest_pending);
}
}

View File

@ -0,0 +1,156 @@
use {
curve25519_dalek::{
digest::{ExtendableOutput, Update, XofReader},
ristretto::RistrettoPoint,
},
sha3::{Sha3XofReader, Shake256},
};
/// Generators for Pedersen vector commitments.
///
/// The code is copied from https://github.com/dalek-cryptography/bulletproofs for now...
struct GeneratorsChain {
reader: Sha3XofReader,
}
impl GeneratorsChain {
/// Creates a chain of generators, determined by the hash of `label`.
fn new(label: &[u8]) -> Self {
let mut shake = Shake256::default();
shake.update(b"GeneratorsChain");
shake.update(label);
GeneratorsChain {
reader: shake.finalize_xof(),
}
}
/// Advances the reader n times, squeezing and discarding
/// the result.
fn fast_forward(mut self, n: usize) -> Self {
for _ in 0..n {
let mut buf = [0u8; 64];
self.reader.read(&mut buf);
}
self
}
}
impl Default for GeneratorsChain {
fn default() -> Self {
Self::new(&[])
}
}
impl Iterator for GeneratorsChain {
type Item = RistrettoPoint;
fn next(&mut self) -> Option<Self::Item> {
let mut uniform_bytes = [0u8; 64];
self.reader.read(&mut uniform_bytes);
Some(RistrettoPoint::from_uniform_bytes(&uniform_bytes))
}
fn size_hint(&self) -> (usize, Option<usize>) {
(usize::max_value(), None)
}
}
#[allow(non_snake_case)]
#[derive(Clone)]
pub struct BulletproofGens {
/// The maximum number of usable generators.
pub gens_capacity: usize,
/// Precomputed \\(\mathbf G\\) generators.
G_vec: Vec<RistrettoPoint>,
/// Precomputed \\(\mathbf H\\) generators.
H_vec: Vec<RistrettoPoint>,
}
impl BulletproofGens {
pub fn new(gens_capacity: usize) -> Self {
let mut gens = BulletproofGens {
gens_capacity: 0,
G_vec: Vec::new(),
H_vec: Vec::new(),
};
gens.increase_capacity(gens_capacity);
gens
}
// pub fn new_aggregate(gens_capacities: Vec<usize>) -> Vec<BulletproofGens> {
// let mut gens_vector = Vec::new();
// for (capacity, i) in gens_capacities.iter().enumerate() {
// gens_vector.push(BulletproofGens::new(capacity, &i.to_le_bytes()));
// }
// gens_vector
// }
/// Increases the generators' capacity to the amount specified.
/// If less than or equal to the current capacity, does nothing.
pub fn increase_capacity(&mut self, new_capacity: usize) {
if self.gens_capacity >= new_capacity {
return;
}
let label = [b'G'];
self.G_vec.extend(
&mut GeneratorsChain::new(&[label, [b'G']].concat())
.fast_forward(self.gens_capacity)
.take(new_capacity - self.gens_capacity),
);
self.H_vec.extend(
&mut GeneratorsChain::new(&[label, [b'H']].concat())
.fast_forward(self.gens_capacity)
.take(new_capacity - self.gens_capacity),
);
self.gens_capacity = new_capacity;
}
#[allow(non_snake_case)]
pub(crate) fn G(&self, n: usize) -> impl Iterator<Item = &RistrettoPoint> {
GensIter {
array: &self.G_vec,
n,
gen_idx: 0,
}
}
#[allow(non_snake_case)]
pub(crate) fn H(&self, n: usize) -> impl Iterator<Item = &RistrettoPoint> {
GensIter {
array: &self.H_vec,
n,
gen_idx: 0,
}
}
}
struct GensIter<'a> {
array: &'a Vec<RistrettoPoint>,
n: usize,
gen_idx: usize,
}
impl<'a> Iterator for GensIter<'a> {
type Item = &'a RistrettoPoint;
fn next(&mut self) -> Option<Self::Item> {
if self.gen_idx >= self.n {
None
} else {
let cur_gen = self.gen_idx;
self.gen_idx += 1;
Some(&self.array[cur_gen])
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let size = self.n - self.gen_idx;
(size, Some(size))
}
}

View File

@ -0,0 +1,505 @@
use core::iter;
use std::borrow::Borrow;
use curve25519_dalek::ristretto::{CompressedRistretto, RistrettoPoint};
use curve25519_dalek::scalar::Scalar;
use curve25519_dalek::traits::VartimeMultiscalarMul;
use crate::errors::ProofError;
use crate::range_proof::util;
use crate::transcript::TranscriptProtocol;
use merlin::Transcript;
#[allow(non_snake_case)]
#[derive(Clone)]
pub struct InnerProductProof {
pub L_vec: Vec<CompressedRistretto>, // 32 * log(bit_length)
pub R_vec: Vec<CompressedRistretto>, // 32 * log(bit_length)
pub a: Scalar, // 32 bytes
pub b: Scalar, // 32 bytes
}
#[allow(non_snake_case)]
impl InnerProductProof {
/// Create an inner-product proof.
///
/// The proof is created with respect to the bases \\(G\\), \\(H'\\),
/// where \\(H'\_i = H\_i \cdot \texttt{Hprime\\_factors}\_i\\).
///
/// The `verifier` is passed in as a parameter so that the
/// challenges depend on the *entire* transcript (including parent
/// protocols).
///
/// The lengths of the vectors must all be the same, and must all be
/// either 0 or a power of 2.
#[allow(clippy::too_many_arguments)]
pub fn create(
Q: &RistrettoPoint,
G_factors: &[Scalar],
H_factors: &[Scalar],
mut G_vec: Vec<RistrettoPoint>,
mut H_vec: Vec<RistrettoPoint>,
mut a_vec: Vec<Scalar>,
mut b_vec: Vec<Scalar>,
transcript: &mut Transcript,
) -> InnerProductProof {
// Create slices G, H, a, b backed by their respective
// vectors. This lets us reslice as we compress the lengths
// of the vectors in the main loop below.
let mut G = &mut G_vec[..];
let mut H = &mut H_vec[..];
let mut a = &mut a_vec[..];
let mut b = &mut b_vec[..];
let mut n = G.len();
// All of the input vectors must have the same length.
assert_eq!(G.len(), n);
assert_eq!(H.len(), n);
assert_eq!(a.len(), n);
assert_eq!(b.len(), n);
assert_eq!(G_factors.len(), n);
assert_eq!(H_factors.len(), n);
// All of the input vectors must have a length that is a power of two.
assert!(n.is_power_of_two());
transcript.innerproduct_domain_sep(n as u64);
let lg_n = n.next_power_of_two().trailing_zeros() as usize;
let mut L_vec = Vec::with_capacity(lg_n);
let mut R_vec = Vec::with_capacity(lg_n);
// If it's the first iteration, unroll the Hprime = H*y_inv scalar mults
// into multiscalar muls, for performance.
if n != 1 {
n /= 2;
let (a_L, a_R) = a.split_at_mut(n);
let (b_L, b_R) = b.split_at_mut(n);
let (G_L, G_R) = G.split_at_mut(n);
let (H_L, H_R) = H.split_at_mut(n);
let c_L = util::inner_product(a_L, b_R);
let c_R = util::inner_product(a_R, b_L);
let L = RistrettoPoint::vartime_multiscalar_mul(
a_L.iter()
.zip(G_factors[n..2 * n].iter())
.map(|(a_L_i, g)| a_L_i * g)
.chain(
b_R.iter()
.zip(H_factors[0..n].iter())
.map(|(b_R_i, h)| b_R_i * h),
)
.chain(iter::once(c_L)),
G_R.iter().chain(H_L.iter()).chain(iter::once(Q)),
)
.compress();
let R = RistrettoPoint::vartime_multiscalar_mul(
a_R.iter()
.zip(G_factors[0..n].iter())
.map(|(a_R_i, g)| a_R_i * g)
.chain(
b_L.iter()
.zip(H_factors[n..2 * n].iter())
.map(|(b_L_i, h)| b_L_i * h),
)
.chain(iter::once(c_R)),
G_L.iter().chain(H_R.iter()).chain(iter::once(Q)),
)
.compress();
L_vec.push(L);
R_vec.push(R);
transcript.append_point(b"L", &L);
transcript.append_point(b"R", &R);
let u = transcript.challenge_scalar(b"u");
let u_inv = u.invert();
for i in 0..n {
a_L[i] = a_L[i] * u + u_inv * a_R[i];
b_L[i] = b_L[i] * u_inv + u * b_R[i];
G_L[i] = RistrettoPoint::vartime_multiscalar_mul(
&[u_inv * G_factors[i], u * G_factors[n + i]],
&[G_L[i], G_R[i]],
);
H_L[i] = RistrettoPoint::vartime_multiscalar_mul(
&[u * H_factors[i], u_inv * H_factors[n + i]],
&[H_L[i], H_R[i]],
)
}
a = a_L;
b = b_L;
G = G_L;
H = H_L;
}
while n != 1 {
n /= 2;
let (a_L, a_R) = a.split_at_mut(n);
let (b_L, b_R) = b.split_at_mut(n);
let (G_L, G_R) = G.split_at_mut(n);
let (H_L, H_R) = H.split_at_mut(n);
let c_L = util::inner_product(a_L, b_R);
let c_R = util::inner_product(a_R, b_L);
let L = RistrettoPoint::vartime_multiscalar_mul(
a_L.iter().chain(b_R.iter()).chain(iter::once(&c_L)),
G_R.iter().chain(H_L.iter()).chain(iter::once(Q)),
)
.compress();
let R = RistrettoPoint::vartime_multiscalar_mul(
a_R.iter().chain(b_L.iter()).chain(iter::once(&c_R)),
G_L.iter().chain(H_R.iter()).chain(iter::once(Q)),
)
.compress();
L_vec.push(L);
R_vec.push(R);
transcript.append_point(b"L", &L);
transcript.append_point(b"R", &R);
let u = transcript.challenge_scalar(b"u");
let u_inv = u.invert();
for i in 0..n {
a_L[i] = a_L[i] * u + u_inv * a_R[i];
b_L[i] = b_L[i] * u_inv + u * b_R[i];
G_L[i] = RistrettoPoint::vartime_multiscalar_mul(&[u_inv, u], &[G_L[i], G_R[i]]);
H_L[i] = RistrettoPoint::vartime_multiscalar_mul(&[u, u_inv], &[H_L[i], H_R[i]]);
}
a = a_L;
b = b_L;
G = G_L;
H = H_L;
}
InnerProductProof {
L_vec,
R_vec,
a: a[0],
b: b[0],
}
}
/// Computes three vectors of verification scalars \\([u\_{i}^{2}]\\), \\([u\_{i}^{-2}]\\) and
/// \\([s\_{i}]\\) for combined multiscalar multiplication in a parent protocol. See [inner
/// product protocol notes](index.html#verification-equation) for details. The verifier must
/// provide the input length \\(n\\) explicitly to avoid unbounded allocation within the inner
/// product proof.
#[allow(clippy::type_complexity)]
pub(crate) fn verification_scalars(
&self,
n: usize,
transcript: &mut Transcript,
) -> Result<(Vec<Scalar>, Vec<Scalar>, Vec<Scalar>), ProofError> {
let lg_n = self.L_vec.len();
if lg_n >= 32 {
// 4 billion multiplications should be enough for anyone
// and this check prevents overflow in 1<<lg_n below.
return Err(ProofError::VerificationError);
}
if n != (1 << lg_n) {
return Err(ProofError::VerificationError);
}
transcript.innerproduct_domain_sep(n as u64);
// 1. Recompute x_k,...,x_1 based on the proof transcript
let mut challenges = Vec::with_capacity(lg_n);
for (L, R) in self.L_vec.iter().zip(self.R_vec.iter()) {
transcript.validate_and_append_point(b"L", L)?;
transcript.validate_and_append_point(b"R", R)?;
challenges.push(transcript.challenge_scalar(b"u"));
}
// 2. Compute 1/(u_k...u_1) and 1/u_k, ..., 1/u_1
let mut challenges_inv = challenges.clone();
let allinv = Scalar::batch_invert(&mut challenges_inv);
// 3. Compute u_i^2 and (1/u_i)^2
for i in 0..lg_n {
challenges[i] = challenges[i] * challenges[i];
challenges_inv[i] = challenges_inv[i] * challenges_inv[i];
}
let challenges_sq = challenges;
let challenges_inv_sq = challenges_inv;
// 4. Compute s values inductively.
let mut s = Vec::with_capacity(n);
s.push(allinv);
for i in 1..n {
let lg_i = (32 - 1 - (i as u32).leading_zeros()) as usize;
let k = 1 << lg_i;
// The challenges are stored in "creation order" as [u_k,...,u_1],
// so u_{lg(i)+1} = is indexed by (lg_n-1) - lg_i
let u_lg_i_sq = challenges_sq[(lg_n - 1) - lg_i];
s.push(s[i - k] * u_lg_i_sq);
}
Ok((challenges_sq, challenges_inv_sq, s))
}
/// This method is for testing that proof generation work, but for efficiency the actual
/// protocols would use `verification_scalars` method to combine inner product verification
/// with other checks in a single multiscalar multiplication.
#[allow(clippy::too_many_arguments)]
pub fn verify<IG, IH>(
&self,
n: usize,
G_factors: IG,
H_factors: IH,
P: &RistrettoPoint,
Q: &RistrettoPoint,
G: &[RistrettoPoint],
H: &[RistrettoPoint],
transcript: &mut Transcript,
) -> Result<(), ProofError>
where
IG: IntoIterator,
IG::Item: Borrow<Scalar>,
IH: IntoIterator,
IH::Item: Borrow<Scalar>,
{
let (u_sq, u_inv_sq, s) = self.verification_scalars(n, transcript)?;
let g_times_a_times_s = G_factors
.into_iter()
.zip(s.iter())
.map(|(g_i, s_i)| (self.a * s_i) * g_i.borrow())
.take(G.len());
// 1/s[i] is s[!i], and !i runs from n-1 to 0 as i runs from 0 to n-1
let inv_s = s.iter().rev();
let h_times_b_div_s = H_factors
.into_iter()
.zip(inv_s)
.map(|(h_i, s_i_inv)| (self.b * s_i_inv) * h_i.borrow());
let neg_u_sq = u_sq.iter().map(|ui| -ui);
let neg_u_inv_sq = u_inv_sq.iter().map(|ui| -ui);
let Ls = self
.L_vec
.iter()
.map(|p| p.decompress().ok_or(ProofError::VerificationError))
.collect::<Result<Vec<_>, _>>()?;
let Rs = self
.R_vec
.iter()
.map(|p| p.decompress().ok_or(ProofError::VerificationError))
.collect::<Result<Vec<_>, _>>()?;
let expect_P = RistrettoPoint::vartime_multiscalar_mul(
iter::once(self.a * self.b)
.chain(g_times_a_times_s)
.chain(h_times_b_div_s)
.chain(neg_u_sq)
.chain(neg_u_inv_sq),
iter::once(Q)
.chain(G.iter())
.chain(H.iter())
.chain(Ls.iter())
.chain(Rs.iter()),
);
if expect_P == *P {
Ok(())
} else {
Err(ProofError::VerificationError)
}
}
/// Returns the size in bytes required to serialize the inner
/// product proof.
///
/// For vectors of length `n` the proof size is
/// \\(32 \cdot (2\lg n+2)\\) bytes.
pub fn serialized_size(&self) -> usize {
(self.L_vec.len() * 2 + 2) * 32
}
/// Serializes the proof into a byte array of \\(2n+2\\) 32-byte elements.
/// The layout of the inner product proof is:
/// * \\(n\\) pairs of compressed Ristretto points \\(L_0, R_0 \dots, L_{n-1}, R_{n-1}\\),
/// * two scalars \\(a, b\\).
pub fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(self.serialized_size());
for (l, r) in self.L_vec.iter().zip(self.R_vec.iter()) {
buf.extend_from_slice(l.as_bytes());
buf.extend_from_slice(r.as_bytes());
}
buf.extend_from_slice(self.a.as_bytes());
buf.extend_from_slice(self.b.as_bytes());
buf
}
// pub fn to_bytes_64(&self) -> Result<InnerProductProof64, ProofError> {
// let mut bytes = [0u8; 448];
// self.L_vec.iter().chain(self.R_vec.iter()).enumerate().for_each(
// |(i, x)| bytes[i*32..(i+1)*32].copy_from_slice(x.as_bytes())
// );
// bytes[384..416].copy_from_slice(self.a.as_bytes());
// bytes[416..448].copy_from_slice(self.a.as_bytes());
// Ok(InnerProductProof64(bytes))
// }
/*
/// Converts the proof into a byte iterator over serialized view of the proof.
/// The layout of the inner product proof is:
/// * \\(n\\) pairs of compressed Ristretto points \\(L_0, R_0 \dots, L_{n-1}, R_{n-1}\\),
/// * two scalars \\(a, b\\).
#[inline]
pub(crate) fn to_bytes_iter(&self) -> impl Iterator<Item = u8> + '_ {
self.L_vec
.iter()
.zip(self.R_vec.iter())
.flat_map(|(l, r)| l.as_bytes().iter().chain(r.as_bytes()))
.chain(self.a.as_bytes())
.chain(self.b.as_bytes())
.copied()
}
*/
/// Deserializes the proof from a byte slice.
/// Returns an error in the following cases:
/// * the slice does not have \\(2n+2\\) 32-byte elements,
/// * \\(n\\) is larger or equal to 32 (proof is too big),
/// * any of \\(2n\\) points are not valid compressed Ristretto points,
/// * any of 2 scalars are not canonical scalars modulo Ristretto group order.
pub fn from_bytes(slice: &[u8]) -> Result<InnerProductProof, ProofError> {
let b = slice.len();
if b % 32 != 0 {
return Err(ProofError::FormatError);
}
let num_elements = b / 32;
if num_elements < 2 {
return Err(ProofError::FormatError);
}
if (num_elements - 2) % 2 != 0 {
return Err(ProofError::FormatError);
}
let lg_n = (num_elements - 2) / 2;
if lg_n >= 32 {
return Err(ProofError::FormatError);
}
let mut L_vec: Vec<CompressedRistretto> = Vec::with_capacity(lg_n);
let mut R_vec: Vec<CompressedRistretto> = Vec::with_capacity(lg_n);
for i in 0..lg_n {
let pos = 2 * i * 32;
L_vec.push(CompressedRistretto(util::read32(&slice[pos..])));
R_vec.push(CompressedRistretto(util::read32(&slice[pos + 32..])));
}
let pos = 2 * lg_n * 32;
let a = Scalar::from_canonical_bytes(util::read32(&slice[pos..]))
.ok_or(ProofError::FormatError)?;
let b = Scalar::from_canonical_bytes(util::read32(&slice[pos + 32..]))
.ok_or(ProofError::FormatError)?;
Ok(InnerProductProof { L_vec, R_vec, a, b })
}
}
#[cfg(test)]
mod tests {
use {
super::*, crate::range_proof::generators::BulletproofGens, rand::rngs::OsRng,
sha3::Sha3_512,
};
#[test]
#[ignore]
#[allow(non_snake_case)]
fn test_basic_correctness() {
let n = 32;
let bp_gens = BulletproofGens::new(n);
let G: Vec<RistrettoPoint> = bp_gens.G(n).cloned().collect();
let H: Vec<RistrettoPoint> = bp_gens.H(n).cloned().collect();
let Q = RistrettoPoint::hash_from_bytes::<Sha3_512>(b"test point");
let a: Vec<_> = (0..n).map(|_| Scalar::random(&mut OsRng)).collect();
let b: Vec<_> = (0..n).map(|_| Scalar::random(&mut OsRng)).collect();
let c = util::inner_product(&a, &b);
let G_factors: Vec<Scalar> = iter::repeat(Scalar::one()).take(n).collect();
let y_inv = Scalar::random(&mut OsRng);
let H_factors: Vec<Scalar> = util::exp_iter(y_inv).take(n).collect();
// P would be determined upstream, but we need a correct P to check the proof.
//
// To generate P = <a,G> + <b,H'> + <a,b> Q, compute
// P = <a,G> + <b',H> + <a,b> Q,
// where b' = b \circ y^(-n)
let b_prime = b.iter().zip(util::exp_iter(y_inv)).map(|(bi, yi)| bi * yi);
// a.iter() has Item=&Scalar, need Item=Scalar to chain with b_prime
let a_prime = a.iter().cloned();
let P = RistrettoPoint::vartime_multiscalar_mul(
a_prime.chain(b_prime).chain(iter::once(c)),
G.iter().chain(H.iter()).chain(iter::once(&Q)),
);
let mut prover_transcript = Transcript::new(b"innerproducttest");
let mut verifier_transcript = Transcript::new(b"innerproducttest");
let proof = InnerProductProof::create(
&Q,
&G_factors,
&H_factors,
G.clone(),
H.clone(),
a.clone(),
b.clone(),
&mut prover_transcript,
);
assert!(proof
.verify(
n,
iter::repeat(Scalar::one()).take(n),
util::exp_iter(y_inv).take(n),
&P,
&Q,
&G,
&H,
&mut verifier_transcript,
)
.is_ok());
let proof = InnerProductProof::from_bytes(proof.to_bytes().as_slice()).unwrap();
let mut verifier_transcript = Transcript::new(b"innerproducttest");
assert!(proof
.verify(
n,
iter::repeat(Scalar::one()).take(n),
util::exp_iter(y_inv).take(n),
&P,
&Q,
&G,
&H,
&mut verifier_transcript,
)
.is_ok());
}
}

View File

@ -0,0 +1,456 @@
#[cfg(not(target_arch = "bpf"))]
use {
crate::encryption::pedersen::{Pedersen, PedersenOpen},
curve25519_dalek::traits::MultiscalarMul,
rand::rngs::OsRng,
subtle::{Choice, ConditionallySelectable},
};
use {
crate::{
encryption::pedersen::PedersenBase, errors::ProofError,
range_proof::generators::BulletproofGens, range_proof::inner_product::InnerProductProof,
transcript::TranscriptProtocol,
},
core::iter,
curve25519_dalek::{
ristretto::{CompressedRistretto, RistrettoPoint},
scalar::Scalar,
traits::{IsIdentity, VartimeMultiscalarMul},
},
merlin::Transcript,
};
pub mod generators;
pub mod inner_product;
pub mod util;
#[allow(non_snake_case)]
#[derive(Clone)]
pub struct RangeProof {
pub A: CompressedRistretto, // 32 bytes
pub S: CompressedRistretto, // 32 bytes
pub T_1: CompressedRistretto, // 32 bytes
pub T_2: CompressedRistretto, // 32 bytes
pub t_x: Scalar, // 32 bytes
pub t_x_blinding: Scalar, // 32 bytes
pub e_blinding: Scalar, // 32 bytes
pub ipp_proof: InnerProductProof, // 448 bytes for withdraw; 512 for transfer
}
#[allow(non_snake_case)]
impl RangeProof {
#[allow(clippy::many_single_char_names)]
#[cfg(not(target_arch = "bpf"))]
pub fn create(
amounts: Vec<u64>,
bit_lengths: Vec<usize>,
opens: Vec<&PedersenOpen>,
transcript: &mut Transcript,
) -> Self {
let t_1_blinding = PedersenOpen::random(&mut OsRng);
let t_2_blinding = PedersenOpen::random(&mut OsRng);
let (range_proof, _, _) = Self::create_with(
amounts,
bit_lengths,
opens,
&t_1_blinding,
&t_2_blinding,
transcript,
);
range_proof
}
#[allow(clippy::many_single_char_names)]
#[cfg(not(target_arch = "bpf"))]
pub fn create_with(
amounts: Vec<u64>,
bit_lengths: Vec<usize>,
opens: Vec<&PedersenOpen>,
t_1_blinding: &PedersenOpen,
t_2_blinding: &PedersenOpen,
transcript: &mut Transcript,
) -> (Self, Scalar, Scalar) {
let nm = bit_lengths.iter().sum();
// Computing the generators online for now. It should ultimately be precomputed.
let bp_gens = BulletproofGens::new(nm);
let G = PedersenBase::default().G;
let H = PedersenBase::default().H;
// bit-decompose values and commit to the bits
let a_blinding = Scalar::random(&mut OsRng);
let mut A = a_blinding * H;
let mut gens_iter = bp_gens.G(nm).zip(bp_gens.H(nm));
for (amount_i, n_i) in amounts.iter().zip(bit_lengths.iter()) {
for j in 0..(*n_i) {
let (G_ij, H_ij) = gens_iter.next().unwrap();
let v_ij = Choice::from(((amount_i >> j) & 1) as u8);
let mut point = -H_ij;
point.conditional_assign(G_ij, v_ij);
A += point;
}
}
// generate blinding factors and commit as vectors
let s_blinding = Scalar::random(&mut OsRng);
let s_L: Vec<Scalar> = (0..nm).map(|_| Scalar::random(&mut OsRng)).collect();
let s_R: Vec<Scalar> = (0..nm).map(|_| Scalar::random(&mut OsRng)).collect();
let S = RistrettoPoint::multiscalar_mul(
iter::once(&s_blinding).chain(s_L.iter()).chain(s_R.iter()),
iter::once(&H).chain(bp_gens.G(nm)).chain(bp_gens.H(nm)),
);
transcript.append_point(b"A", &A.compress());
transcript.append_point(b"S", &S.compress());
// commit to T1 and T2
let y = transcript.challenge_scalar(b"y");
let z = transcript.challenge_scalar(b"z");
let mut l_poly = util::VecPoly1::zero(nm);
let mut r_poly = util::VecPoly1::zero(nm);
let mut i = 0;
let mut exp_z = z * z;
let mut exp_y = Scalar::one();
for (amount_i, n_i) in amounts.iter().zip(bit_lengths.iter()) {
let mut exp_2 = Scalar::one();
for j in 0..(*n_i) {
let a_L_j = Scalar::from((amount_i >> j) & 1);
let a_R_j = a_L_j - Scalar::one();
l_poly.0[i] = a_L_j - z;
l_poly.1[i] = s_L[i];
r_poly.0[i] = exp_y * (a_R_j + z) + exp_z * exp_2;
r_poly.1[i] = exp_y * s_R[i];
exp_y *= y;
exp_2 = exp_2 + exp_2;
i += 1;
}
exp_z *= z;
}
let t_poly = l_poly.inner_product(&r_poly);
let T_1 = Pedersen::commit_with(t_poly.1, t_1_blinding)
.get_point()
.compress();
let T_2 = Pedersen::commit_with(t_poly.2, t_2_blinding)
.get_point()
.compress();
transcript.append_point(b"T_1", &T_1);
transcript.append_point(b"T_2", &T_2);
let x = transcript.challenge_scalar(b"x");
let mut agg_open = Scalar::zero();
let mut exp_z = z * z;
for open in opens {
agg_open += exp_z * open.get_scalar();
exp_z *= z;
}
let t_blinding_poly = util::Poly2(
agg_open,
t_1_blinding.get_scalar(),
t_2_blinding.get_scalar(),
);
// compute t_x
let t_x = t_poly.eval(x);
let t_x_blinding = t_blinding_poly.eval(x);
let e_blinding = a_blinding + s_blinding * x;
let l_vec = l_poly.eval(x);
let r_vec = r_poly.eval(x);
transcript.append_scalar(b"t_x", &t_x);
transcript.append_scalar(b"t_x_blinding", &t_x_blinding);
transcript.append_scalar(b"e_blinding", &e_blinding);
let w = transcript.challenge_scalar(b"w");
let Q = w * G;
transcript.challenge_scalar(b"c");
let G_factors: Vec<Scalar> = iter::repeat(Scalar::one()).take(nm).collect();
let H_factors: Vec<Scalar> = util::exp_iter(y.invert()).take(nm).collect();
let ipp_proof = InnerProductProof::create(
&Q,
&G_factors,
&H_factors,
bp_gens.G(nm).cloned().collect(),
bp_gens.H(nm).cloned().collect(),
l_vec,
r_vec,
transcript,
);
let range_proof = RangeProof {
A: A.compress(),
S: S.compress(),
T_1,
T_2,
t_x,
t_x_blinding,
e_blinding,
ipp_proof,
};
(range_proof, x, z)
}
#[allow(clippy::many_single_char_names)]
pub fn verify(
&self,
comms: Vec<&CompressedRistretto>,
bit_lengths: Vec<usize>,
transcript: &mut Transcript,
) -> Result<(), ProofError> {
self.verify_with(comms, bit_lengths, None, None, transcript)
}
#[allow(clippy::many_single_char_names)]
pub fn verify_with(
&self,
comms: Vec<&CompressedRistretto>,
bit_lengths: Vec<usize>,
x_ver: Option<Scalar>,
z_ver: Option<Scalar>,
transcript: &mut Transcript,
) -> Result<(), ProofError> {
let G = PedersenBase::default().G;
let H = PedersenBase::default().H;
let m = bit_lengths.len();
let nm: usize = bit_lengths.iter().sum();
let bp_gens = BulletproofGens::new(nm);
if !(nm == 8 || nm == 16 || nm == 32 || nm == 64 || nm == 128) {
return Err(ProofError::InvalidBitsize);
}
transcript.validate_and_append_point(b"A", &self.A)?;
transcript.validate_and_append_point(b"S", &self.S)?;
let y = transcript.challenge_scalar(b"y");
let z = transcript.challenge_scalar(b"z");
if z_ver.is_some() && z_ver.unwrap() != z {
return Err(ProofError::VerificationError);
}
let zz = z * z;
let minus_z = -z;
transcript.validate_and_append_point(b"T_1", &self.T_1)?;
transcript.validate_and_append_point(b"T_2", &self.T_2)?;
let x = transcript.challenge_scalar(b"x");
if x_ver.is_some() && x_ver.unwrap() != x {
return Err(ProofError::VerificationError);
}
transcript.append_scalar(b"t_x", &self.t_x);
transcript.append_scalar(b"t_x_blinding", &self.t_x_blinding);
transcript.append_scalar(b"e_blinding", &self.e_blinding);
let w = transcript.challenge_scalar(b"w");
// Challenge value for batching statements to be verified
let c = transcript.challenge_scalar(b"c");
let (x_sq, x_inv_sq, s) = self.ipp_proof.verification_scalars(nm, transcript)?;
let s_inv = s.iter().rev();
let a = self.ipp_proof.a;
let b = self.ipp_proof.b;
// Construct concat_z_and_2, an iterator of the values of
// z^0 * \vec(2)^n || z^1 * \vec(2)^n || ... || z^(m-1) * \vec(2)^n
let concat_z_and_2: Vec<Scalar> = util::exp_iter(z)
.zip(bit_lengths.iter())
.flat_map(|(exp_z, n_i)| {
util::exp_iter(Scalar::from(2u64))
.take(*n_i)
.map(move |exp_2| exp_2 * exp_z)
})
.collect();
let gs = s.iter().map(|s_i| minus_z - a * s_i);
let hs = s_inv
.clone()
.zip(util::exp_iter(y.invert()))
.zip(concat_z_and_2.iter())
.map(|((s_i_inv, exp_y_inv), z_and_2)| z + exp_y_inv * (zz * z_and_2 - b * s_i_inv));
let basepoint_scalar =
w * (self.t_x - a * b) + c * (delta(&bit_lengths, &y, &z) - self.t_x);
let value_commitment_scalars = util::exp_iter(z).take(m).map(|z_exp| c * zz * z_exp);
let mega_check = RistrettoPoint::optional_multiscalar_mul(
iter::once(Scalar::one())
.chain(iter::once(x))
.chain(iter::once(c * x))
.chain(iter::once(c * x * x))
.chain(iter::once(-self.e_blinding - c * self.t_x_blinding))
.chain(iter::once(basepoint_scalar))
.chain(x_sq.iter().cloned())
.chain(x_inv_sq.iter().cloned())
.chain(gs)
.chain(hs)
.chain(value_commitment_scalars),
iter::once(self.A.decompress())
.chain(iter::once(self.S.decompress()))
.chain(iter::once(self.T_1.decompress()))
.chain(iter::once(self.T_2.decompress()))
.chain(iter::once(Some(H)))
.chain(iter::once(Some(G)))
.chain(self.ipp_proof.L_vec.iter().map(|L| L.decompress()))
.chain(self.ipp_proof.R_vec.iter().map(|R| R.decompress()))
.chain(bp_gens.G(nm).map(|&x| Some(x)))
.chain(bp_gens.H(nm).map(|&x| Some(x)))
.chain(comms.iter().map(|V| V.decompress())),
)
.ok_or(ProofError::VerificationError)?;
if mega_check.is_identity() {
Ok(())
} else {
Err(ProofError::VerificationError)
}
}
// Following the dalek rangeproof library signature for now. The exact method signature can be
// changed.
pub fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(7 * 32 + self.ipp_proof.serialized_size());
buf.extend_from_slice(self.A.as_bytes());
buf.extend_from_slice(self.S.as_bytes());
buf.extend_from_slice(self.T_1.as_bytes());
buf.extend_from_slice(self.T_2.as_bytes());
buf.extend_from_slice(self.t_x.as_bytes());
buf.extend_from_slice(self.t_x_blinding.as_bytes());
buf.extend_from_slice(self.e_blinding.as_bytes());
buf.extend_from_slice(&self.ipp_proof.to_bytes());
buf
}
// Following the dalek rangeproof library signature for now. The exact method signature can be
// changed.
pub fn from_bytes(slice: &[u8]) -> Result<RangeProof, ProofError> {
if slice.len() % 32 != 0 {
return Err(ProofError::FormatError);
}
if slice.len() < 7 * 32 {
return Err(ProofError::FormatError);
}
let A = CompressedRistretto(util::read32(&slice[0..]));
let S = CompressedRistretto(util::read32(&slice[32..]));
let T_1 = CompressedRistretto(util::read32(&slice[2 * 32..]));
let T_2 = CompressedRistretto(util::read32(&slice[3 * 32..]));
let t_x = Scalar::from_canonical_bytes(util::read32(&slice[4 * 32..]))
.ok_or(ProofError::FormatError)?;
let t_x_blinding = Scalar::from_canonical_bytes(util::read32(&slice[5 * 32..]))
.ok_or(ProofError::FormatError)?;
let e_blinding = Scalar::from_canonical_bytes(util::read32(&slice[6 * 32..]))
.ok_or(ProofError::FormatError)?;
let ipp_proof = InnerProductProof::from_bytes(&slice[7 * 32..])?;
Ok(RangeProof {
A,
S,
T_1,
T_2,
t_x,
t_x_blinding,
e_blinding,
ipp_proof,
})
}
}
/// Compute
/// \\[
/// \delta(y,z) = (z - z^{2}) \langle \mathbf{1}, {\mathbf{y}}^{n \cdot m} \rangle - \sum_{j=0}^{m-1} z^{j+3} \cdot \langle \mathbf{1}, {\mathbf{2}}^{n \cdot m} \rangle
/// \\]
fn delta(bit_lengths: &[usize], y: &Scalar, z: &Scalar) -> Scalar {
let nm: usize = bit_lengths.iter().sum();
let sum_y = util::sum_of_powers(y, nm);
let mut agg_delta = (z - z * z) * sum_y;
let mut exp_z = z * z * z;
for n_i in bit_lengths.iter() {
let sum_2 = util::sum_of_powers(&Scalar::from(2u64), *n_i);
agg_delta -= exp_z * sum_2;
exp_z *= z;
}
agg_delta
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_single_rangeproof() {
let (comm, open) = Pedersen::commit(55_u64);
let mut transcript_create = Transcript::new(b"Test");
let mut transcript_verify = Transcript::new(b"Test");
let proof = RangeProof::create(vec![55], vec![32], vec![&open], &mut transcript_create);
assert!(proof
.verify(
vec![&comm.get_point().compress()],
vec![32],
&mut transcript_verify
)
.is_ok());
}
#[test]
fn test_aggregated_rangeproof() {
let (comm_1, open_1) = Pedersen::commit(55_u64);
let (comm_2, open_2) = Pedersen::commit(77_u64);
let (comm_3, open_3) = Pedersen::commit(99_u64);
let mut transcript_create = Transcript::new(b"Test");
let mut transcript_verify = Transcript::new(b"Test");
let proof = RangeProof::create(
vec![55, 77, 99],
vec![64, 32, 32],
vec![&open_1, &open_2, &open_3],
&mut transcript_create,
);
let comm_1_point = comm_1.get_point().compress();
let comm_2_point = comm_2.get_point().compress();
let comm_3_point = comm_3.get_point().compress();
assert!(proof
.verify(
vec![&comm_1_point, &comm_2_point, &comm_3_point],
vec![64, 32, 32],
&mut transcript_verify,
)
.is_ok());
}
// TODO: write test for serialization/deserialization
}

View File

@ -0,0 +1,138 @@
/// Utility functions for Bulletproofs.
///
/// The code is copied from https://github.com/dalek-cryptography/bulletproofs for now...
use curve25519_dalek::scalar::Scalar;
/// Represents a degree-1 vector polynomial \\(\mathbf{a} + \mathbf{b} \cdot x\\).
pub struct VecPoly1(pub Vec<Scalar>, pub Vec<Scalar>);
impl VecPoly1 {
pub fn zero(n: usize) -> Self {
VecPoly1(vec![Scalar::zero(); n], vec![Scalar::zero(); n])
}
pub fn inner_product(&self, rhs: &VecPoly1) -> Poly2 {
// Uses Karatsuba's method
let l = self;
let r = rhs;
let t0 = inner_product(&l.0, &r.0);
let t2 = inner_product(&l.1, &r.1);
let l0_plus_l1 = add_vec(&l.0, &l.1);
let r0_plus_r1 = add_vec(&r.0, &r.1);
let t1 = inner_product(&l0_plus_l1, &r0_plus_r1) - t0 - t2;
Poly2(t0, t1, t2)
}
pub fn eval(&self, x: Scalar) -> Vec<Scalar> {
let n = self.0.len();
let mut out = vec![Scalar::zero(); n];
#[allow(clippy::needless_range_loop)]
for i in 0..n {
out[i] = self.0[i] + self.1[i] * x;
}
out
}
}
/// Represents a degree-2 scalar polynomial \\(a + b \cdot x + c \cdot x^2\\)
pub struct Poly2(pub Scalar, pub Scalar, pub Scalar);
impl Poly2 {
pub fn eval(&self, x: Scalar) -> Scalar {
self.0 + x * (self.1 + x * self.2)
}
}
/// Provides an iterator over the powers of a `Scalar`.
///
/// This struct is created by the `exp_iter` function.
pub struct ScalarExp {
x: Scalar,
next_exp_x: Scalar,
}
impl Iterator for ScalarExp {
type Item = Scalar;
fn next(&mut self) -> Option<Scalar> {
let exp_x = self.next_exp_x;
self.next_exp_x *= self.x;
Some(exp_x)
}
fn size_hint(&self) -> (usize, Option<usize>) {
(usize::max_value(), None)
}
}
/// Return an iterator of the powers of `x`.
pub fn exp_iter(x: Scalar) -> ScalarExp {
let next_exp_x = Scalar::one();
ScalarExp { x, next_exp_x }
}
pub fn add_vec(a: &[Scalar], b: &[Scalar]) -> Vec<Scalar> {
if a.len() != b.len() {
// throw some error
//println!("lengths of vectors don't match for vector addition");
}
let mut out = vec![Scalar::zero(); b.len()];
for i in 0..a.len() {
out[i] = a[i] + b[i];
}
out
}
/// Given `data` with `len >= 32`, return the first 32 bytes.
pub fn read32(data: &[u8]) -> [u8; 32] {
let mut buf32 = [0u8; 32];
buf32[..].copy_from_slice(&data[..32]);
buf32
}
/// Computes an inner product of two vectors
/// \\[
/// {\langle {\mathbf{a}}, {\mathbf{b}} \rangle} = \sum\_{i=0}^{n-1} a\_i \cdot b\_i.
/// \\]
/// Panics if the lengths of \\(\mathbf{a}\\) and \\(\mathbf{b}\\) are not equal.
pub fn inner_product(a: &[Scalar], b: &[Scalar]) -> Scalar {
let mut out = Scalar::zero();
if a.len() != b.len() {
panic!("inner_product(a,b): lengths of vectors do not match");
}
for i in 0..a.len() {
out += a[i] * b[i];
}
out
}
/// Takes the sum of all the powers of `x`, up to `n`
/// If `n` is a power of 2, it uses the efficient algorithm with `2*lg n` multiplications and additions.
/// If `n` is not a power of 2, it uses the slow algorithm with `n` multiplications and additions.
/// In the Bulletproofs case, all calls to `sum_of_powers` should have `n` as a power of 2.
pub fn sum_of_powers(x: &Scalar, n: usize) -> Scalar {
if !n.is_power_of_two() {
return sum_of_powers_slow(x, n);
}
if n == 0 || n == 1 {
return Scalar::from(n as u64);
}
let mut m = n;
let mut result = Scalar::one() + x;
let mut factor = *x;
while m > 2 {
factor = factor * factor;
result = result + factor * result;
m /= 2;
}
result
}
// takes the sum of all of the powers of x, up to n
fn sum_of_powers_slow(x: &Scalar, n: usize) -> Scalar {
exp_iter(*x).take(n).sum()
}

View File

@ -0,0 +1,117 @@
use curve25519_dalek::ristretto::CompressedRistretto;
use curve25519_dalek::scalar::Scalar;
use curve25519_dalek::traits::IsIdentity;
use merlin::Transcript;
use crate::errors::ProofError;
pub trait TranscriptProtocol {
/// Append a domain separator for an `n`-bit rangeproof for ElGamal
/// ciphertext using a decryption key
fn rangeproof_from_key_domain_sep(&mut self, n: u64);
/// Append a domain separator for an `n`-bit rangeproof for ElGamal
/// ciphertext using an opening
fn rangeproof_from_opening_domain_sep(&mut self, n: u64);
/// Append a domain separator for a length-`n` inner product proof.
fn innerproduct_domain_sep(&mut self, n: u64);
/// Append a domain separator for close account proof.
fn close_account_proof_domain_sep(&mut self);
/// Append a domain separator for update account public key proof.
fn update_account_public_key_proof_domain_sep(&mut self);
/// Append a domain separator for withdraw proof.
fn withdraw_proof_domain_sep(&mut self);
/// Append a domain separator for transfer with range proof.
fn transfer_range_proof_sep(&mut self);
/// Append a domain separator for transfer with validity proof.
fn transfer_validity_proof_sep(&mut self);
/// Append a `scalar` with the given `label`.
fn append_scalar(&mut self, label: &'static [u8], scalar: &Scalar);
/// Append a `point` with the given `label`.
fn append_point(&mut self, label: &'static [u8], point: &CompressedRistretto);
/// Check that a point is not the identity, then append it to the
/// transcript. Otherwise, return an error.
fn validate_and_append_point(
&mut self,
label: &'static [u8],
point: &CompressedRistretto,
) -> Result<(), ProofError>;
/// Compute a `label`ed challenge variable.
fn challenge_scalar(&mut self, label: &'static [u8]) -> Scalar;
}
impl TranscriptProtocol for Transcript {
fn rangeproof_from_key_domain_sep(&mut self, n: u64) {
self.append_message(b"dom-sep", b"rangeproof from opening v1");
self.append_u64(b"n", n);
}
fn rangeproof_from_opening_domain_sep(&mut self, n: u64) {
self.append_message(b"dom-sep", b"rangeproof from opening v1");
self.append_u64(b"n", n);
}
fn innerproduct_domain_sep(&mut self, n: u64) {
self.append_message(b"dom-sep", b"ipp v1");
self.append_u64(b"n", n);
}
fn close_account_proof_domain_sep(&mut self) {
self.append_message(b"dom_sep", b"CloseAccountProof");
}
fn update_account_public_key_proof_domain_sep(&mut self) {
self.append_message(b"dom_sep", b"UpdateAccountPublicKeyProof");
}
fn withdraw_proof_domain_sep(&mut self) {
self.append_message(b"dom_sep", b"WithdrawProof");
}
fn transfer_range_proof_sep(&mut self) {
self.append_message(b"dom_sep", b"TransferRangeProof");
}
fn transfer_validity_proof_sep(&mut self) {
self.append_message(b"dom_sep", b"TransferValidityProof");
}
fn append_scalar(&mut self, label: &'static [u8], scalar: &Scalar) {
self.append_message(label, scalar.as_bytes());
}
fn append_point(&mut self, label: &'static [u8], point: &CompressedRistretto) {
self.append_message(label, point.as_bytes());
}
fn validate_and_append_point(
&mut self,
label: &'static [u8],
point: &CompressedRistretto,
) -> Result<(), ProofError> {
if point.is_identity() {
Err(ProofError::VerificationError)
} else {
self.append_message(label, point.as_bytes());
Ok(())
}
}
fn challenge_scalar(&mut self, label: &'static [u8]) -> Scalar {
let mut buf = [0u8; 64];
self.challenge_bytes(label, &mut buf);
Scalar::from_bytes_mod_order_wide(&buf)
}
}

View File

@ -0,0 +1,91 @@
///! Instructions provided by the ZkToken Proof program
pub use crate::instruction::*;
use {
bytemuck::{bytes_of, Pod},
num_derive::{FromPrimitive, ToPrimitive},
num_traits::{FromPrimitive, ToPrimitive},
solana_program::{instruction::Instruction, pubkey::Pubkey},
};
#[derive(Clone, Copy, Debug, FromPrimitive, ToPrimitive, PartialEq)]
#[repr(u8)]
pub enum ProofInstruction {
/// Verify a `UpdateAccountPkData` struct
///
/// Accounts expected by this instruction:
/// None
///
/// Data expected by this instruction:
/// `UpdateAccountPkData`
///
VerifyUpdateAccountPk,
/// Verify a `CloseAccountData` struct
///
/// Accounts expected by this instruction:
/// None
///
/// Data expected by this instruction:
/// `CloseAccountData`
///
VerifyCloseAccount,
/// Verify a `WithdrawData` struct
///
/// Accounts expected by this instruction:
/// None
///
/// Data expected by this instruction:
/// `WithdrawData`
///
VerifyWithdraw,
/// Verify a `TransferRangeProofData` struct
///
/// Accounts expected by this instruction:
/// None
///
/// Data expected by this instruction:
/// `TransferRangeProofData`
///
VerifyTransferRangeProofData,
/// Verify a `TransferValidityProofData` struct
///
/// Accounts expected by this instruction:
/// None
///
/// Data expected by this instruction:
/// `TransferValidityProofData`
///
VerifyTransferValidityProofData,
}
impl ProofInstruction {
pub fn encode<T: Pod>(&self, proof: &T) -> Instruction {
let mut data = vec![ToPrimitive::to_u8(self).unwrap()];
data.extend_from_slice(bytes_of(proof));
Instruction {
program_id: crate::zk_token_proof_program::id(),
accounts: vec![],
data,
}
}
pub fn decode_type(program_id: &Pubkey, input: &[u8]) -> Option<Self> {
if *program_id != crate::zk_token_proof_program::id() || input.is_empty() {
None
} else {
FromPrimitive::from_u8(input[0])
}
}
pub fn decode_data<T: Pod>(input: &[u8]) -> Option<&T> {
if input.is_empty() {
None
} else {
bytemuck::try_from_bytes(&input[1..]).ok()
}
}
}

View File

@ -0,0 +1,2 @@
// Program Id of the ZkToken Proof program
solana_program::declare_id!("ZkTokenProof1111111111111111111111111111111");