Use u128 for all the math but store in u64 (#711)

* Use u128 for all the math but store in u64

* tests

* missing changes

* bulds

* specific conversion failure

* fix tests

* use large numbers

* Rebase and fix merge issue from new tests

Co-authored-by: Jon Cinque <jon.cinque@gmail.com>
This commit is contained in:
anatoly yakovenko 2020-10-26 10:18:33 -07:00 committed by GitHub
parent 7473fa0035
commit 64a362c059
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 263 additions and 189 deletions

10
Cargo.lock generated
View File

@ -626,13 +626,12 @@ dependencies = [
[[package]]
name = "curve25519-dalek"
version = "2.1.0"
source = "git+https://github.com/garious/curve25519-dalek?rev=60efef3553d6bf3d7f3b09b5f97acd54d72529ff#60efef3553d6bf3d7f3b09b5f97acd54d72529ff"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d85653f070353a16313d0046f173f70d1aadd5b42600a14de626f0dfb3473a5"
dependencies = [
"borsh",
"byteorder",
"digest 0.8.1",
"rand_core",
"serde",
"subtle 2.2.3",
"zeroize",
]
@ -640,12 +639,13 @@ dependencies = [
[[package]]
name = "curve25519-dalek"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d85653f070353a16313d0046f173f70d1aadd5b42600a14de626f0dfb3473a5"
source = "git+https://github.com/garious/curve25519-dalek?rev=60efef3553d6bf3d7f3b09b5f97acd54d72529ff#60efef3553d6bf3d7f3b09b5f97acd54d72529ff"
dependencies = [
"borsh",
"byteorder",
"digest 0.8.1",
"rand_core",
"serde",
"subtle 2.2.3",
"zeroize",
]

View File

@ -10,10 +10,10 @@ use std::convert::{TryFrom, TryInto};
use std::fmt::Debug;
/// Initial amount of pool tokens for swap contract, hard-coded to something
/// "sensible" given a maximum of u64.
/// "sensible" given a maximum of u128.
/// Note that on Ethereum, Uniswap uses the geometric mean of all provided
/// input amounts, and Balancer uses 100 * 10 ^ 18.
pub const INITIAL_SWAP_POOL_AMOUNT: u64 = 1_000_000_000;
pub const INITIAL_SWAP_POOL_AMOUNT: u128 = 1_000_000_000;
/// Curve types supported by the token-swap program.
#[repr(C)]
@ -141,20 +141,20 @@ pub trait CurveCalculator: Debug + DynPack {
/// of source token.
fn swap(
&self,
source_amount: u64,
swap_source_amount: u64,
swap_destination_amount: u64,
source_amount: u128,
swap_source_amount: u128,
swap_destination_amount: u128,
) -> Option<SwapResult>;
/// Calculate the withdraw fee in pool tokens
/// Default implementation assumes no fee
fn owner_withdraw_fee(&self, _pool_tokens: u64) -> Option<u64> {
fn owner_withdraw_fee(&self, _pool_tokens: u128) -> Option<u128> {
Some(0)
}
/// Calculate the trading fee in trading tokens
/// Default implementation assumes no fee
fn trading_fee(&self, _trading_tokens: u64) -> Option<u64> {
fn trading_fee(&self, _trading_tokens: u128) -> Option<u128> {
Some(0)
}
@ -165,11 +165,11 @@ pub trait CurveCalculator: Debug + DynPack {
/// value.
fn owner_fee_to_pool_tokens(
&self,
owner_fee: u64,
trading_token_amount: u64,
pool_supply: u64,
tokens_in_pool: u64,
) -> Option<u64> {
owner_fee: u128,
trading_token_amount: u128,
pool_supply: u128,
tokens_in_pool: u128,
) -> Option<u128> {
// Get the trading fee incurred if the owner fee is swapped for the other side
let trade_fee = self.trading_fee(owner_fee)?;
let owner_fee = owner_fee.checked_sub(trade_fee)?;
@ -181,7 +181,7 @@ pub trait CurveCalculator: Debug + DynPack {
/// Get the supply for a new pool
/// The default implementation is a Balancer-style fixed initial supply
fn new_pool_supply(&self) -> u64 {
fn new_pool_supply(&self) -> u128 {
INITIAL_SWAP_POOL_AMOUNT
}
@ -191,10 +191,10 @@ pub trait CurveCalculator: Debug + DynPack {
/// trading tokens correspond to a certain number of pool tokens
fn pool_tokens_to_trading_tokens(
&self,
pool_tokens: u64,
pool_token_supply: u64,
total_trading_tokens: u64,
) -> Option<u64> {
pool_tokens: u128,
pool_token_supply: u128,
total_trading_tokens: u128,
) -> Option<u128> {
pool_tokens
.checked_mul(total_trading_tokens)?
.checked_div(pool_token_supply)
@ -205,19 +205,19 @@ pub trait CurveCalculator: Debug + DynPack {
/// Encodes all results of swapping from a source token to a destination token
pub struct SwapResult {
/// New amount of source token
pub new_source_amount: u64,
pub new_source_amount: u128,
/// New amount of destination token
pub new_destination_amount: u64,
pub new_destination_amount: u128,
/// Amount of destination token swapped
pub amount_swapped: u64,
pub amount_swapped: u128,
/// Amount of source tokens going to pool holders
pub trade_fee: u64,
pub trade_fee: u128,
/// Amount of source tokens going to owner
pub owner_fee: u64,
pub owner_fee: u128,
}
/// Helper function for mapping to SwapError::CalculationFailure
fn map_zero_to_none(x: u64) -> Option<u64> {
fn map_zero_to_none(x: u128) -> Option<u128> {
if x == 0 {
None
} else {
@ -242,7 +242,7 @@ pub struct FlatCurve {
pub owner_withdraw_fee_denominator: u64,
}
fn calculate_fee(token_amount: u64, fee_numerator: u64, fee_denominator: u64) -> Option<u64> {
fn calculate_fee(token_amount: u128, fee_numerator: u128, fee_denominator: u128) -> Option<u128> {
if fee_numerator == 0 {
Some(0)
} else {
@ -261,20 +261,20 @@ impl CurveCalculator for FlatCurve {
/// Flat curve swap always returns 1:1 (minus fee)
fn swap(
&self,
source_amount: u64,
swap_source_amount: u64,
swap_destination_amount: u64,
source_amount: u128,
swap_source_amount: u128,
swap_destination_amount: u128,
) -> Option<SwapResult> {
// debit the fee to calculate the amount swapped
let trade_fee = calculate_fee(
source_amount,
self.trade_fee_numerator,
self.trade_fee_denominator,
u128::try_from(self.trade_fee_numerator).ok()?,
u128::try_from(self.trade_fee_denominator).ok()?,
)?;
let owner_fee = calculate_fee(
source_amount,
self.owner_trade_fee_numerator,
self.owner_trade_fee_denominator,
u128::try_from(self.owner_trade_fee_numerator).ok()?,
u128::try_from(self.owner_trade_fee_denominator).ok()?,
)?;
let amount_swapped = source_amount
@ -294,11 +294,11 @@ impl CurveCalculator for FlatCurve {
}
/// Calculate the withdraw fee in pool tokens
fn owner_withdraw_fee(&self, pool_tokens: u64) -> Option<u64> {
fn owner_withdraw_fee(&self, pool_tokens: u128) -> Option<u128> {
calculate_fee(
pool_tokens,
self.owner_withdraw_fee_numerator,
self.owner_withdraw_fee_denominator,
u128::try_from(self.owner_withdraw_fee_numerator).ok()?,
u128::try_from(self.owner_withdraw_fee_denominator).ok()?,
)
}
}
@ -379,16 +379,16 @@ impl CurveCalculator for ConstantProductCurve {
/// Constant product swap ensures x * y = constant
fn swap(
&self,
source_amount: u64,
swap_source_amount: u64,
swap_destination_amount: u64,
source_amount: u128,
swap_source_amount: u128,
swap_destination_amount: u128,
) -> Option<SwapResult> {
// debit the fee to calculate the amount swapped
let trade_fee = self.trading_fee(source_amount)?;
let owner_fee = calculate_fee(
source_amount,
self.owner_trade_fee_numerator,
self.owner_trade_fee_denominator,
u128::try_from(self.owner_trade_fee_numerator).ok()?,
u128::try_from(self.owner_trade_fee_denominator).ok()?,
)?;
let invariant = swap_source_amount.checked_mul(swap_destination_amount)?;
@ -412,20 +412,20 @@ impl CurveCalculator for ConstantProductCurve {
}
/// Calculate the withdraw fee in pool tokens
fn owner_withdraw_fee(&self, pool_tokens: u64) -> Option<u64> {
fn owner_withdraw_fee(&self, pool_tokens: u128) -> Option<u128> {
calculate_fee(
pool_tokens,
self.owner_withdraw_fee_numerator,
self.owner_withdraw_fee_denominator,
u128::try_from(self.owner_withdraw_fee_numerator).ok()?,
u128::try_from(self.owner_withdraw_fee_denominator).ok()?,
)
}
/// Calculate the trading fee in trading tokens
fn trading_fee(&self, trading_tokens: u64) -> Option<u64> {
fn trading_fee(&self, trading_tokens: u128) -> Option<u128> {
calculate_fee(
trading_tokens,
self.trade_fee_numerator,
self.trade_fee_denominator,
u128::try_from(self.trade_fee_numerator).ok()?,
u128::try_from(self.trade_fee_denominator).ok()?,
)
}
}
@ -508,7 +508,7 @@ mod tests {
assert_eq!(calculator.new_pool_supply(), INITIAL_SWAP_POOL_AMOUNT);
}
fn check_pool_token_rate(token_a: u64, deposit: u64, supply: u64, expected: Option<u64>) {
fn check_pool_token_rate(token_a: u128, deposit: u128, supply: u128, expected: Option<u128>) {
let trade_fee_numerator = 0;
let trade_fee_denominator = 1;
let owner_trade_fee_numerator = 0;
@ -535,21 +535,21 @@ mod tests {
check_pool_token_rate(10, 5, 10, Some(5));
check_pool_token_rate(5, 5, 10, Some(2));
check_pool_token_rate(5, 5, 10, Some(2));
check_pool_token_rate(u64::MAX, 5, 10, None);
check_pool_token_rate(u128::MAX, 5, 10, None);
}
#[test]
fn constant_product_swap_calculation_trade_fee() {
// calculation on https://github.com/solana-labs/solana-program-library/issues/341
let swap_source_amount: u64 = 1000;
let swap_destination_amount: u64 = 50000;
let trade_fee_numerator: u64 = 1;
let trade_fee_denominator: u64 = 100;
let owner_trade_fee_numerator: u64 = 0;
let owner_trade_fee_denominator: u64 = 0;
let owner_withdraw_fee_numerator: u64 = 0;
let owner_withdraw_fee_denominator: u64 = 0;
let source_amount: u64 = 100;
let swap_source_amount = 1000;
let swap_destination_amount = 50000;
let trade_fee_numerator = 1;
let trade_fee_denominator = 100;
let owner_trade_fee_numerator = 0;
let owner_trade_fee_denominator = 0;
let owner_withdraw_fee_numerator = 0;
let owner_withdraw_fee_denominator = 0;
let source_amount = 100;
let curve = ConstantProductCurve {
trade_fee_numerator,
trade_fee_denominator,
@ -571,15 +571,15 @@ mod tests {
#[test]
fn constant_product_swap_calculation_owner_fee() {
// calculation on https://github.com/solana-labs/solana-program-library/issues/341
let swap_source_amount: u64 = 1000;
let swap_destination_amount: u64 = 50000;
let trade_fee_numerator: u64 = 0;
let trade_fee_denominator: u64 = 0;
let owner_trade_fee_numerator: u64 = 1;
let owner_trade_fee_denominator: u64 = 100;
let owner_withdraw_fee_numerator: u64 = 0;
let owner_withdraw_fee_denominator: u64 = 0;
let source_amount: u64 = 100;
let swap_source_amount = 1000;
let swap_destination_amount = 50000;
let trade_fee_numerator = 0;
let trade_fee_denominator = 0;
let owner_trade_fee_numerator = 1;
let owner_trade_fee_denominator = 100;
let owner_withdraw_fee_numerator = 0;
let owner_withdraw_fee_denominator = 0;
let source_amount: u128 = 100;
let curve = ConstantProductCurve {
trade_fee_numerator,
trade_fee_denominator,
@ -600,9 +600,9 @@ mod tests {
#[test]
fn constant_product_swap_no_fee() {
let swap_source_amount: u64 = 1000;
let swap_destination_amount: u64 = 50000;
let source_amount: u64 = 100;
let swap_source_amount: u128 = 1000;
let swap_destination_amount: u128 = 50000;
let source_amount: u128 = 100;
let curve = ConstantProductCurve::default();
let result = curve
.swap(source_amount, swap_source_amount, swap_destination_amount)
@ -614,15 +614,15 @@ mod tests {
#[test]
fn flat_swap_calculation() {
let swap_source_amount: u64 = 1000;
let swap_destination_amount: u64 = 50000;
let trade_fee_numerator: u64 = 1;
let trade_fee_denominator: u64 = 100;
let owner_trade_fee_numerator: u64 = 2;
let owner_trade_fee_denominator: u64 = 100;
let owner_withdraw_fee_numerator: u64 = 2;
let owner_withdraw_fee_denominator: u64 = 100;
let source_amount: u64 = 100;
let swap_source_amount = 1000;
let swap_destination_amount = 50000;
let trade_fee_numerator = 1;
let trade_fee_denominator = 100;
let owner_trade_fee_numerator = 2;
let owner_trade_fee_denominator = 100;
let owner_withdraw_fee_numerator = 2;
let owner_withdraw_fee_denominator = 100;
let source_amount: u128 = 100;
let curve = FlatCurve {
trade_fee_numerator,
trade_fee_denominator,

View File

@ -73,6 +73,9 @@ pub enum SwapError {
/// The fee calculation failed due to overflow, underflow, or unexpected 0
#[error("Fee calculation failed due to overflow, underflow, or unexpected 0")]
FeeCalculationFailure,
/// ConversionFailure
#[error("Conversion to u64 failed with an overflow or underflow")]
ConversionFailure,
}
impl From<SwapError> for ProgramError {
fn from(e: SwapError) -> Self {

View File

@ -13,6 +13,7 @@ use solana_program::{
program_pack::Pack,
pubkey::Pubkey,
};
use std::convert::TryInto;
/// Hardcode the number of token types in a pool, used to calculate the
/// equivalent pool tokens for the owner trading fee.
@ -211,7 +212,7 @@ impl Processor {
destination_info.clone(),
authority_info.clone(),
nonce,
initial_amount,
to_u64(initial_amount)?,
)?;
let obj = SwapInfo {
@ -286,9 +287,13 @@ impl Processor {
let result = token_swap
.swap_curve
.calculator
.swap(amount_in, source_account.amount, dest_account.amount)
.swap(
to_u128(amount_in)?,
to_u128(source_account.amount)?,
to_u128(dest_account.amount)?,
)
.ok_or(SwapError::ZeroTradingTokens)?;
if result.amount_swapped < minimum_amount_out {
if result.amount_swapped < to_u128(minimum_amount_out)? {
return Err(SwapError::ExceededSlippage.into());
}
Self::token_transfer(
@ -307,7 +312,7 @@ impl Processor {
destination_info.clone(),
authority_info.clone(),
token_swap.nonce,
result.amount_swapped,
to_u64(result.amount_swapped)?,
)?;
// mint pool tokens equivalent to the owner fee
@ -317,9 +322,9 @@ impl Processor {
.calculator
.owner_fee_to_pool_tokens(
result.owner_fee,
source_account.amount,
pool_mint.supply,
TOKENS_IN_POOL,
to_u128(source_account.amount)?,
to_u128(pool_mint.supply)?,
to_u128(TOKENS_IN_POOL)?,
)
.ok_or(SwapError::FeeCalculationFailure)?;
if pool_token_amount > 0 {
@ -330,7 +335,7 @@ impl Processor {
pool_fee_account_info.clone(),
authority_info.clone(),
token_swap.nonce,
pool_token_amount,
to_u64(pool_token_amount)?,
)?;
}
Ok(())
@ -378,19 +383,29 @@ impl Processor {
let token_a = Self::unpack_token_account(&token_a_info.data.borrow())?;
let token_b = Self::unpack_token_account(&token_b_info.data.borrow())?;
let pool_mint = Self::unpack_mint(&pool_mint_info.data.borrow())?;
let pool_token_amount = to_u128(pool_token_amount)?;
let pool_mint_supply = to_u128(pool_mint.supply)?;
let calculator = token_swap.swap_curve.calculator;
let a_amount = calculator
.pool_tokens_to_trading_tokens(pool_token_amount, pool_mint.supply, token_a.amount)
.pool_tokens_to_trading_tokens(
pool_token_amount,
pool_mint_supply,
to_u128(token_a.amount)?,
)
.ok_or(SwapError::ZeroTradingTokens)?;
if a_amount > maximum_token_a_amount {
if a_amount > to_u128(maximum_token_a_amount)? {
return Err(SwapError::ExceededSlippage.into());
}
let b_amount = calculator
.pool_tokens_to_trading_tokens(pool_token_amount, pool_mint.supply, token_b.amount)
.pool_tokens_to_trading_tokens(
pool_token_amount,
pool_mint_supply,
to_u128(token_b.amount)?,
)
.ok_or(SwapError::ZeroTradingTokens)?;
if b_amount > maximum_token_b_amount {
if b_amount > to_u128(maximum_token_b_amount)? {
return Err(SwapError::ExceededSlippage.into());
}
@ -401,7 +416,7 @@ impl Processor {
token_a_info.clone(),
authority_info.clone(),
token_swap.nonce,
a_amount,
to_u64(a_amount)?,
)?;
Self::token_transfer(
swap_info.key,
@ -410,7 +425,7 @@ impl Processor {
token_b_info.clone(),
authority_info.clone(),
token_swap.nonce,
b_amount,
to_u64(b_amount)?,
)?;
Self::token_mint_to(
swap_info.key,
@ -419,7 +434,7 @@ impl Processor {
dest_info.clone(),
authority_info.clone(),
token_swap.nonce,
pool_token_amount,
to_u64(pool_token_amount)?,
)?;
Ok(())
@ -474,27 +489,36 @@ impl Processor {
let calculator = token_swap.swap_curve.calculator;
let withdraw_fee = if *pool_fee_account_info.key == *source_info.key {
let withdraw_fee: u128 = if *pool_fee_account_info.key == *source_info.key {
// withdrawing from the fee account, don't assess withdraw fee
0
} else {
calculator
.owner_withdraw_fee(pool_token_amount)
.owner_withdraw_fee(to_u128(pool_token_amount)?)
.ok_or(SwapError::FeeCalculationFailure)?
};
let pool_token_amount = pool_token_amount
let pool_token_amount = to_u128(pool_token_amount)?
.checked_sub(withdraw_fee)
.ok_or(SwapError::CalculationFailure)?;
let a_amount = calculator
.pool_tokens_to_trading_tokens(pool_token_amount, pool_mint.supply, token_a.amount)
.pool_tokens_to_trading_tokens(
pool_token_amount,
to_u128(pool_mint.supply)?,
to_u128(token_a.amount)?,
)
.ok_or(SwapError::ZeroTradingTokens)?;
if a_amount < minimum_token_a_amount {
if a_amount < to_u128(minimum_token_a_amount)? {
return Err(SwapError::ExceededSlippage.into());
}
let b_amount = calculator
.pool_tokens_to_trading_tokens(pool_token_amount, pool_mint.supply, token_b.amount)
.pool_tokens_to_trading_tokens(
pool_token_amount,
to_u128(pool_mint.supply)?,
to_u128(token_b.amount)?,
)
.ok_or(SwapError::ZeroTradingTokens)?;
let b_amount = to_u64(b_amount)?;
if b_amount < minimum_token_b_amount {
return Err(SwapError::ExceededSlippage.into());
}
@ -506,7 +530,7 @@ impl Processor {
dest_token_a_info.clone(),
authority_info.clone(),
token_swap.nonce,
a_amount,
to_u64(a_amount)?,
)?;
Self::token_transfer(
swap_info.key,
@ -525,7 +549,7 @@ impl Processor {
pool_fee_account_info.clone(),
authority_info.clone(),
token_swap.nonce,
withdraw_fee,
to_u64(withdraw_fee)?,
)?;
}
Self::token_burn(
@ -535,7 +559,7 @@ impl Processor {
pool_mint_info.clone(),
authority_info.clone(),
token_swap.nonce,
pool_token_amount,
to_u64(pool_token_amount)?,
)?;
Ok(())
}
@ -637,10 +661,19 @@ impl PrintProgramError for SwapError {
SwapError::FeeCalculationFailure => info!(
"Error: The fee calculation failed due to overflow, underflow, or unexpected 0"
),
SwapError::ConversionFailure => info!("Error: Conversion to or from u64 failed."),
}
}
}
fn to_u128(val: u64) -> Result<u128, SwapError> {
val.try_into().map_err(|_| SwapError::ConversionFailure)
}
fn to_u64(val: u128) -> Result<u64, SwapError> {
val.try_into().map_err(|_| SwapError::ConversionFailure)
}
#[cfg(test)]
mod tests {
use super::*;
@ -1858,7 +1891,7 @@ mod tests {
&mut token_b_account,
&pool_key,
&mut pool_account,
pool_amount,
pool_amount.try_into().unwrap(),
deposit_a,
deposit_b,
)
@ -1893,7 +1926,7 @@ mod tests {
&mut token_b_account,
&pool_key,
&mut pool_account,
pool_amount,
pool_amount.try_into().unwrap(),
deposit_a,
deposit_b,
)
@ -1927,7 +1960,7 @@ mod tests {
&mut token_b_account,
&pool_key,
&mut pool_account,
pool_amount,
pool_amount.try_into().unwrap(),
deposit_a,
deposit_b,
)
@ -1960,7 +1993,7 @@ mod tests {
&mut token_b_account,
&pool_key,
&mut pool_account,
pool_amount,
pool_amount.try_into().unwrap(),
deposit_a,
deposit_b,
)
@ -1987,7 +2020,7 @@ mod tests {
&mut token_a_account,
&pool_key,
&mut pool_account,
pool_amount,
pool_amount.try_into().unwrap(),
deposit_a,
deposit_b,
)
@ -2022,7 +2055,7 @@ mod tests {
&mut token_b_account,
&wrong_token_key,
&mut wrong_token_account,
pool_amount,
pool_amount.try_into().unwrap(),
deposit_a,
deposit_b,
)
@ -2053,7 +2086,7 @@ mod tests {
&accounts.token_b_key,
&accounts.pool_mint_key,
&pool_key,
pool_amount,
pool_amount.try_into().unwrap(),
deposit_a,
deposit_b,
)
@ -2098,7 +2131,7 @@ mod tests {
&accounts.token_b_key,
&accounts.pool_mint_key,
&pool_key,
pool_amount,
pool_amount.try_into().unwrap(),
deposit_a,
deposit_b,
)
@ -2146,7 +2179,7 @@ mod tests {
&mut token_b_account,
&pool_key,
&mut pool_account,
pool_amount,
pool_amount.try_into().unwrap(),
deposit_a,
deposit_b,
)
@ -2172,7 +2205,7 @@ mod tests {
&mut token_b_account,
&pool_key,
&mut pool_account,
pool_amount,
pool_amount.try_into().unwrap(),
deposit_a,
deposit_b,
)
@ -2209,7 +2242,7 @@ mod tests {
&mut token_b_account,
&pool_key,
&mut pool_account,
pool_amount,
pool_amount.try_into().unwrap(),
deposit_a,
deposit_b,
)
@ -2267,7 +2300,7 @@ mod tests {
&mut token_b_account,
&pool_key,
&mut pool_account,
pool_amount,
pool_amount.try_into().unwrap(),
deposit_a / 10,
deposit_b,
)
@ -2283,7 +2316,7 @@ mod tests {
&mut token_b_account,
&pool_key,
&mut pool_account,
pool_amount,
pool_amount.try_into().unwrap(),
deposit_a,
deposit_b / 10,
)
@ -2315,7 +2348,7 @@ mod tests {
&mut swap_token_b_account,
&pool_key,
&mut pool_account,
pool_amount,
pool_amount.try_into().unwrap(),
deposit_a,
deposit_b,
)
@ -2340,7 +2373,7 @@ mod tests {
&mut token_b_account,
&pool_key,
&mut pool_account,
pool_amount,
pool_amount.try_into().unwrap(),
deposit_a,
deposit_b,
)
@ -2422,7 +2455,7 @@ mod tests {
&mut token_a_account,
&token_b_key,
&mut token_b_account,
withdraw_amount,
withdraw_amount.try_into().unwrap(),
minimum_a_amount,
minimum_b_amount,
)
@ -2457,7 +2490,7 @@ mod tests {
&mut token_a_account,
&token_b_key,
&mut token_b_account,
withdraw_amount,
withdraw_amount.try_into().unwrap(),
minimum_a_amount,
minimum_b_amount,
)
@ -2479,7 +2512,7 @@ mod tests {
&withdrawer_key,
initial_a,
initial_b,
withdraw_amount / 2,
to_u64(withdraw_amount).unwrap() / 2u64,
);
assert_eq!(
Err(TokenError::InsufficientFunds.into()),
@ -2491,7 +2524,7 @@ mod tests {
&mut token_a_account,
&token_b_key,
&mut token_b_account,
withdraw_amount,
withdraw_amount.try_into().unwrap(),
minimum_a_amount / 2,
minimum_b_amount / 2,
)
@ -2512,7 +2545,7 @@ mod tests {
&withdrawer_key,
initial_a,
initial_b,
withdraw_amount,
withdraw_amount.try_into().unwrap(),
);
assert_eq!(
Err(TokenError::MintMismatch.into()),
@ -2524,7 +2557,7 @@ mod tests {
&mut token_b_account,
&token_a_key,
&mut token_a_account,
withdraw_amount,
withdraw_amount.try_into().unwrap(),
minimum_a_amount,
minimum_b_amount,
)
@ -2545,7 +2578,7 @@ mod tests {
&withdrawer_key,
initial_a,
initial_b,
withdraw_amount,
withdraw_amount.try_into().unwrap(),
);
let (
wrong_token_a_key,
@ -2557,9 +2590,9 @@ mod tests {
) = accounts.setup_token_accounts(
&user_key,
&withdrawer_key,
withdraw_amount,
withdraw_amount.try_into().unwrap(),
initial_b,
withdraw_amount,
withdraw_amount.try_into().unwrap(),
);
assert_eq!(
Err(TokenError::MintMismatch.into()),
@ -2571,7 +2604,7 @@ mod tests {
&mut token_a_account,
&token_b_key,
&mut token_b_account,
withdraw_amount,
withdraw_amount.try_into().unwrap(),
minimum_a_amount,
minimum_b_amount,
)
@ -2592,7 +2625,7 @@ mod tests {
&withdrawer_key,
initial_a,
initial_b,
withdraw_amount,
withdraw_amount.try_into().unwrap(),
);
let (
_token_a_key,
@ -2606,7 +2639,7 @@ mod tests {
&withdrawer_key,
initial_a,
initial_b,
withdraw_amount,
withdraw_amount.try_into().unwrap(),
);
let old_pool_fee_account = accounts.pool_fee_account;
let old_pool_fee_key = accounts.pool_fee_key;
@ -2622,7 +2655,7 @@ mod tests {
&mut token_a_account,
&token_b_key,
&mut token_b_account,
withdraw_amount,
withdraw_amount.try_into().unwrap(),
minimum_a_amount,
minimum_b_amount,
),
@ -2640,7 +2673,13 @@ mod tests {
mut token_b_account,
pool_key,
mut pool_account,
) = accounts.setup_token_accounts(&user_key, &withdrawer_key, 0, 0, withdraw_amount);
) = accounts.setup_token_accounts(
&user_key,
&withdrawer_key,
0,
0,
withdraw_amount.try_into().unwrap(),
);
assert_eq!(
Err(TokenError::OwnerMismatch.into()),
do_process_instruction(
@ -2656,7 +2695,7 @@ mod tests {
&accounts.token_b_key,
&token_a_key,
&token_b_key,
withdraw_amount,
withdraw_amount.try_into().unwrap(),
minimum_a_amount,
minimum_b_amount,
)
@ -2691,7 +2730,7 @@ mod tests {
&withdrawer_key,
initial_a,
initial_b,
withdraw_amount,
withdraw_amount.try_into().unwrap(),
);
let wrong_key = Pubkey::new_unique();
assert_eq!(
@ -2709,7 +2748,7 @@ mod tests {
&accounts.token_b_key,
&token_a_key,
&token_b_key,
withdraw_amount,
withdraw_amount.try_into().unwrap(),
minimum_a_amount,
minimum_b_amount,
)
@ -2744,7 +2783,7 @@ mod tests {
&withdrawer_key,
initial_a,
initial_b,
initial_pool,
initial_pool.try_into().unwrap(),
);
let old_a_key = accounts.token_a_key;
@ -2764,7 +2803,7 @@ mod tests {
&mut token_a_account,
&token_b_key,
&mut token_b_account,
withdraw_amount,
withdraw_amount.try_into().unwrap(),
minimum_a_amount,
minimum_b_amount,
)
@ -2790,7 +2829,7 @@ mod tests {
&mut token_a_account,
&token_b_key,
&mut token_b_account,
withdraw_amount,
withdraw_amount.try_into().unwrap(),
minimum_a_amount,
minimum_b_amount,
)
@ -2814,7 +2853,7 @@ mod tests {
&withdrawer_key,
initial_a,
initial_b,
initial_pool,
initial_pool.try_into().unwrap(),
);
let (pool_mint_key, pool_mint_account) =
create_mint(&TOKEN_PROGRAM_ID, &accounts.authority_key, None);
@ -2833,7 +2872,7 @@ mod tests {
&mut token_a_account,
&token_b_key,
&mut token_b_account,
withdraw_amount,
withdraw_amount.try_into().unwrap(),
minimum_a_amount,
minimum_b_amount,
)
@ -2857,7 +2896,7 @@ mod tests {
&withdrawer_key,
initial_a,
initial_b,
initial_pool,
initial_pool.try_into().unwrap(),
);
assert_eq!(
Err(SwapError::ZeroTradingTokens.into()),
@ -2890,7 +2929,7 @@ mod tests {
&withdrawer_key,
initial_a,
initial_b,
initial_pool,
initial_pool.try_into().unwrap(),
);
// minimum A amount out too high
assert_eq!(
@ -2903,7 +2942,7 @@ mod tests {
&mut token_a_account,
&token_b_key,
&mut token_b_account,
withdraw_amount,
withdraw_amount.try_into().unwrap(),
minimum_a_amount * 10,
minimum_b_amount,
)
@ -2919,7 +2958,7 @@ mod tests {
&mut token_a_account,
&token_b_key,
&mut token_b_account,
withdraw_amount,
withdraw_amount.try_into().unwrap(),
minimum_a_amount,
minimum_b_amount * 10,
)
@ -2940,7 +2979,7 @@ mod tests {
&withdrawer_key,
initial_a,
initial_b,
initial_pool,
initial_pool.try_into().unwrap(),
);
let swap_token_a_key = accounts.token_a_key;
let mut swap_token_a_account = accounts.get_token_account(&swap_token_a_key).clone();
@ -2954,7 +2993,7 @@ mod tests {
&mut swap_token_a_account,
&token_b_key,
&mut token_b_account,
withdraw_amount,
withdraw_amount.try_into().unwrap(),
minimum_a_amount,
minimum_b_amount,
)
@ -2971,7 +3010,7 @@ mod tests {
&mut token_a_account,
&swap_token_b_key,
&mut swap_token_b_account,
withdraw_amount,
withdraw_amount.try_into().unwrap(),
minimum_a_amount,
minimum_b_amount,
)
@ -2992,7 +3031,7 @@ mod tests {
&withdrawer_key,
initial_a,
initial_b,
initial_pool,
initial_pool.try_into().unwrap(),
);
accounts
@ -3004,7 +3043,7 @@ mod tests {
&mut token_a_account,
&token_b_key,
&mut token_b_account,
withdraw_amount,
withdraw_amount.try_into().unwrap(),
minimum_a_amount,
minimum_b_amount,
)
@ -3025,30 +3064,39 @@ mod tests {
.calculator
.pool_tokens_to_trading_tokens(
withdraw_amount - withdraw_fee,
pool_mint.supply,
swap_token_a.amount,
pool_mint.supply.try_into().unwrap(),
swap_token_a.amount.try_into().unwrap(),
)
.unwrap();
assert_eq!(swap_token_a.amount, token_a_amount - withdrawn_a);
assert_eq!(
swap_token_a.amount,
token_a_amount - to_u64(withdrawn_a).unwrap()
);
let withdrawn_b = accounts
.swap_curve
.calculator
.pool_tokens_to_trading_tokens(
withdraw_amount - withdraw_fee,
pool_mint.supply,
swap_token_b.amount,
pool_mint.supply.try_into().unwrap(),
swap_token_b.amount.try_into().unwrap(),
)
.unwrap();
assert_eq!(swap_token_b.amount, token_b_amount - withdrawn_b);
assert_eq!(
swap_token_b.amount,
token_b_amount - to_u64(withdrawn_b).unwrap()
);
let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap();
assert_eq!(token_a.amount, initial_a + withdrawn_a);
assert_eq!(token_a.amount, initial_a + to_u64(withdrawn_a).unwrap());
let token_b = Processor::unpack_token_account(&token_b_account.data).unwrap();
assert_eq!(token_b.amount, initial_b + withdrawn_b);
assert_eq!(token_b.amount, initial_b + to_u64(withdrawn_b).unwrap());
let pool_account = Processor::unpack_token_account(&pool_account.data).unwrap();
assert_eq!(pool_account.amount, initial_pool - withdraw_amount);
assert_eq!(
pool_account.amount,
to_u64(initial_pool - withdraw_amount).unwrap()
);
let fee_account =
Processor::unpack_token_account(&accounts.pool_fee_account.data).unwrap();
assert_eq!(fee_account.amount, withdraw_fee);
assert_eq!(fee_account.amount, withdraw_fee.try_into().unwrap());
}
// correct withdrawal from fee account
@ -3091,32 +3139,32 @@ mod tests {
.swap_curve
.calculator
.pool_tokens_to_trading_tokens(
pool_fee_amount,
pool_mint.supply,
swap_token_a.amount,
pool_fee_amount.try_into().unwrap(),
pool_mint.supply.try_into().unwrap(),
swap_token_a.amount.try_into().unwrap(),
)
.unwrap();
let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap();
assert_eq!(token_a.amount, withdrawn_a);
assert_eq!(token_a.amount, withdrawn_a.try_into().unwrap());
let withdrawn_b = accounts
.swap_curve
.calculator
.pool_tokens_to_trading_tokens(
pool_fee_amount,
pool_mint.supply,
swap_token_b.amount,
pool_fee_amount.try_into().unwrap(),
pool_mint.supply.try_into().unwrap(),
swap_token_b.amount.try_into().unwrap(),
)
.unwrap();
let token_b = Processor::unpack_token_account(&token_b_account.data).unwrap();
assert_eq!(token_b.amount, withdrawn_b);
assert_eq!(token_b.amount, withdrawn_b.try_into().unwrap());
}
}
fn check_valid_swap_curve(curve_type: CurveType, calculator: Box<dyn CurveCalculator>) {
let user_key = Pubkey::new_unique();
let swapper_key = Pubkey::new_unique();
let token_a_amount = 1000;
let token_b_amount = 5000;
let token_a_amount = 10_000_000_000u64;
let token_b_amount = 50_000_000_000u64;
let swap_curve = SwapCurve {
curve_type,
@ -3165,32 +3213,45 @@ mod tests {
let results = swap_curve
.calculator
.swap(a_to_b_amount, token_a_amount, token_b_amount)
.swap(
a_to_b_amount.try_into().unwrap(),
token_a_amount.try_into().unwrap(),
token_b_amount.try_into().unwrap(),
)
.unwrap();
let swap_token_a = Processor::unpack_token_account(&accounts.token_a_account.data).unwrap();
let token_a_amount = swap_token_a.amount;
assert_eq!(token_a_amount, results.new_source_amount);
assert_eq!(
token_a_amount,
results.new_source_amount.try_into().unwrap()
);
let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap();
assert_eq!(token_a.amount, initial_a - a_to_b_amount);
let swap_token_b = Processor::unpack_token_account(&accounts.token_b_account.data).unwrap();
let token_b_amount = swap_token_b.amount;
assert_eq!(token_b_amount, results.new_destination_amount);
assert_eq!(
token_b_amount,
results.new_destination_amount.try_into().unwrap()
);
let token_b = Processor::unpack_token_account(&token_b_account.data).unwrap();
assert_eq!(token_b.amount, initial_b + results.amount_swapped);
assert_eq!(
token_b.amount,
initial_b + to_u64(results.amount_swapped).unwrap()
);
let first_fee = swap_curve
.calculator
.owner_fee_to_pool_tokens(
results.owner_fee,
token_a_amount,
initial_supply,
TOKENS_IN_POOL,
token_a_amount.try_into().unwrap(),
initial_supply.try_into().unwrap(),
TOKENS_IN_POOL.try_into().unwrap(),
)
.unwrap();
let fee_account = Processor::unpack_token_account(&accounts.pool_fee_account.data).unwrap();
assert_eq!(fee_account.amount, first_fee);
assert_eq!(fee_account.amount, first_fee.try_into().unwrap());
let first_swap_amount = results.amount_swapped;
@ -3216,38 +3277,48 @@ mod tests {
let results = swap_curve
.calculator
.swap(b_to_a_amount, token_b_amount, token_a_amount)
.swap(
b_to_a_amount.try_into().unwrap(),
token_b_amount.try_into().unwrap(),
token_a_amount.try_into().unwrap(),
)
.unwrap();
let swap_token_a = Processor::unpack_token_account(&accounts.token_a_account.data).unwrap();
let token_a_amount = swap_token_a.amount;
assert_eq!(token_a_amount, results.new_destination_amount);
assert_eq!(
token_a_amount,
results.new_destination_amount.try_into().unwrap()
);
let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap();
assert_eq!(
token_a.amount,
initial_a - a_to_b_amount + results.amount_swapped
initial_a - a_to_b_amount + to_u64(results.amount_swapped).unwrap()
);
let swap_token_b = Processor::unpack_token_account(&accounts.token_b_account.data).unwrap();
let token_b_amount = swap_token_b.amount;
assert_eq!(token_b_amount, results.new_source_amount);
assert_eq!(
token_b_amount,
results.new_source_amount.try_into().unwrap()
);
let token_b = Processor::unpack_token_account(&token_b_account.data).unwrap();
assert_eq!(
token_b.amount,
initial_b + first_swap_amount - b_to_a_amount
initial_b + to_u64(first_swap_amount).unwrap() - b_to_a_amount
);
let second_fee = swap_curve
.calculator
.owner_fee_to_pool_tokens(
results.owner_fee,
token_b_amount,
initial_supply,
TOKENS_IN_POOL,
token_b_amount.try_into().unwrap(),
initial_supply.try_into().unwrap(),
TOKENS_IN_POOL.try_into().unwrap(),
)
.unwrap();
let fee_account = Processor::unpack_token_account(&accounts.pool_fee_account.data).unwrap();
assert_eq!(fee_account.amount, first_fee + second_fee);
assert_eq!(fee_account.amount, to_u64(first_fee + second_fee).unwrap());
}
#[test]