[zk-token-sdk] Use checked arithmetic when processing transfer amount (#34130)

* add `try_split_u64`

* add `try_combine_lo_hi_u64`

* add `try` variants of ciphertext arithmetic functions

* use try functions in proof generaiton and verification logic

* deprecate non-`try` functions

* use try functions in proof generaiton and verification logic

* Apply suggestions from code review

Co-authored-by: Jon C <me@jonc.dev>

* cargo fmt

---------

Co-authored-by: Jon C <me@jonc.dev>
This commit is contained in:
samkim-crypto 2024-01-24 22:27:03 +09:00 committed by GitHub
parent b9947bd327
commit b11d41a3f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 239 additions and 33 deletions

View File

@ -38,6 +38,8 @@ pub enum ProofVerificationError {
ProofContext,
#[error("illegal commitment length")]
IllegalCommitmentLength,
#[error("illegal amount bit length")]
IllegalAmountBitLength,
}
#[derive(Clone, Debug, Eq, PartialEq)]

View File

@ -8,4 +8,8 @@ pub enum InstructionError {
Decryption,
#[error("missing ciphertext")]
MissingCiphertext,
#[error("illegal amount bit length")]
IllegalAmountBitLength,
#[error("arithmetic overflow")]
Overflow,
}

View File

@ -4,9 +4,12 @@ mod without_fee;
#[cfg(not(target_os = "solana"))]
use {
crate::encryption::{
elgamal::ElGamalCiphertext,
pedersen::{PedersenCommitment, PedersenOpening},
crate::{
encryption::{
elgamal::ElGamalCiphertext,
pedersen::{PedersenCommitment, PedersenOpening},
},
instruction::errors::InstructionError,
},
curve25519_dalek::scalar::Scalar,
};
@ -33,6 +36,7 @@ pub enum Role {
/// Takes in a 64-bit number `amount` and a bit length `bit_length`. It returns:
/// - the `bit_length` low bits of `amount` interpreted as u64
/// - the (64 - `bit_length`) high bits of `amount` interpreted as u64
#[deprecated(since = "1.18.0", note = "please use `try_split_u64` instead")]
#[cfg(not(target_os = "solana"))]
pub fn split_u64(amount: u64, bit_length: usize) -> (u64, u64) {
if bit_length == 64 {
@ -44,6 +48,30 @@ pub fn split_u64(amount: u64, bit_length: usize) -> (u64, u64) {
}
}
/// Takes in a 64-bit number `amount` and a bit length `bit_length`. It returns:
/// - the `bit_length` low bits of `amount` interpretted as u64
/// - the `(64 - bit_length)` high bits of `amount` interpretted as u64
#[cfg(not(target_os = "solana"))]
pub fn try_split_u64(amount: u64, bit_length: usize) -> Result<(u64, u64), InstructionError> {
match bit_length {
0 => Ok((0, amount)),
1..=63 => {
let bit_length_complement = u64::BITS.checked_sub(bit_length as u32).unwrap();
// shifts are safe as long as `bit_length` and `bit_length_complement` < 64
let lo = amount
.checked_shl(bit_length_complement) // clear out the high bits
.and_then(|amount| amount.checked_shr(bit_length_complement))
.unwrap(); // shift back
let hi = amount.checked_shr(bit_length as u32).unwrap();
Ok((lo, hi))
}
64 => Ok((amount, 0)),
_ => Err(InstructionError::IllegalAmountBitLength),
}
}
#[deprecated(since = "1.18.0", note = "please use `try_combine_lo_hi_u64` instead")]
#[cfg(not(target_os = "solana"))]
pub fn combine_lo_hi_u64(amount_lo: u64, amount_hi: u64, bit_length: usize) -> u64 {
if bit_length == 64 {
@ -53,16 +81,47 @@ pub fn combine_lo_hi_u64(amount_lo: u64, amount_hi: u64, bit_length: usize) -> u
}
}
/// Combine two numbers that are interpretted as the low and high bits of a target number. The
/// `bit_length` parameter specifies the number of bits that `amount_hi` is to be shifted by.
#[cfg(not(target_os = "solana"))]
fn combine_lo_hi_ciphertexts(
pub fn try_combine_lo_hi_u64(
amount_lo: u64,
amount_hi: u64,
bit_length: usize,
) -> Result<u64, InstructionError> {
match bit_length {
0 => Ok(amount_hi),
1..=63 => {
// shifts are safe as long as `bit_length` < 64
let amount_hi = amount_hi.checked_shl(bit_length as u32).unwrap();
let combined = amount_lo
.checked_add(amount_hi)
.ok_or(InstructionError::IllegalAmountBitLength)?;
Ok(combined)
}
64 => Ok(amount_lo),
_ => Err(InstructionError::IllegalAmountBitLength),
}
}
#[cfg(not(target_os = "solana"))]
fn try_combine_lo_hi_ciphertexts(
ciphertext_lo: &ElGamalCiphertext,
ciphertext_hi: &ElGamalCiphertext,
bit_length: usize,
) -> ElGamalCiphertext {
let two_power = (1_u64) << bit_length;
ciphertext_lo + &(ciphertext_hi * &Scalar::from(two_power))
) -> Result<ElGamalCiphertext, InstructionError> {
let two_power = if bit_length < u64::BITS as usize {
1_u64.checked_shl(bit_length as u32).unwrap()
} else {
return Err(InstructionError::IllegalAmountBitLength);
};
Ok(ciphertext_lo + &(ciphertext_hi * &Scalar::from(two_power)))
}
#[deprecated(
since = "1.18.0",
note = "please use `try_combine_lo_hi_commitments` instead"
)]
#[cfg(not(target_os = "solana"))]
pub fn combine_lo_hi_commitments(
comm_lo: &PedersenCommitment,
@ -73,6 +132,24 @@ pub fn combine_lo_hi_commitments(
comm_lo + comm_hi * &Scalar::from(two_power)
}
#[cfg(not(target_os = "solana"))]
pub fn try_combine_lo_hi_commitments(
comm_lo: &PedersenCommitment,
comm_hi: &PedersenCommitment,
bit_length: usize,
) -> Result<PedersenCommitment, InstructionError> {
let two_power = if bit_length < u64::BITS as usize {
1_u64.checked_shl(bit_length as u32).unwrap()
} else {
return Err(InstructionError::IllegalAmountBitLength);
};
Ok(comm_lo + comm_hi * &Scalar::from(two_power))
}
#[deprecated(
since = "1.18.0",
note = "please use `try_combine_lo_hi_openings` instead"
)]
#[cfg(not(target_os = "solana"))]
pub fn combine_lo_hi_openings(
opening_lo: &PedersenOpening,
@ -83,6 +160,20 @@ pub fn combine_lo_hi_openings(
opening_lo + opening_hi * &Scalar::from(two_power)
}
#[cfg(not(target_os = "solana"))]
pub fn try_combine_lo_hi_openings(
opening_lo: &PedersenOpening,
opening_hi: &PedersenOpening,
bit_length: usize,
) -> Result<PedersenOpening, InstructionError> {
let two_power = if bit_length < u64::BITS as usize {
1_u64.checked_shl(bit_length as u32).unwrap()
} else {
return Err(InstructionError::IllegalAmountBitLength);
};
Ok(opening_lo + opening_hi * &Scalar::from(two_power))
}
#[derive(Clone, Copy)]
#[repr(C)]
pub struct FeeParameters {
@ -91,3 +182,97 @@ pub struct FeeParameters {
/// Maximum fee assessed on transfers, expressed as an amount of tokens
pub maximum_fee: u64,
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_split_u64() {
assert_eq!((0, 0), try_split_u64(0, 0).unwrap());
assert_eq!((0, 0), try_split_u64(0, 1).unwrap());
assert_eq!((0, 0), try_split_u64(0, 5).unwrap());
assert_eq!((0, 0), try_split_u64(0, 63).unwrap());
assert_eq!((0, 0), try_split_u64(0, 64).unwrap());
assert_eq!(
InstructionError::IllegalAmountBitLength,
try_split_u64(0, 65).unwrap_err()
);
assert_eq!((0, 1), try_split_u64(1, 0).unwrap());
assert_eq!((1, 0), try_split_u64(1, 1).unwrap());
assert_eq!((1, 0), try_split_u64(1, 5).unwrap());
assert_eq!((1, 0), try_split_u64(1, 63).unwrap());
assert_eq!((1, 0), try_split_u64(1, 64).unwrap());
assert_eq!(
InstructionError::IllegalAmountBitLength,
try_split_u64(1, 65).unwrap_err()
);
assert_eq!((0, 33), try_split_u64(33, 0).unwrap());
assert_eq!((1, 16), try_split_u64(33, 1).unwrap());
assert_eq!((1, 1), try_split_u64(33, 5).unwrap());
assert_eq!((33, 0), try_split_u64(33, 63).unwrap());
assert_eq!((33, 0), try_split_u64(33, 64).unwrap());
assert_eq!(
InstructionError::IllegalAmountBitLength,
try_split_u64(33, 65).unwrap_err()
);
let amount = u64::MAX;
assert_eq!((0, amount), try_split_u64(amount, 0).unwrap());
assert_eq!((1, (1 << 63) - 1), try_split_u64(amount, 1).unwrap());
assert_eq!((31, (1 << 59) - 1), try_split_u64(amount, 5).unwrap());
assert_eq!(((1 << 63) - 1, 1), try_split_u64(amount, 63).unwrap());
assert_eq!((amount, 0), try_split_u64(amount, 64).unwrap());
assert_eq!(
InstructionError::IllegalAmountBitLength,
try_split_u64(amount, 65).unwrap_err()
);
}
fn test_split_and_combine(amount: u64, bit_length: usize) {
let (amount_lo, amount_hi) = try_split_u64(amount, bit_length).unwrap();
assert_eq!(
try_combine_lo_hi_u64(amount_lo, amount_hi, bit_length).unwrap(),
amount
);
}
#[test]
fn test_combine_lo_hi_u64() {
test_split_and_combine(0, 0);
test_split_and_combine(0, 1);
test_split_and_combine(0, 5);
test_split_and_combine(0, 63);
test_split_and_combine(0, 64);
test_split_and_combine(1, 0);
test_split_and_combine(1, 1);
test_split_and_combine(1, 5);
test_split_and_combine(1, 63);
test_split_and_combine(1, 64);
test_split_and_combine(33, 0);
test_split_and_combine(33, 1);
test_split_and_combine(33, 5);
test_split_and_combine(33, 63);
test_split_and_combine(33, 64);
test_split_and_combine(u64::MAX, 0);
test_split_and_combine(u64::MAX, 1);
test_split_and_combine(u64::MAX, 5);
test_split_and_combine(u64::MAX, 63);
test_split_and_combine(u64::MAX, 64);
// illegal amount bit
let err = try_combine_lo_hi_u64(0, 0, 65).unwrap_err();
assert_eq!(err, InstructionError::IllegalAmountBitLength);
// overflow
let amount_lo = u64::MAX;
let amount_hi = u64::MAX;
let err = try_combine_lo_hi_u64(amount_lo, amount_hi, 1).unwrap_err();
assert_eq!(err, InstructionError::IllegalAmountBitLength);
}
}

View File

@ -9,10 +9,10 @@ use {
instruction::{
errors::InstructionError,
transfer::{
combine_lo_hi_ciphertexts, combine_lo_hi_commitments, combine_lo_hi_openings,
combine_lo_hi_u64,
encryption::{FeeEncryption, TransferAmountCiphertext},
split_u64, FeeParameters, Role,
try_combine_lo_hi_ciphertexts, try_combine_lo_hi_commitments,
try_combine_lo_hi_openings, try_combine_lo_hi_u64, try_split_u64, FeeParameters,
Role,
},
},
range_proof::RangeProof,
@ -128,7 +128,8 @@ impl TransferWithFeeData {
withdraw_withheld_authority_pubkey: &ElGamalPubkey,
) -> Result<Self, ProofGenerationError> {
// split and encrypt transfer amount
let (amount_lo, amount_hi) = split_u64(transfer_amount, TRANSFER_AMOUNT_LO_BITS);
let (amount_lo, amount_hi) = try_split_u64(transfer_amount, TRANSFER_AMOUNT_LO_BITS)
.map_err(|_| ProofGenerationError::IllegalAmountBitLength)?;
let (ciphertext_lo, opening_lo) = TransferAmountCiphertext::new(
amount_lo,
@ -159,11 +160,12 @@ impl TransferWithFeeData {
};
let new_source_ciphertext = old_source_ciphertext
- combine_lo_hi_ciphertexts(
- try_combine_lo_hi_ciphertexts(
&transfer_amount_lo_source,
&transfer_amount_hi_source,
TRANSFER_AMOUNT_LO_BITS,
);
)
.map_err(|_| ProofGenerationError::IllegalAmountBitLength)?;
// calculate fee
//
@ -177,7 +179,9 @@ impl TransferWithFeeData {
u64::conditional_select(&fee_parameters.maximum_fee, &fee_amount, below_max);
// split and encrypt fee
let (fee_to_encrypt_lo, fee_to_encrypt_hi) = split_u64(fee_to_encrypt, FEE_AMOUNT_LO_BITS);
let (fee_to_encrypt_lo, fee_to_encrypt_hi) =
try_split_u64(fee_to_encrypt, FEE_AMOUNT_LO_BITS)
.map_err(|_| ProofGenerationError::IllegalAmountBitLength)?;
let (fee_ciphertext_lo, opening_fee_lo) = FeeEncryption::new(
fee_to_encrypt_lo,
@ -510,23 +514,28 @@ impl TransferWithFeeProof {
let pod_claimed_commitment: pod::PedersenCommitment = claimed_commitment.into();
transcript.append_commitment(b"commitment-claimed", &pod_claimed_commitment);
let combined_commitment = combine_lo_hi_commitments(
let combined_commitment = try_combine_lo_hi_commitments(
ciphertext_lo.get_commitment(),
ciphertext_hi.get_commitment(),
TRANSFER_AMOUNT_LO_BITS,
);
)
.map_err(|_| ProofGenerationError::IllegalAmountBitLength)?;
let combined_opening =
combine_lo_hi_openings(opening_lo, opening_hi, TRANSFER_AMOUNT_LO_BITS);
try_combine_lo_hi_openings(opening_lo, opening_hi, TRANSFER_AMOUNT_LO_BITS)
.map_err(|_| ProofGenerationError::IllegalAmountBitLength)?;
let combined_fee_amount =
combine_lo_hi_u64(fee_amount_lo, fee_amount_hi, TRANSFER_AMOUNT_LO_BITS);
let combined_fee_commitment = combine_lo_hi_commitments(
try_combine_lo_hi_u64(fee_amount_lo, fee_amount_hi, TRANSFER_AMOUNT_LO_BITS)
.map_err(|_| ProofGenerationError::IllegalAmountBitLength)?;
let combined_fee_commitment = try_combine_lo_hi_commitments(
fee_ciphertext_lo.get_commitment(),
fee_ciphertext_hi.get_commitment(),
TRANSFER_AMOUNT_LO_BITS,
);
)
.map_err(|_| ProofGenerationError::IllegalAmountBitLength)?;
let combined_fee_opening =
combine_lo_hi_openings(opening_fee_lo, opening_fee_hi, TRANSFER_AMOUNT_LO_BITS);
try_combine_lo_hi_openings(opening_fee_lo, opening_fee_hi, TRANSFER_AMOUNT_LO_BITS)
.map_err(|_| ProofGenerationError::IllegalAmountBitLength)?;
// compute real delta commitment
let (delta_commitment, opening_delta) = compute_delta_commitment_and_opening(
@ -561,11 +570,12 @@ impl TransferWithFeeProof {
// generate the range proof
let opening_claimed_negated = &PedersenOpening::default() - &opening_claimed;
let combined_amount = combine_lo_hi_u64(
let combined_amount = try_combine_lo_hi_u64(
transfer_amount_lo,
transfer_amount_hi,
TRANSFER_AMOUNT_LO_BITS,
);
)
.map_err(|_| ProofGenerationError::IllegalAmountBitLength)?;
let amount_sub_fee = combined_amount
.checked_sub(combined_fee_amount)
.ok_or(ProofGenerationError::FeeCalculation)?;
@ -680,16 +690,18 @@ impl TransferWithFeeProof {
// verify fee sigma proof
transcript.append_commitment(b"commitment-claimed", &self.claimed_commitment);
let combined_commitment = combine_lo_hi_commitments(
let combined_commitment = try_combine_lo_hi_commitments(
ciphertext_lo.get_commitment(),
ciphertext_hi.get_commitment(),
TRANSFER_AMOUNT_LO_BITS,
);
let combined_fee_commitment = combine_lo_hi_commitments(
)
.map_err(|_| ProofVerificationError::IllegalAmountBitLength)?;
let combined_fee_commitment = try_combine_lo_hi_commitments(
fee_ciphertext_lo.get_commitment(),
fee_ciphertext_hi.get_commitment(),
TRANSFER_AMOUNT_LO_BITS,
);
)
.map_err(|_| ProofVerificationError::IllegalAmountBitLength)?;
let delta_commitment = compute_delta_commitment(
&combined_commitment,

View File

@ -9,7 +9,8 @@ use {
instruction::{
errors::InstructionError,
transfer::{
combine_lo_hi_ciphertexts, encryption::TransferAmountCiphertext, split_u64, Role,
encryption::TransferAmountCiphertext, try_combine_lo_hi_ciphertexts, try_split_u64,
Role,
},
},
range_proof::RangeProof,
@ -96,7 +97,8 @@ impl TransferData {
(destination_pubkey, auditor_pubkey): (&ElGamalPubkey, &ElGamalPubkey),
) -> Result<Self, ProofGenerationError> {
// split and encrypt transfer amount
let (amount_lo, amount_hi) = split_u64(transfer_amount, TRANSFER_AMOUNT_LO_BITS);
let (amount_lo, amount_hi) = try_split_u64(transfer_amount, TRANSFER_AMOUNT_LO_BITS)
.map_err(|_| ProofGenerationError::IllegalAmountBitLength)?;
let (ciphertext_lo, opening_lo) = TransferAmountCiphertext::new(
amount_lo,
@ -128,11 +130,12 @@ impl TransferData {
};
let new_source_ciphertext = ciphertext_old_source
- combine_lo_hi_ciphertexts(
- try_combine_lo_hi_ciphertexts(
&transfer_amount_lo_source,
&transfer_amount_hi_source,
TRANSFER_AMOUNT_LO_BITS,
);
)
.map_err(|_| ProofGenerationError::IllegalAmountBitLength)?;
// generate transcript and append all public inputs
let pod_transfer_pubkeys = TransferPubkeys {

View File

@ -134,7 +134,7 @@ mod tests {
elgamal::{ElGamalCiphertext, ElGamalKeypair},
pedersen::{Pedersen, PedersenOpening},
},
instruction::transfer::split_u64,
instruction::transfer::try_split_u64,
zk_token_elgamal::{ops, pod},
},
bytemuck::Zeroable,
@ -204,7 +204,7 @@ mod tests {
fn test_transfer_arithmetic() {
// transfer amount
let transfer_amount: u64 = 55;
let (amount_lo, amount_hi) = split_u64(transfer_amount, 16);
let (amount_lo, amount_hi) = try_split_u64(transfer_amount, 16).unwrap();
// generate public keys
let source_keypair = ElGamalKeypair::new_rand();