Step 2: Calculate the expected transfer fee

This commit is contained in:
Jon Cinque 2022-08-24 19:16:29 +02:00
parent 9abaf9c90d
commit bb468d2ce7
1 changed files with 173 additions and 47 deletions

View File

@ -17,6 +17,7 @@ use crate::{
use num_traits::FromPrimitive; use num_traits::FromPrimitive;
use solana_program::{ use solana_program::{
account_info::{next_account_info, AccountInfo}, account_info::{next_account_info, AccountInfo},
clock::Clock,
decode_error::DecodeError, decode_error::DecodeError,
entrypoint::ProgramResult, entrypoint::ProgramResult,
instruction::Instruction, instruction::Instruction,
@ -25,11 +26,12 @@ use solana_program::{
program_error::{PrintProgramError, ProgramError}, program_error::{PrintProgramError, ProgramError},
program_option::COption, program_option::COption,
pubkey::Pubkey, pubkey::Pubkey,
sysvar::Sysvar,
}; };
use spl_token_2022::{ use spl_token_2022::{
check_spl_token_program_account, check_spl_token_program_account,
error::TokenError, error::TokenError,
extension::StateWithExtensions, extension::{transfer_fee::TransferFeeConfig, StateWithExtensions},
state::{Account, Mint}, state::{Account, Mint},
}; };
use std::{convert::TryInto, error::Error}; use std::{convert::TryInto, error::Error};
@ -69,6 +71,19 @@ impl Processor {
} }
} }
/// Unpacks a spl_token `Mint` with extension data
pub fn unpack_mint_with_extensions<'a>(
account_data: &'a [u8],
owner: &Pubkey,
token_program_id: &Pubkey,
) -> Result<StateWithExtensions<'a, Mint>, SwapError> {
if owner != token_program_id && check_spl_token_program_account(owner).is_err() {
Err(SwapError::IncorrectTokenProgramId)
} else {
StateWithExtensions::<Mint>::unpack(account_data).map_err(|_| SwapError::ExpectedMint)
}
}
/// Calculates the authority id by generating a program address. /// Calculates the authority id by generating a program address.
pub fn authority_id( pub fn authority_id(
program_id: &Pubkey, program_id: &Pubkey,
@ -412,6 +427,27 @@ impl Processor {
Self::unpack_token_account(swap_destination_info, token_swap.token_program_id())?; Self::unpack_token_account(swap_destination_info, token_swap.token_program_id())?;
let pool_mint = Self::unpack_mint(pool_mint_info, token_swap.token_program_id())?; let pool_mint = Self::unpack_mint(pool_mint_info, token_swap.token_program_id())?;
// Take transfer fees into account for actual amount transferred in
let actual_amount_in = {
let source_mint_data = source_token_mint_info.data.borrow();
let source_mint = Self::unpack_mint_with_extensions(
&source_mint_data,
source_token_mint_info.owner,
token_swap.token_program_id(),
)?;
if let Ok(transfer_fee_config) = source_mint.get_extension::<TransferFeeConfig>() {
amount_in.saturating_sub(
transfer_fee_config
.calculate_epoch_fee(Clock::get()?.epoch, amount_in)
.ok_or(SwapError::FeeCalculationFailure)?,
)
} else {
amount_in
}
};
// Calculate the trade amounts
let trade_direction = if *swap_source_info.key == *token_swap.token_a_account() { let trade_direction = if *swap_source_info.key == *token_swap.token_a_account() {
TradeDirection::AtoB TradeDirection::AtoB
} else { } else {
@ -420,16 +456,61 @@ impl Processor {
let result = token_swap let result = token_swap
.swap_curve() .swap_curve()
.swap( .swap(
to_u128(amount_in)?, to_u128(actual_amount_in)?,
to_u128(source_account.amount)?, to_u128(source_account.amount)?,
to_u128(dest_account.amount)?, to_u128(dest_account.amount)?,
trade_direction, trade_direction,
token_swap.fees(), token_swap.fees(),
) )
.ok_or(SwapError::ZeroTradingTokens)?; .ok_or(SwapError::ZeroTradingTokens)?;
if result.destination_amount_swapped < to_u128(minimum_amount_out)? {
return Err(SwapError::ExceededSlippage.into()); // Re-calculate the source amount swapped based on what the curve says
} let (source_transfer_amount, source_mint_decimals) = {
let source_amount_swapped = to_u64(result.source_amount_swapped)?;
let source_mint_data = source_token_mint_info.data.borrow();
let source_mint = Self::unpack_mint_with_extensions(
&source_mint_data,
source_token_mint_info.owner,
token_swap.token_program_id(),
)?;
let amount =
if let Ok(transfer_fee_config) = source_mint.get_extension::<TransferFeeConfig>() {
source_amount_swapped.saturating_add(
transfer_fee_config
.calculate_inverse_epoch_fee(Clock::get()?.epoch, source_amount_swapped)
.ok_or(SwapError::FeeCalculationFailure)?,
)
} else {
source_amount_swapped
};
(amount, source_mint.base.decimals)
};
let (destination_transfer_amount, destination_mint_decimals) = {
let destination_mint_data = destination_token_mint_info.data.borrow();
let destination_mint = Self::unpack_mint_with_extensions(
&destination_mint_data,
source_token_mint_info.owner,
token_swap.token_program_id(),
)?;
let amount_out = to_u64(result.destination_amount_swapped)?;
let amount_received = if let Ok(transfer_fee_config) =
destination_mint.get_extension::<TransferFeeConfig>()
{
amount_out.saturating_sub(
transfer_fee_config
.calculate_epoch_fee(Clock::get()?.epoch, amount_out)
.ok_or(SwapError::FeeCalculationFailure)?,
)
} else {
amount_out
};
if amount_received < minimum_amount_out {
return Err(SwapError::ExceededSlippage.into());
}
(amount_out, destination_mint.base.decimals)
};
let (swap_token_a_amount, swap_token_b_amount) = match trade_direction { let (swap_token_a_amount, swap_token_b_amount) = match trade_direction {
TradeDirection::AtoB => ( TradeDirection::AtoB => (
@ -450,8 +531,8 @@ impl Processor {
swap_source_info.clone(), swap_source_info.clone(),
user_transfer_authority_info.clone(), user_transfer_authority_info.clone(),
token_swap.bump_seed(), token_swap.bump_seed(),
to_u64(result.source_amount_swapped)?, source_transfer_amount,
Self::unpack_mint(source_token_mint_info, token_swap.token_program_id())?.decimals, source_mint_decimals,
)?; )?;
let mut pool_token_amount = token_swap let mut pool_token_amount = token_swap
@ -519,8 +600,8 @@ impl Processor {
destination_info.clone(), destination_info.clone(),
authority_info.clone(), authority_info.clone(),
token_swap.bump_seed(), token_swap.bump_seed(),
to_u64(result.destination_amount_swapped)?, destination_transfer_amount,
Self::unpack_mint(destination_token_mint_info, token_swap.token_program_id())?.decimals, destination_mint_decimals,
)?; )?;
Ok(()) Ok(())
@ -1194,7 +1275,10 @@ mod tests {
}; };
use spl_token_2022::{ use spl_token_2022::{
error::TokenError, error::TokenError,
extension::{transfer_fee::instruction::initialize_transfer_fee_config, ExtensionType}, extension::{
transfer_fee::{instruction::initialize_transfer_fee_config, TransferFee},
ExtensionType,
},
instruction::{ instruction::{
approve, close_account, freeze_account, initialize_account, initialize_immutable_owner, approve, close_account, freeze_account, initialize_account, initialize_immutable_owner,
initialize_mint, initialize_mint_close_authority, mint_to, revoke, set_authority, initialize_mint, initialize_mint_close_authority, mint_to, revoke, set_authority,
@ -1277,17 +1361,11 @@ mod tests {
}); });
} }
#[derive(Default)]
struct TransferFees {
transfer_fee_basis_points: u16,
maximum_transfer_fee: u64,
}
#[derive(Default)] #[derive(Default)]
struct SwapTransferFees { struct SwapTransferFees {
pool_token: TransferFees, pool_token: TransferFee,
token_a: TransferFees, token_a: TransferFee,
token_b: TransferFees, token_b: TransferFee,
} }
struct SwapAccountInfo { struct SwapAccountInfo {
@ -1335,8 +1413,12 @@ mod tests {
let (authority_key, bump_seed) = let (authority_key, bump_seed) =
Pubkey::find_program_address(&[&swap_key.to_bytes()[..]], &SWAP_PROGRAM_ID); Pubkey::find_program_address(&[&swap_key.to_bytes()[..]], &SWAP_PROGRAM_ID);
let (pool_mint_key, mut pool_mint_account) = let (pool_mint_key, mut pool_mint_account) = create_mint(
create_mint(pool_token_program_id, &authority_key, None, &transfer_fees.pool_token); pool_token_program_id,
&authority_key,
None,
&transfer_fees.pool_token,
);
let (pool_token_key, pool_token_account) = mint_token( let (pool_token_key, pool_token_account) = mint_token(
pool_token_program_id, pool_token_program_id,
&pool_mint_key, &pool_mint_key,
@ -2101,7 +2183,7 @@ mod tests {
program_id: &Pubkey, program_id: &Pubkey,
authority_key: &Pubkey, authority_key: &Pubkey,
freeze_authority: Option<&Pubkey>, freeze_authority: Option<&Pubkey>,
fees: &TransferFees, fees: &TransferFee,
) -> (Pubkey, SolanaAccount) { ) -> (Pubkey, SolanaAccount) {
let mint_key = Pubkey::new_unique(); let mint_key = Pubkey::new_unique();
let space = if *program_id == spl_token_2022::id() { let space = if *program_id == spl_token_2022::id() {
@ -2128,8 +2210,8 @@ mod tests {
&mint_key, &mint_key,
freeze_authority, freeze_authority,
freeze_authority, freeze_authority,
fees.transfer_fee_basis_points, fees.transfer_fee_basis_points.into(),
fees.maximum_transfer_fee, fees.maximum_fee.into(),
) )
.unwrap(), .unwrap(),
vec![&mut mint_account], vec![&mut mint_account],
@ -2430,8 +2512,12 @@ mod tests {
// pool mint authority is not swap authority // pool mint authority is not swap authority
{ {
let (_pool_mint_key, pool_mint_account) = let (_pool_mint_key, pool_mint_account) = create_mint(
create_mint(&pool_token_program_id, &user_key, None, &TransferFees::default()); &pool_token_program_id,
&user_key,
None,
&TransferFee::default(),
);
let old_mint = accounts.pool_mint_account; let old_mint = accounts.pool_mint_account;
accounts.pool_mint_account = pool_mint_account; accounts.pool_mint_account = pool_mint_account;
assert_eq!( assert_eq!(
@ -2446,7 +2532,8 @@ mod tests {
let (_pool_mint_key, pool_mint_account) = create_mint( let (_pool_mint_key, pool_mint_account) = create_mint(
&pool_token_program_id, &pool_token_program_id,
&accounts.authority_key, &accounts.authority_key,
Some(&user_key), &TransferFees::default(), Some(&user_key),
&TransferFee::default(),
); );
let old_mint = accounts.pool_mint_account; let old_mint = accounts.pool_mint_account;
accounts.pool_mint_account = pool_mint_account; accounts.pool_mint_account = pool_mint_account;
@ -2540,8 +2627,12 @@ mod tests {
let old_mint = accounts.pool_mint_account; let old_mint = accounts.pool_mint_account;
let old_pool_account = accounts.pool_token_account; let old_pool_account = accounts.pool_token_account;
let (_pool_mint_key, pool_mint_account) = let (_pool_mint_key, pool_mint_account) = create_mint(
create_mint(&pool_token_program_id, &accounts.authority_key, None, &TransferFees::default()); &pool_token_program_id,
&accounts.authority_key,
None,
&TransferFee::default(),
);
accounts.pool_mint_account = pool_mint_account; accounts.pool_mint_account = pool_mint_account;
let (_empty_pool_token_key, empty_pool_token_account) = mint_token( let (_empty_pool_token_key, empty_pool_token_account) = mint_token(
@ -3660,8 +3751,12 @@ mod tests {
pool_key, pool_key,
mut pool_account, mut pool_account,
) = accounts.setup_token_accounts(&user_key, &depositor_key, deposit_a, deposit_b, 0); ) = accounts.setup_token_accounts(&user_key, &depositor_key, deposit_a, deposit_b, 0);
let (pool_mint_key, pool_mint_account) = let (pool_mint_key, pool_mint_account) = create_mint(
create_mint(&pool_token_program_id, &accounts.authority_key, None, &TransferFees::default()); &pool_token_program_id,
&accounts.authority_key,
None,
&TransferFee::default(),
);
let old_pool_key = accounts.pool_mint_key; let old_pool_key = accounts.pool_mint_key;
let old_pool_account = accounts.pool_mint_account; let old_pool_account = accounts.pool_mint_account;
accounts.pool_mint_key = pool_mint_key; accounts.pool_mint_key = pool_mint_key;
@ -4382,8 +4477,12 @@ mod tests {
initial_b, initial_b,
initial_pool.try_into().unwrap(), initial_pool.try_into().unwrap(),
); );
let (pool_mint_key, pool_mint_account) = let (pool_mint_key, pool_mint_account) = create_mint(
create_mint(&pool_token_program_id, &accounts.authority_key, None, &TransferFees::default()); &pool_token_program_id,
&accounts.authority_key,
None,
&TransferFee::default(),
);
let old_pool_key = accounts.pool_mint_key; let old_pool_key = accounts.pool_mint_key;
let old_pool_account = accounts.pool_mint_account; let old_pool_account = accounts.pool_mint_account;
accounts.pool_mint_key = pool_mint_key; accounts.pool_mint_key = pool_mint_key;
@ -5067,8 +5166,12 @@ mod tests {
pool_key, pool_key,
mut pool_account, mut pool_account,
) = accounts.setup_token_accounts(&user_key, &depositor_key, deposit_a, deposit_b, 0); ) = accounts.setup_token_accounts(&user_key, &depositor_key, deposit_a, deposit_b, 0);
let (pool_mint_key, pool_mint_account) = let (pool_mint_key, pool_mint_account) = create_mint(
create_mint(&pool_token_program_id, &accounts.authority_key, None, &TransferFees::default()); &pool_token_program_id,
&accounts.authority_key,
None,
&TransferFee::default(),
);
let old_pool_key = accounts.pool_mint_key; let old_pool_key = accounts.pool_mint_key;
let old_pool_account = accounts.pool_mint_account; let old_pool_account = accounts.pool_mint_account;
accounts.pool_mint_key = pool_mint_key; accounts.pool_mint_key = pool_mint_key;
@ -5684,8 +5787,12 @@ mod tests {
initial_b, initial_b,
initial_pool.try_into().unwrap(), initial_pool.try_into().unwrap(),
); );
let (pool_mint_key, pool_mint_account) = let (pool_mint_key, pool_mint_account) = create_mint(
create_mint(&pool_token_program_id, &accounts.authority_key, None, &TransferFees::default()); &pool_token_program_id,
&accounts.authority_key,
None,
&TransferFee::default(),
);
let old_pool_key = accounts.pool_mint_key; let old_pool_key = accounts.pool_mint_key;
let old_pool_account = accounts.pool_mint_account; let old_pool_account = accounts.pool_mint_account;
accounts.pool_mint_key = pool_mint_key; accounts.pool_mint_key = pool_mint_key;
@ -5976,9 +6083,16 @@ mod tests {
) )
.unwrap(); .unwrap();
// tweak values based on transfer fees assessed
let token_a_fee = accounts
.transfer_fees
.token_a
.calculate_fee(a_to_b_amount)
.unwrap();
let actual_a_to_b_amount = a_to_b_amount - token_a_fee;
let results = swap_curve let results = swap_curve
.swap( .swap(
a_to_b_amount.try_into().unwrap(), actual_a_to_b_amount.try_into().unwrap(),
token_a_amount.try_into().unwrap(), token_a_amount.try_into().unwrap(),
token_b_amount.try_into().unwrap(), token_b_amount.try_into().unwrap(),
TradeDirection::AtoB, TradeDirection::AtoB,
@ -6049,7 +6163,7 @@ mod tests {
) )
.unwrap(); .unwrap();
let results = swap_curve let mut results = swap_curve
.swap( .swap(
b_to_a_amount.try_into().unwrap(), b_to_a_amount.try_into().unwrap(),
token_b_amount.try_into().unwrap(), token_b_amount.try_into().unwrap(),
@ -6058,6 +6172,13 @@ mod tests {
&fees, &fees,
) )
.unwrap(); .unwrap();
// tweak values based on transfer fees assessed
let token_a_fee = accounts
.transfer_fees
.token_a
.calculate_fee(results.destination_amount_swapped.try_into().unwrap())
.unwrap();
results.destination_amount_swapped -= token_a_fee as u128;
let swap_token_a = let swap_token_a =
StateWithExtensions::<Account>::unpack(&accounts.token_a_account.data).unwrap(); StateWithExtensions::<Account>::unpack(&accounts.token_a_account.data).unwrap();
@ -6758,8 +6879,12 @@ mod tests {
_pool_key, _pool_key,
_pool_account, _pool_account,
) = accounts.setup_token_accounts(&user_key, &swapper_key, initial_a, initial_b, 0); ) = accounts.setup_token_accounts(&user_key, &swapper_key, initial_a, initial_b, 0);
let (pool_mint_key, pool_mint_account) = let (pool_mint_key, pool_mint_account) = create_mint(
create_mint(&pool_token_program_id, &accounts.authority_key, None, &TransferFees::default()); &pool_token_program_id,
&accounts.authority_key,
None,
&TransferFee::default(),
);
let old_pool_key = accounts.pool_mint_key; let old_pool_key = accounts.pool_mint_key;
let old_pool_account = accounts.pool_mint_account; let old_pool_account = accounts.pool_mint_account;
accounts.pool_mint_key = pool_mint_key; accounts.pool_mint_key = pool_mint_key;
@ -8149,14 +8274,15 @@ mod tests {
let token_b_amount = 50_000_000_000; let token_b_amount = 50_000_000_000;
check_valid_swap_curve( check_valid_swap_curve(
fees.clone(), fees,
SwapTransferFees { SwapTransferFees {
pool_token: TransferFees::default(), pool_token: TransferFee::default(),
token_a: TransferFees { token_a: TransferFee {
transfer_fee_basis_points: 100, epoch: 0.into(),
maximum_transfer_fee: 1_000_000_000, transfer_fee_basis_points: 100.into(),
maximum_fee: 1_000_000_000.into(),
}, },
token_b: TransferFees::default(), token_b: TransferFee::default(),
}, },
CurveType::ConstantProduct, CurveType::ConstantProduct,
Arc::new(ConstantProductCurve {}), Arc::new(ConstantProductCurve {}),