From bb468d2ce7a80b48949f92e4fe1e8a811e6459a3 Mon Sep 17 00:00:00 2001 From: Jon Cinque Date: Wed, 24 Aug 2022 19:16:29 +0200 Subject: [PATCH] Step 2: Calculate the expected transfer fee --- token-swap/program/src/processor.rs | 220 ++++++++++++++++++++++------ 1 file changed, 173 insertions(+), 47 deletions(-) diff --git a/token-swap/program/src/processor.rs b/token-swap/program/src/processor.rs index b7c2c9d4..60ba99d2 100644 --- a/token-swap/program/src/processor.rs +++ b/token-swap/program/src/processor.rs @@ -17,6 +17,7 @@ use crate::{ use num_traits::FromPrimitive; use solana_program::{ account_info::{next_account_info, AccountInfo}, + clock::Clock, decode_error::DecodeError, entrypoint::ProgramResult, instruction::Instruction, @@ -25,11 +26,12 @@ use solana_program::{ program_error::{PrintProgramError, ProgramError}, program_option::COption, pubkey::Pubkey, + sysvar::Sysvar, }; use spl_token_2022::{ check_spl_token_program_account, error::TokenError, - extension::StateWithExtensions, + extension::{transfer_fee::TransferFeeConfig, StateWithExtensions}, state::{Account, Mint}, }; use std::{convert::TryInto, error::Error}; @@ -69,6 +71,19 @@ impl Processor { } } + /// Unpacks a spl_token `Mint` with extension data + pub fn unpack_mint_with_extensions<'a>( + account_data: &'a [u8], + owner: &Pubkey, + token_program_id: &Pubkey, + ) -> Result, SwapError> { + if owner != token_program_id && check_spl_token_program_account(owner).is_err() { + Err(SwapError::IncorrectTokenProgramId) + } else { + StateWithExtensions::::unpack(account_data).map_err(|_| SwapError::ExpectedMint) + } + } + /// Calculates the authority id by generating a program address. pub fn authority_id( program_id: &Pubkey, @@ -412,6 +427,27 @@ impl Processor { Self::unpack_token_account(swap_destination_info, token_swap.token_program_id())?; let pool_mint = Self::unpack_mint(pool_mint_info, token_swap.token_program_id())?; + // Take transfer fees into account for actual amount transferred in + let actual_amount_in = { + let source_mint_data = source_token_mint_info.data.borrow(); + let source_mint = Self::unpack_mint_with_extensions( + &source_mint_data, + source_token_mint_info.owner, + token_swap.token_program_id(), + )?; + + if let Ok(transfer_fee_config) = source_mint.get_extension::() { + amount_in.saturating_sub( + transfer_fee_config + .calculate_epoch_fee(Clock::get()?.epoch, amount_in) + .ok_or(SwapError::FeeCalculationFailure)?, + ) + } else { + amount_in + } + }; + + // Calculate the trade amounts let trade_direction = if *swap_source_info.key == *token_swap.token_a_account() { TradeDirection::AtoB } else { @@ -420,16 +456,61 @@ impl Processor { let result = token_swap .swap_curve() .swap( - to_u128(amount_in)?, + to_u128(actual_amount_in)?, to_u128(source_account.amount)?, to_u128(dest_account.amount)?, trade_direction, token_swap.fees(), ) .ok_or(SwapError::ZeroTradingTokens)?; - if result.destination_amount_swapped < to_u128(minimum_amount_out)? { - return Err(SwapError::ExceededSlippage.into()); - } + + // Re-calculate the source amount swapped based on what the curve says + let (source_transfer_amount, source_mint_decimals) = { + let source_amount_swapped = to_u64(result.source_amount_swapped)?; + + let source_mint_data = source_token_mint_info.data.borrow(); + let source_mint = Self::unpack_mint_with_extensions( + &source_mint_data, + source_token_mint_info.owner, + token_swap.token_program_id(), + )?; + let amount = + if let Ok(transfer_fee_config) = source_mint.get_extension::() { + source_amount_swapped.saturating_add( + transfer_fee_config + .calculate_inverse_epoch_fee(Clock::get()?.epoch, source_amount_swapped) + .ok_or(SwapError::FeeCalculationFailure)?, + ) + } else { + source_amount_swapped + }; + (amount, source_mint.base.decimals) + }; + + let (destination_transfer_amount, destination_mint_decimals) = { + let destination_mint_data = destination_token_mint_info.data.borrow(); + let destination_mint = Self::unpack_mint_with_extensions( + &destination_mint_data, + source_token_mint_info.owner, + token_swap.token_program_id(), + )?; + let amount_out = to_u64(result.destination_amount_swapped)?; + let amount_received = if let Ok(transfer_fee_config) = + destination_mint.get_extension::() + { + amount_out.saturating_sub( + transfer_fee_config + .calculate_epoch_fee(Clock::get()?.epoch, amount_out) + .ok_or(SwapError::FeeCalculationFailure)?, + ) + } else { + amount_out + }; + if amount_received < minimum_amount_out { + return Err(SwapError::ExceededSlippage.into()); + } + (amount_out, destination_mint.base.decimals) + }; let (swap_token_a_amount, swap_token_b_amount) = match trade_direction { TradeDirection::AtoB => ( @@ -450,8 +531,8 @@ impl Processor { swap_source_info.clone(), user_transfer_authority_info.clone(), token_swap.bump_seed(), - to_u64(result.source_amount_swapped)?, - Self::unpack_mint(source_token_mint_info, token_swap.token_program_id())?.decimals, + source_transfer_amount, + source_mint_decimals, )?; let mut pool_token_amount = token_swap @@ -519,8 +600,8 @@ impl Processor { destination_info.clone(), authority_info.clone(), token_swap.bump_seed(), - to_u64(result.destination_amount_swapped)?, - Self::unpack_mint(destination_token_mint_info, token_swap.token_program_id())?.decimals, + destination_transfer_amount, + destination_mint_decimals, )?; Ok(()) @@ -1194,7 +1275,10 @@ mod tests { }; use spl_token_2022::{ error::TokenError, - extension::{transfer_fee::instruction::initialize_transfer_fee_config, ExtensionType}, + extension::{ + transfer_fee::{instruction::initialize_transfer_fee_config, TransferFee}, + ExtensionType, + }, instruction::{ approve, close_account, freeze_account, initialize_account, initialize_immutable_owner, initialize_mint, initialize_mint_close_authority, mint_to, revoke, set_authority, @@ -1277,17 +1361,11 @@ mod tests { }); } - #[derive(Default)] - struct TransferFees { - transfer_fee_basis_points: u16, - maximum_transfer_fee: u64, - } - #[derive(Default)] struct SwapTransferFees { - pool_token: TransferFees, - token_a: TransferFees, - token_b: TransferFees, + pool_token: TransferFee, + token_a: TransferFee, + token_b: TransferFee, } struct SwapAccountInfo { @@ -1335,8 +1413,12 @@ mod tests { let (authority_key, bump_seed) = Pubkey::find_program_address(&[&swap_key.to_bytes()[..]], &SWAP_PROGRAM_ID); - let (pool_mint_key, mut pool_mint_account) = - create_mint(pool_token_program_id, &authority_key, None, &transfer_fees.pool_token); + let (pool_mint_key, mut pool_mint_account) = create_mint( + pool_token_program_id, + &authority_key, + None, + &transfer_fees.pool_token, + ); let (pool_token_key, pool_token_account) = mint_token( pool_token_program_id, &pool_mint_key, @@ -2101,7 +2183,7 @@ mod tests { program_id: &Pubkey, authority_key: &Pubkey, freeze_authority: Option<&Pubkey>, - fees: &TransferFees, + fees: &TransferFee, ) -> (Pubkey, SolanaAccount) { let mint_key = Pubkey::new_unique(); let space = if *program_id == spl_token_2022::id() { @@ -2128,8 +2210,8 @@ mod tests { &mint_key, freeze_authority, freeze_authority, - fees.transfer_fee_basis_points, - fees.maximum_transfer_fee, + fees.transfer_fee_basis_points.into(), + fees.maximum_fee.into(), ) .unwrap(), vec![&mut mint_account], @@ -2430,8 +2512,12 @@ mod tests { // pool mint authority is not swap authority { - let (_pool_mint_key, pool_mint_account) = - create_mint(&pool_token_program_id, &user_key, None, &TransferFees::default()); + let (_pool_mint_key, pool_mint_account) = create_mint( + &pool_token_program_id, + &user_key, + None, + &TransferFee::default(), + ); let old_mint = accounts.pool_mint_account; accounts.pool_mint_account = pool_mint_account; assert_eq!( @@ -2446,7 +2532,8 @@ mod tests { let (_pool_mint_key, pool_mint_account) = create_mint( &pool_token_program_id, &accounts.authority_key, - Some(&user_key), &TransferFees::default(), + Some(&user_key), + &TransferFee::default(), ); let old_mint = accounts.pool_mint_account; accounts.pool_mint_account = pool_mint_account; @@ -2540,8 +2627,12 @@ mod tests { let old_mint = accounts.pool_mint_account; let old_pool_account = accounts.pool_token_account; - let (_pool_mint_key, pool_mint_account) = - create_mint(&pool_token_program_id, &accounts.authority_key, None, &TransferFees::default()); + let (_pool_mint_key, pool_mint_account) = create_mint( + &pool_token_program_id, + &accounts.authority_key, + None, + &TransferFee::default(), + ); accounts.pool_mint_account = pool_mint_account; let (_empty_pool_token_key, empty_pool_token_account) = mint_token( @@ -3660,8 +3751,12 @@ mod tests { pool_key, mut pool_account, ) = accounts.setup_token_accounts(&user_key, &depositor_key, deposit_a, deposit_b, 0); - let (pool_mint_key, pool_mint_account) = - create_mint(&pool_token_program_id, &accounts.authority_key, None, &TransferFees::default()); + let (pool_mint_key, pool_mint_account) = create_mint( + &pool_token_program_id, + &accounts.authority_key, + None, + &TransferFee::default(), + ); let old_pool_key = accounts.pool_mint_key; let old_pool_account = accounts.pool_mint_account; accounts.pool_mint_key = pool_mint_key; @@ -4382,8 +4477,12 @@ mod tests { initial_b, initial_pool.try_into().unwrap(), ); - let (pool_mint_key, pool_mint_account) = - create_mint(&pool_token_program_id, &accounts.authority_key, None, &TransferFees::default()); + let (pool_mint_key, pool_mint_account) = create_mint( + &pool_token_program_id, + &accounts.authority_key, + None, + &TransferFee::default(), + ); let old_pool_key = accounts.pool_mint_key; let old_pool_account = accounts.pool_mint_account; accounts.pool_mint_key = pool_mint_key; @@ -5067,8 +5166,12 @@ mod tests { pool_key, mut pool_account, ) = accounts.setup_token_accounts(&user_key, &depositor_key, deposit_a, deposit_b, 0); - let (pool_mint_key, pool_mint_account) = - create_mint(&pool_token_program_id, &accounts.authority_key, None, &TransferFees::default()); + let (pool_mint_key, pool_mint_account) = create_mint( + &pool_token_program_id, + &accounts.authority_key, + None, + &TransferFee::default(), + ); let old_pool_key = accounts.pool_mint_key; let old_pool_account = accounts.pool_mint_account; accounts.pool_mint_key = pool_mint_key; @@ -5684,8 +5787,12 @@ mod tests { initial_b, initial_pool.try_into().unwrap(), ); - let (pool_mint_key, pool_mint_account) = - create_mint(&pool_token_program_id, &accounts.authority_key, None, &TransferFees::default()); + let (pool_mint_key, pool_mint_account) = create_mint( + &pool_token_program_id, + &accounts.authority_key, + None, + &TransferFee::default(), + ); let old_pool_key = accounts.pool_mint_key; let old_pool_account = accounts.pool_mint_account; accounts.pool_mint_key = pool_mint_key; @@ -5976,9 +6083,16 @@ mod tests { ) .unwrap(); + // tweak values based on transfer fees assessed + let token_a_fee = accounts + .transfer_fees + .token_a + .calculate_fee(a_to_b_amount) + .unwrap(); + let actual_a_to_b_amount = a_to_b_amount - token_a_fee; let results = swap_curve .swap( - a_to_b_amount.try_into().unwrap(), + actual_a_to_b_amount.try_into().unwrap(), token_a_amount.try_into().unwrap(), token_b_amount.try_into().unwrap(), TradeDirection::AtoB, @@ -6049,7 +6163,7 @@ mod tests { ) .unwrap(); - let results = swap_curve + let mut results = swap_curve .swap( b_to_a_amount.try_into().unwrap(), token_b_amount.try_into().unwrap(), @@ -6058,6 +6172,13 @@ mod tests { &fees, ) .unwrap(); + // tweak values based on transfer fees assessed + let token_a_fee = accounts + .transfer_fees + .token_a + .calculate_fee(results.destination_amount_swapped.try_into().unwrap()) + .unwrap(); + results.destination_amount_swapped -= token_a_fee as u128; let swap_token_a = StateWithExtensions::::unpack(&accounts.token_a_account.data).unwrap(); @@ -6758,8 +6879,12 @@ mod tests { _pool_key, _pool_account, ) = accounts.setup_token_accounts(&user_key, &swapper_key, initial_a, initial_b, 0); - let (pool_mint_key, pool_mint_account) = - create_mint(&pool_token_program_id, &accounts.authority_key, None, &TransferFees::default()); + let (pool_mint_key, pool_mint_account) = create_mint( + &pool_token_program_id, + &accounts.authority_key, + None, + &TransferFee::default(), + ); let old_pool_key = accounts.pool_mint_key; let old_pool_account = accounts.pool_mint_account; accounts.pool_mint_key = pool_mint_key; @@ -8149,14 +8274,15 @@ mod tests { let token_b_amount = 50_000_000_000; check_valid_swap_curve( - fees.clone(), + fees, SwapTransferFees { - pool_token: TransferFees::default(), - token_a: TransferFees { - transfer_fee_basis_points: 100, - maximum_transfer_fee: 1_000_000_000, + pool_token: TransferFee::default(), + token_a: TransferFee { + epoch: 0.into(), + transfer_fee_basis_points: 100.into(), + maximum_fee: 1_000_000_000.into(), }, - token_b: TransferFees::default(), + token_b: TransferFee::default(), }, CurveType::ConstantProduct, Arc::new(ConstantProductCurve {}),