[zk-token-sdk] Use checked arithmetic when processing transfer amount (#34130)
* add `try_split_u64` * add `try_combine_lo_hi_u64` * add `try` variants of ciphertext arithmetic functions * use try functions in proof generaiton and verification logic * deprecate non-`try` functions * use try functions in proof generaiton and verification logic * Apply suggestions from code review Co-authored-by: Jon C <me@jonc.dev> * cargo fmt --------- Co-authored-by: Jon C <me@jonc.dev>
This commit is contained in:
parent
b9947bd327
commit
b11d41a3f7
|
@ -38,6 +38,8 @@ pub enum ProofVerificationError {
|
|||
ProofContext,
|
||||
#[error("illegal commitment length")]
|
||||
IllegalCommitmentLength,
|
||||
#[error("illegal amount bit length")]
|
||||
IllegalAmountBitLength,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
|
|
|
@ -8,4 +8,8 @@ pub enum InstructionError {
|
|||
Decryption,
|
||||
#[error("missing ciphertext")]
|
||||
MissingCiphertext,
|
||||
#[error("illegal amount bit length")]
|
||||
IllegalAmountBitLength,
|
||||
#[error("arithmetic overflow")]
|
||||
Overflow,
|
||||
}
|
||||
|
|
|
@ -4,9 +4,12 @@ mod without_fee;
|
|||
|
||||
#[cfg(not(target_os = "solana"))]
|
||||
use {
|
||||
crate::encryption::{
|
||||
elgamal::ElGamalCiphertext,
|
||||
pedersen::{PedersenCommitment, PedersenOpening},
|
||||
crate::{
|
||||
encryption::{
|
||||
elgamal::ElGamalCiphertext,
|
||||
pedersen::{PedersenCommitment, PedersenOpening},
|
||||
},
|
||||
instruction::errors::InstructionError,
|
||||
},
|
||||
curve25519_dalek::scalar::Scalar,
|
||||
};
|
||||
|
@ -33,6 +36,7 @@ pub enum Role {
|
|||
/// Takes in a 64-bit number `amount` and a bit length `bit_length`. It returns:
|
||||
/// - the `bit_length` low bits of `amount` interpreted as u64
|
||||
/// - the (64 - `bit_length`) high bits of `amount` interpreted as u64
|
||||
#[deprecated(since = "1.18.0", note = "please use `try_split_u64` instead")]
|
||||
#[cfg(not(target_os = "solana"))]
|
||||
pub fn split_u64(amount: u64, bit_length: usize) -> (u64, u64) {
|
||||
if bit_length == 64 {
|
||||
|
@ -44,6 +48,30 @@ pub fn split_u64(amount: u64, bit_length: usize) -> (u64, u64) {
|
|||
}
|
||||
}
|
||||
|
||||
/// Takes in a 64-bit number `amount` and a bit length `bit_length`. It returns:
|
||||
/// - the `bit_length` low bits of `amount` interpretted as u64
|
||||
/// - the `(64 - bit_length)` high bits of `amount` interpretted as u64
|
||||
#[cfg(not(target_os = "solana"))]
|
||||
pub fn try_split_u64(amount: u64, bit_length: usize) -> Result<(u64, u64), InstructionError> {
|
||||
match bit_length {
|
||||
0 => Ok((0, amount)),
|
||||
1..=63 => {
|
||||
let bit_length_complement = u64::BITS.checked_sub(bit_length as u32).unwrap();
|
||||
// shifts are safe as long as `bit_length` and `bit_length_complement` < 64
|
||||
let lo = amount
|
||||
.checked_shl(bit_length_complement) // clear out the high bits
|
||||
.and_then(|amount| amount.checked_shr(bit_length_complement))
|
||||
.unwrap(); // shift back
|
||||
let hi = amount.checked_shr(bit_length as u32).unwrap();
|
||||
|
||||
Ok((lo, hi))
|
||||
}
|
||||
64 => Ok((amount, 0)),
|
||||
_ => Err(InstructionError::IllegalAmountBitLength),
|
||||
}
|
||||
}
|
||||
|
||||
#[deprecated(since = "1.18.0", note = "please use `try_combine_lo_hi_u64` instead")]
|
||||
#[cfg(not(target_os = "solana"))]
|
||||
pub fn combine_lo_hi_u64(amount_lo: u64, amount_hi: u64, bit_length: usize) -> u64 {
|
||||
if bit_length == 64 {
|
||||
|
@ -53,16 +81,47 @@ pub fn combine_lo_hi_u64(amount_lo: u64, amount_hi: u64, bit_length: usize) -> u
|
|||
}
|
||||
}
|
||||
|
||||
/// Combine two numbers that are interpretted as the low and high bits of a target number. The
|
||||
/// `bit_length` parameter specifies the number of bits that `amount_hi` is to be shifted by.
|
||||
#[cfg(not(target_os = "solana"))]
|
||||
fn combine_lo_hi_ciphertexts(
|
||||
pub fn try_combine_lo_hi_u64(
|
||||
amount_lo: u64,
|
||||
amount_hi: u64,
|
||||
bit_length: usize,
|
||||
) -> Result<u64, InstructionError> {
|
||||
match bit_length {
|
||||
0 => Ok(amount_hi),
|
||||
1..=63 => {
|
||||
// shifts are safe as long as `bit_length` < 64
|
||||
let amount_hi = amount_hi.checked_shl(bit_length as u32).unwrap();
|
||||
let combined = amount_lo
|
||||
.checked_add(amount_hi)
|
||||
.ok_or(InstructionError::IllegalAmountBitLength)?;
|
||||
Ok(combined)
|
||||
}
|
||||
64 => Ok(amount_lo),
|
||||
_ => Err(InstructionError::IllegalAmountBitLength),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "solana"))]
|
||||
fn try_combine_lo_hi_ciphertexts(
|
||||
ciphertext_lo: &ElGamalCiphertext,
|
||||
ciphertext_hi: &ElGamalCiphertext,
|
||||
bit_length: usize,
|
||||
) -> ElGamalCiphertext {
|
||||
let two_power = (1_u64) << bit_length;
|
||||
ciphertext_lo + &(ciphertext_hi * &Scalar::from(two_power))
|
||||
) -> Result<ElGamalCiphertext, InstructionError> {
|
||||
let two_power = if bit_length < u64::BITS as usize {
|
||||
1_u64.checked_shl(bit_length as u32).unwrap()
|
||||
} else {
|
||||
return Err(InstructionError::IllegalAmountBitLength);
|
||||
};
|
||||
Ok(ciphertext_lo + &(ciphertext_hi * &Scalar::from(two_power)))
|
||||
}
|
||||
|
||||
#[deprecated(
|
||||
since = "1.18.0",
|
||||
note = "please use `try_combine_lo_hi_commitments` instead"
|
||||
)]
|
||||
#[cfg(not(target_os = "solana"))]
|
||||
pub fn combine_lo_hi_commitments(
|
||||
comm_lo: &PedersenCommitment,
|
||||
|
@ -73,6 +132,24 @@ pub fn combine_lo_hi_commitments(
|
|||
comm_lo + comm_hi * &Scalar::from(two_power)
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "solana"))]
|
||||
pub fn try_combine_lo_hi_commitments(
|
||||
comm_lo: &PedersenCommitment,
|
||||
comm_hi: &PedersenCommitment,
|
||||
bit_length: usize,
|
||||
) -> Result<PedersenCommitment, InstructionError> {
|
||||
let two_power = if bit_length < u64::BITS as usize {
|
||||
1_u64.checked_shl(bit_length as u32).unwrap()
|
||||
} else {
|
||||
return Err(InstructionError::IllegalAmountBitLength);
|
||||
};
|
||||
Ok(comm_lo + comm_hi * &Scalar::from(two_power))
|
||||
}
|
||||
|
||||
#[deprecated(
|
||||
since = "1.18.0",
|
||||
note = "please use `try_combine_lo_hi_openings` instead"
|
||||
)]
|
||||
#[cfg(not(target_os = "solana"))]
|
||||
pub fn combine_lo_hi_openings(
|
||||
opening_lo: &PedersenOpening,
|
||||
|
@ -83,6 +160,20 @@ pub fn combine_lo_hi_openings(
|
|||
opening_lo + opening_hi * &Scalar::from(two_power)
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "solana"))]
|
||||
pub fn try_combine_lo_hi_openings(
|
||||
opening_lo: &PedersenOpening,
|
||||
opening_hi: &PedersenOpening,
|
||||
bit_length: usize,
|
||||
) -> Result<PedersenOpening, InstructionError> {
|
||||
let two_power = if bit_length < u64::BITS as usize {
|
||||
1_u64.checked_shl(bit_length as u32).unwrap()
|
||||
} else {
|
||||
return Err(InstructionError::IllegalAmountBitLength);
|
||||
};
|
||||
Ok(opening_lo + opening_hi * &Scalar::from(two_power))
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
#[repr(C)]
|
||||
pub struct FeeParameters {
|
||||
|
@ -91,3 +182,97 @@ pub struct FeeParameters {
|
|||
/// Maximum fee assessed on transfers, expressed as an amount of tokens
|
||||
pub maximum_fee: u64,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_split_u64() {
|
||||
assert_eq!((0, 0), try_split_u64(0, 0).unwrap());
|
||||
assert_eq!((0, 0), try_split_u64(0, 1).unwrap());
|
||||
assert_eq!((0, 0), try_split_u64(0, 5).unwrap());
|
||||
assert_eq!((0, 0), try_split_u64(0, 63).unwrap());
|
||||
assert_eq!((0, 0), try_split_u64(0, 64).unwrap());
|
||||
assert_eq!(
|
||||
InstructionError::IllegalAmountBitLength,
|
||||
try_split_u64(0, 65).unwrap_err()
|
||||
);
|
||||
|
||||
assert_eq!((0, 1), try_split_u64(1, 0).unwrap());
|
||||
assert_eq!((1, 0), try_split_u64(1, 1).unwrap());
|
||||
assert_eq!((1, 0), try_split_u64(1, 5).unwrap());
|
||||
assert_eq!((1, 0), try_split_u64(1, 63).unwrap());
|
||||
assert_eq!((1, 0), try_split_u64(1, 64).unwrap());
|
||||
assert_eq!(
|
||||
InstructionError::IllegalAmountBitLength,
|
||||
try_split_u64(1, 65).unwrap_err()
|
||||
);
|
||||
|
||||
assert_eq!((0, 33), try_split_u64(33, 0).unwrap());
|
||||
assert_eq!((1, 16), try_split_u64(33, 1).unwrap());
|
||||
assert_eq!((1, 1), try_split_u64(33, 5).unwrap());
|
||||
assert_eq!((33, 0), try_split_u64(33, 63).unwrap());
|
||||
assert_eq!((33, 0), try_split_u64(33, 64).unwrap());
|
||||
assert_eq!(
|
||||
InstructionError::IllegalAmountBitLength,
|
||||
try_split_u64(33, 65).unwrap_err()
|
||||
);
|
||||
|
||||
let amount = u64::MAX;
|
||||
assert_eq!((0, amount), try_split_u64(amount, 0).unwrap());
|
||||
assert_eq!((1, (1 << 63) - 1), try_split_u64(amount, 1).unwrap());
|
||||
assert_eq!((31, (1 << 59) - 1), try_split_u64(amount, 5).unwrap());
|
||||
assert_eq!(((1 << 63) - 1, 1), try_split_u64(amount, 63).unwrap());
|
||||
assert_eq!((amount, 0), try_split_u64(amount, 64).unwrap());
|
||||
assert_eq!(
|
||||
InstructionError::IllegalAmountBitLength,
|
||||
try_split_u64(amount, 65).unwrap_err()
|
||||
);
|
||||
}
|
||||
|
||||
fn test_split_and_combine(amount: u64, bit_length: usize) {
|
||||
let (amount_lo, amount_hi) = try_split_u64(amount, bit_length).unwrap();
|
||||
assert_eq!(
|
||||
try_combine_lo_hi_u64(amount_lo, amount_hi, bit_length).unwrap(),
|
||||
amount
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_combine_lo_hi_u64() {
|
||||
test_split_and_combine(0, 0);
|
||||
test_split_and_combine(0, 1);
|
||||
test_split_and_combine(0, 5);
|
||||
test_split_and_combine(0, 63);
|
||||
test_split_and_combine(0, 64);
|
||||
|
||||
test_split_and_combine(1, 0);
|
||||
test_split_and_combine(1, 1);
|
||||
test_split_and_combine(1, 5);
|
||||
test_split_and_combine(1, 63);
|
||||
test_split_and_combine(1, 64);
|
||||
|
||||
test_split_and_combine(33, 0);
|
||||
test_split_and_combine(33, 1);
|
||||
test_split_and_combine(33, 5);
|
||||
test_split_and_combine(33, 63);
|
||||
test_split_and_combine(33, 64);
|
||||
|
||||
test_split_and_combine(u64::MAX, 0);
|
||||
test_split_and_combine(u64::MAX, 1);
|
||||
test_split_and_combine(u64::MAX, 5);
|
||||
test_split_and_combine(u64::MAX, 63);
|
||||
test_split_and_combine(u64::MAX, 64);
|
||||
|
||||
// illegal amount bit
|
||||
let err = try_combine_lo_hi_u64(0, 0, 65).unwrap_err();
|
||||
assert_eq!(err, InstructionError::IllegalAmountBitLength);
|
||||
|
||||
// overflow
|
||||
let amount_lo = u64::MAX;
|
||||
let amount_hi = u64::MAX;
|
||||
let err = try_combine_lo_hi_u64(amount_lo, amount_hi, 1).unwrap_err();
|
||||
assert_eq!(err, InstructionError::IllegalAmountBitLength);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,10 +9,10 @@ use {
|
|||
instruction::{
|
||||
errors::InstructionError,
|
||||
transfer::{
|
||||
combine_lo_hi_ciphertexts, combine_lo_hi_commitments, combine_lo_hi_openings,
|
||||
combine_lo_hi_u64,
|
||||
encryption::{FeeEncryption, TransferAmountCiphertext},
|
||||
split_u64, FeeParameters, Role,
|
||||
try_combine_lo_hi_ciphertexts, try_combine_lo_hi_commitments,
|
||||
try_combine_lo_hi_openings, try_combine_lo_hi_u64, try_split_u64, FeeParameters,
|
||||
Role,
|
||||
},
|
||||
},
|
||||
range_proof::RangeProof,
|
||||
|
@ -128,7 +128,8 @@ impl TransferWithFeeData {
|
|||
withdraw_withheld_authority_pubkey: &ElGamalPubkey,
|
||||
) -> Result<Self, ProofGenerationError> {
|
||||
// split and encrypt transfer amount
|
||||
let (amount_lo, amount_hi) = split_u64(transfer_amount, TRANSFER_AMOUNT_LO_BITS);
|
||||
let (amount_lo, amount_hi) = try_split_u64(transfer_amount, TRANSFER_AMOUNT_LO_BITS)
|
||||
.map_err(|_| ProofGenerationError::IllegalAmountBitLength)?;
|
||||
|
||||
let (ciphertext_lo, opening_lo) = TransferAmountCiphertext::new(
|
||||
amount_lo,
|
||||
|
@ -159,11 +160,12 @@ impl TransferWithFeeData {
|
|||
};
|
||||
|
||||
let new_source_ciphertext = old_source_ciphertext
|
||||
- combine_lo_hi_ciphertexts(
|
||||
- try_combine_lo_hi_ciphertexts(
|
||||
&transfer_amount_lo_source,
|
||||
&transfer_amount_hi_source,
|
||||
TRANSFER_AMOUNT_LO_BITS,
|
||||
);
|
||||
)
|
||||
.map_err(|_| ProofGenerationError::IllegalAmountBitLength)?;
|
||||
|
||||
// calculate fee
|
||||
//
|
||||
|
@ -177,7 +179,9 @@ impl TransferWithFeeData {
|
|||
u64::conditional_select(&fee_parameters.maximum_fee, &fee_amount, below_max);
|
||||
|
||||
// split and encrypt fee
|
||||
let (fee_to_encrypt_lo, fee_to_encrypt_hi) = split_u64(fee_to_encrypt, FEE_AMOUNT_LO_BITS);
|
||||
let (fee_to_encrypt_lo, fee_to_encrypt_hi) =
|
||||
try_split_u64(fee_to_encrypt, FEE_AMOUNT_LO_BITS)
|
||||
.map_err(|_| ProofGenerationError::IllegalAmountBitLength)?;
|
||||
|
||||
let (fee_ciphertext_lo, opening_fee_lo) = FeeEncryption::new(
|
||||
fee_to_encrypt_lo,
|
||||
|
@ -510,23 +514,28 @@ impl TransferWithFeeProof {
|
|||
let pod_claimed_commitment: pod::PedersenCommitment = claimed_commitment.into();
|
||||
transcript.append_commitment(b"commitment-claimed", &pod_claimed_commitment);
|
||||
|
||||
let combined_commitment = combine_lo_hi_commitments(
|
||||
let combined_commitment = try_combine_lo_hi_commitments(
|
||||
ciphertext_lo.get_commitment(),
|
||||
ciphertext_hi.get_commitment(),
|
||||
TRANSFER_AMOUNT_LO_BITS,
|
||||
);
|
||||
)
|
||||
.map_err(|_| ProofGenerationError::IllegalAmountBitLength)?;
|
||||
let combined_opening =
|
||||
combine_lo_hi_openings(opening_lo, opening_hi, TRANSFER_AMOUNT_LO_BITS);
|
||||
try_combine_lo_hi_openings(opening_lo, opening_hi, TRANSFER_AMOUNT_LO_BITS)
|
||||
.map_err(|_| ProofGenerationError::IllegalAmountBitLength)?;
|
||||
|
||||
let combined_fee_amount =
|
||||
combine_lo_hi_u64(fee_amount_lo, fee_amount_hi, TRANSFER_AMOUNT_LO_BITS);
|
||||
let combined_fee_commitment = combine_lo_hi_commitments(
|
||||
try_combine_lo_hi_u64(fee_amount_lo, fee_amount_hi, TRANSFER_AMOUNT_LO_BITS)
|
||||
.map_err(|_| ProofGenerationError::IllegalAmountBitLength)?;
|
||||
let combined_fee_commitment = try_combine_lo_hi_commitments(
|
||||
fee_ciphertext_lo.get_commitment(),
|
||||
fee_ciphertext_hi.get_commitment(),
|
||||
TRANSFER_AMOUNT_LO_BITS,
|
||||
);
|
||||
)
|
||||
.map_err(|_| ProofGenerationError::IllegalAmountBitLength)?;
|
||||
let combined_fee_opening =
|
||||
combine_lo_hi_openings(opening_fee_lo, opening_fee_hi, TRANSFER_AMOUNT_LO_BITS);
|
||||
try_combine_lo_hi_openings(opening_fee_lo, opening_fee_hi, TRANSFER_AMOUNT_LO_BITS)
|
||||
.map_err(|_| ProofGenerationError::IllegalAmountBitLength)?;
|
||||
|
||||
// compute real delta commitment
|
||||
let (delta_commitment, opening_delta) = compute_delta_commitment_and_opening(
|
||||
|
@ -561,11 +570,12 @@ impl TransferWithFeeProof {
|
|||
// generate the range proof
|
||||
let opening_claimed_negated = &PedersenOpening::default() - &opening_claimed;
|
||||
|
||||
let combined_amount = combine_lo_hi_u64(
|
||||
let combined_amount = try_combine_lo_hi_u64(
|
||||
transfer_amount_lo,
|
||||
transfer_amount_hi,
|
||||
TRANSFER_AMOUNT_LO_BITS,
|
||||
);
|
||||
)
|
||||
.map_err(|_| ProofGenerationError::IllegalAmountBitLength)?;
|
||||
let amount_sub_fee = combined_amount
|
||||
.checked_sub(combined_fee_amount)
|
||||
.ok_or(ProofGenerationError::FeeCalculation)?;
|
||||
|
@ -680,16 +690,18 @@ impl TransferWithFeeProof {
|
|||
// verify fee sigma proof
|
||||
transcript.append_commitment(b"commitment-claimed", &self.claimed_commitment);
|
||||
|
||||
let combined_commitment = combine_lo_hi_commitments(
|
||||
let combined_commitment = try_combine_lo_hi_commitments(
|
||||
ciphertext_lo.get_commitment(),
|
||||
ciphertext_hi.get_commitment(),
|
||||
TRANSFER_AMOUNT_LO_BITS,
|
||||
);
|
||||
let combined_fee_commitment = combine_lo_hi_commitments(
|
||||
)
|
||||
.map_err(|_| ProofVerificationError::IllegalAmountBitLength)?;
|
||||
let combined_fee_commitment = try_combine_lo_hi_commitments(
|
||||
fee_ciphertext_lo.get_commitment(),
|
||||
fee_ciphertext_hi.get_commitment(),
|
||||
TRANSFER_AMOUNT_LO_BITS,
|
||||
);
|
||||
)
|
||||
.map_err(|_| ProofVerificationError::IllegalAmountBitLength)?;
|
||||
|
||||
let delta_commitment = compute_delta_commitment(
|
||||
&combined_commitment,
|
||||
|
|
|
@ -9,7 +9,8 @@ use {
|
|||
instruction::{
|
||||
errors::InstructionError,
|
||||
transfer::{
|
||||
combine_lo_hi_ciphertexts, encryption::TransferAmountCiphertext, split_u64, Role,
|
||||
encryption::TransferAmountCiphertext, try_combine_lo_hi_ciphertexts, try_split_u64,
|
||||
Role,
|
||||
},
|
||||
},
|
||||
range_proof::RangeProof,
|
||||
|
@ -96,7 +97,8 @@ impl TransferData {
|
|||
(destination_pubkey, auditor_pubkey): (&ElGamalPubkey, &ElGamalPubkey),
|
||||
) -> Result<Self, ProofGenerationError> {
|
||||
// split and encrypt transfer amount
|
||||
let (amount_lo, amount_hi) = split_u64(transfer_amount, TRANSFER_AMOUNT_LO_BITS);
|
||||
let (amount_lo, amount_hi) = try_split_u64(transfer_amount, TRANSFER_AMOUNT_LO_BITS)
|
||||
.map_err(|_| ProofGenerationError::IllegalAmountBitLength)?;
|
||||
|
||||
let (ciphertext_lo, opening_lo) = TransferAmountCiphertext::new(
|
||||
amount_lo,
|
||||
|
@ -128,11 +130,12 @@ impl TransferData {
|
|||
};
|
||||
|
||||
let new_source_ciphertext = ciphertext_old_source
|
||||
- combine_lo_hi_ciphertexts(
|
||||
- try_combine_lo_hi_ciphertexts(
|
||||
&transfer_amount_lo_source,
|
||||
&transfer_amount_hi_source,
|
||||
TRANSFER_AMOUNT_LO_BITS,
|
||||
);
|
||||
)
|
||||
.map_err(|_| ProofGenerationError::IllegalAmountBitLength)?;
|
||||
|
||||
// generate transcript and append all public inputs
|
||||
let pod_transfer_pubkeys = TransferPubkeys {
|
||||
|
|
|
@ -134,7 +134,7 @@ mod tests {
|
|||
elgamal::{ElGamalCiphertext, ElGamalKeypair},
|
||||
pedersen::{Pedersen, PedersenOpening},
|
||||
},
|
||||
instruction::transfer::split_u64,
|
||||
instruction::transfer::try_split_u64,
|
||||
zk_token_elgamal::{ops, pod},
|
||||
},
|
||||
bytemuck::Zeroable,
|
||||
|
@ -204,7 +204,7 @@ mod tests {
|
|||
fn test_transfer_arithmetic() {
|
||||
// transfer amount
|
||||
let transfer_amount: u64 = 55;
|
||||
let (amount_lo, amount_hi) = split_u64(transfer_amount, 16);
|
||||
let (amount_lo, amount_hi) = try_split_u64(transfer_amount, 16).unwrap();
|
||||
|
||||
// generate public keys
|
||||
let source_keypair = ElGamalKeypair::new_rand();
|
||||
|
|
Loading…
Reference in New Issue