Step 2: Calculate the expected transfer fee

This commit is contained in:
Jon Cinque 2022-08-24 19:16:29 +02:00
parent 9abaf9c90d
commit bb468d2ce7
1 changed files with 173 additions and 47 deletions

View File

@ -17,6 +17,7 @@ use crate::{
use num_traits::FromPrimitive;
use solana_program::{
account_info::{next_account_info, AccountInfo},
clock::Clock,
decode_error::DecodeError,
entrypoint::ProgramResult,
instruction::Instruction,
@ -25,11 +26,12 @@ use solana_program::{
program_error::{PrintProgramError, ProgramError},
program_option::COption,
pubkey::Pubkey,
sysvar::Sysvar,
};
use spl_token_2022::{
check_spl_token_program_account,
error::TokenError,
extension::StateWithExtensions,
extension::{transfer_fee::TransferFeeConfig, StateWithExtensions},
state::{Account, Mint},
};
use std::{convert::TryInto, error::Error};
@ -69,6 +71,19 @@ impl Processor {
}
}
/// Unpacks a spl_token `Mint` with extension data
pub fn unpack_mint_with_extensions<'a>(
account_data: &'a [u8],
owner: &Pubkey,
token_program_id: &Pubkey,
) -> Result<StateWithExtensions<'a, Mint>, SwapError> {
if owner != token_program_id && check_spl_token_program_account(owner).is_err() {
Err(SwapError::IncorrectTokenProgramId)
} else {
StateWithExtensions::<Mint>::unpack(account_data).map_err(|_| SwapError::ExpectedMint)
}
}
/// Calculates the authority id by generating a program address.
pub fn authority_id(
program_id: &Pubkey,
@ -412,6 +427,27 @@ impl Processor {
Self::unpack_token_account(swap_destination_info, token_swap.token_program_id())?;
let pool_mint = Self::unpack_mint(pool_mint_info, token_swap.token_program_id())?;
// Take transfer fees into account for actual amount transferred in
let actual_amount_in = {
let source_mint_data = source_token_mint_info.data.borrow();
let source_mint = Self::unpack_mint_with_extensions(
&source_mint_data,
source_token_mint_info.owner,
token_swap.token_program_id(),
)?;
if let Ok(transfer_fee_config) = source_mint.get_extension::<TransferFeeConfig>() {
amount_in.saturating_sub(
transfer_fee_config
.calculate_epoch_fee(Clock::get()?.epoch, amount_in)
.ok_or(SwapError::FeeCalculationFailure)?,
)
} else {
amount_in
}
};
// Calculate the trade amounts
let trade_direction = if *swap_source_info.key == *token_swap.token_a_account() {
TradeDirection::AtoB
} else {
@ -420,16 +456,61 @@ impl Processor {
let result = token_swap
.swap_curve()
.swap(
to_u128(amount_in)?,
to_u128(actual_amount_in)?,
to_u128(source_account.amount)?,
to_u128(dest_account.amount)?,
trade_direction,
token_swap.fees(),
)
.ok_or(SwapError::ZeroTradingTokens)?;
if result.destination_amount_swapped < to_u128(minimum_amount_out)? {
// Re-calculate the source amount swapped based on what the curve says
let (source_transfer_amount, source_mint_decimals) = {
let source_amount_swapped = to_u64(result.source_amount_swapped)?;
let source_mint_data = source_token_mint_info.data.borrow();
let source_mint = Self::unpack_mint_with_extensions(
&source_mint_data,
source_token_mint_info.owner,
token_swap.token_program_id(),
)?;
let amount =
if let Ok(transfer_fee_config) = source_mint.get_extension::<TransferFeeConfig>() {
source_amount_swapped.saturating_add(
transfer_fee_config
.calculate_inverse_epoch_fee(Clock::get()?.epoch, source_amount_swapped)
.ok_or(SwapError::FeeCalculationFailure)?,
)
} else {
source_amount_swapped
};
(amount, source_mint.base.decimals)
};
let (destination_transfer_amount, destination_mint_decimals) = {
let destination_mint_data = destination_token_mint_info.data.borrow();
let destination_mint = Self::unpack_mint_with_extensions(
&destination_mint_data,
source_token_mint_info.owner,
token_swap.token_program_id(),
)?;
let amount_out = to_u64(result.destination_amount_swapped)?;
let amount_received = if let Ok(transfer_fee_config) =
destination_mint.get_extension::<TransferFeeConfig>()
{
amount_out.saturating_sub(
transfer_fee_config
.calculate_epoch_fee(Clock::get()?.epoch, amount_out)
.ok_or(SwapError::FeeCalculationFailure)?,
)
} else {
amount_out
};
if amount_received < minimum_amount_out {
return Err(SwapError::ExceededSlippage.into());
}
(amount_out, destination_mint.base.decimals)
};
let (swap_token_a_amount, swap_token_b_amount) = match trade_direction {
TradeDirection::AtoB => (
@ -450,8 +531,8 @@ impl Processor {
swap_source_info.clone(),
user_transfer_authority_info.clone(),
token_swap.bump_seed(),
to_u64(result.source_amount_swapped)?,
Self::unpack_mint(source_token_mint_info, token_swap.token_program_id())?.decimals,
source_transfer_amount,
source_mint_decimals,
)?;
let mut pool_token_amount = token_swap
@ -519,8 +600,8 @@ impl Processor {
destination_info.clone(),
authority_info.clone(),
token_swap.bump_seed(),
to_u64(result.destination_amount_swapped)?,
Self::unpack_mint(destination_token_mint_info, token_swap.token_program_id())?.decimals,
destination_transfer_amount,
destination_mint_decimals,
)?;
Ok(())
@ -1194,7 +1275,10 @@ mod tests {
};
use spl_token_2022::{
error::TokenError,
extension::{transfer_fee::instruction::initialize_transfer_fee_config, ExtensionType},
extension::{
transfer_fee::{instruction::initialize_transfer_fee_config, TransferFee},
ExtensionType,
},
instruction::{
approve, close_account, freeze_account, initialize_account, initialize_immutable_owner,
initialize_mint, initialize_mint_close_authority, mint_to, revoke, set_authority,
@ -1277,17 +1361,11 @@ mod tests {
});
}
#[derive(Default)]
struct TransferFees {
transfer_fee_basis_points: u16,
maximum_transfer_fee: u64,
}
#[derive(Default)]
struct SwapTransferFees {
pool_token: TransferFees,
token_a: TransferFees,
token_b: TransferFees,
pool_token: TransferFee,
token_a: TransferFee,
token_b: TransferFee,
}
struct SwapAccountInfo {
@ -1335,8 +1413,12 @@ mod tests {
let (authority_key, bump_seed) =
Pubkey::find_program_address(&[&swap_key.to_bytes()[..]], &SWAP_PROGRAM_ID);
let (pool_mint_key, mut pool_mint_account) =
create_mint(pool_token_program_id, &authority_key, None, &transfer_fees.pool_token);
let (pool_mint_key, mut pool_mint_account) = create_mint(
pool_token_program_id,
&authority_key,
None,
&transfer_fees.pool_token,
);
let (pool_token_key, pool_token_account) = mint_token(
pool_token_program_id,
&pool_mint_key,
@ -2101,7 +2183,7 @@ mod tests {
program_id: &Pubkey,
authority_key: &Pubkey,
freeze_authority: Option<&Pubkey>,
fees: &TransferFees,
fees: &TransferFee,
) -> (Pubkey, SolanaAccount) {
let mint_key = Pubkey::new_unique();
let space = if *program_id == spl_token_2022::id() {
@ -2128,8 +2210,8 @@ mod tests {
&mint_key,
freeze_authority,
freeze_authority,
fees.transfer_fee_basis_points,
fees.maximum_transfer_fee,
fees.transfer_fee_basis_points.into(),
fees.maximum_fee.into(),
)
.unwrap(),
vec![&mut mint_account],
@ -2430,8 +2512,12 @@ mod tests {
// pool mint authority is not swap authority
{
let (_pool_mint_key, pool_mint_account) =
create_mint(&pool_token_program_id, &user_key, None, &TransferFees::default());
let (_pool_mint_key, pool_mint_account) = create_mint(
&pool_token_program_id,
&user_key,
None,
&TransferFee::default(),
);
let old_mint = accounts.pool_mint_account;
accounts.pool_mint_account = pool_mint_account;
assert_eq!(
@ -2446,7 +2532,8 @@ mod tests {
let (_pool_mint_key, pool_mint_account) = create_mint(
&pool_token_program_id,
&accounts.authority_key,
Some(&user_key), &TransferFees::default(),
Some(&user_key),
&TransferFee::default(),
);
let old_mint = accounts.pool_mint_account;
accounts.pool_mint_account = pool_mint_account;
@ -2540,8 +2627,12 @@ mod tests {
let old_mint = accounts.pool_mint_account;
let old_pool_account = accounts.pool_token_account;
let (_pool_mint_key, pool_mint_account) =
create_mint(&pool_token_program_id, &accounts.authority_key, None, &TransferFees::default());
let (_pool_mint_key, pool_mint_account) = create_mint(
&pool_token_program_id,
&accounts.authority_key,
None,
&TransferFee::default(),
);
accounts.pool_mint_account = pool_mint_account;
let (_empty_pool_token_key, empty_pool_token_account) = mint_token(
@ -3660,8 +3751,12 @@ mod tests {
pool_key,
mut pool_account,
) = accounts.setup_token_accounts(&user_key, &depositor_key, deposit_a, deposit_b, 0);
let (pool_mint_key, pool_mint_account) =
create_mint(&pool_token_program_id, &accounts.authority_key, None, &TransferFees::default());
let (pool_mint_key, pool_mint_account) = create_mint(
&pool_token_program_id,
&accounts.authority_key,
None,
&TransferFee::default(),
);
let old_pool_key = accounts.pool_mint_key;
let old_pool_account = accounts.pool_mint_account;
accounts.pool_mint_key = pool_mint_key;
@ -4382,8 +4477,12 @@ mod tests {
initial_b,
initial_pool.try_into().unwrap(),
);
let (pool_mint_key, pool_mint_account) =
create_mint(&pool_token_program_id, &accounts.authority_key, None, &TransferFees::default());
let (pool_mint_key, pool_mint_account) = create_mint(
&pool_token_program_id,
&accounts.authority_key,
None,
&TransferFee::default(),
);
let old_pool_key = accounts.pool_mint_key;
let old_pool_account = accounts.pool_mint_account;
accounts.pool_mint_key = pool_mint_key;
@ -5067,8 +5166,12 @@ mod tests {
pool_key,
mut pool_account,
) = accounts.setup_token_accounts(&user_key, &depositor_key, deposit_a, deposit_b, 0);
let (pool_mint_key, pool_mint_account) =
create_mint(&pool_token_program_id, &accounts.authority_key, None, &TransferFees::default());
let (pool_mint_key, pool_mint_account) = create_mint(
&pool_token_program_id,
&accounts.authority_key,
None,
&TransferFee::default(),
);
let old_pool_key = accounts.pool_mint_key;
let old_pool_account = accounts.pool_mint_account;
accounts.pool_mint_key = pool_mint_key;
@ -5684,8 +5787,12 @@ mod tests {
initial_b,
initial_pool.try_into().unwrap(),
);
let (pool_mint_key, pool_mint_account) =
create_mint(&pool_token_program_id, &accounts.authority_key, None, &TransferFees::default());
let (pool_mint_key, pool_mint_account) = create_mint(
&pool_token_program_id,
&accounts.authority_key,
None,
&TransferFee::default(),
);
let old_pool_key = accounts.pool_mint_key;
let old_pool_account = accounts.pool_mint_account;
accounts.pool_mint_key = pool_mint_key;
@ -5976,9 +6083,16 @@ mod tests {
)
.unwrap();
// tweak values based on transfer fees assessed
let token_a_fee = accounts
.transfer_fees
.token_a
.calculate_fee(a_to_b_amount)
.unwrap();
let actual_a_to_b_amount = a_to_b_amount - token_a_fee;
let results = swap_curve
.swap(
a_to_b_amount.try_into().unwrap(),
actual_a_to_b_amount.try_into().unwrap(),
token_a_amount.try_into().unwrap(),
token_b_amount.try_into().unwrap(),
TradeDirection::AtoB,
@ -6049,7 +6163,7 @@ mod tests {
)
.unwrap();
let results = swap_curve
let mut results = swap_curve
.swap(
b_to_a_amount.try_into().unwrap(),
token_b_amount.try_into().unwrap(),
@ -6058,6 +6172,13 @@ mod tests {
&fees,
)
.unwrap();
// tweak values based on transfer fees assessed
let token_a_fee = accounts
.transfer_fees
.token_a
.calculate_fee(results.destination_amount_swapped.try_into().unwrap())
.unwrap();
results.destination_amount_swapped -= token_a_fee as u128;
let swap_token_a =
StateWithExtensions::<Account>::unpack(&accounts.token_a_account.data).unwrap();
@ -6758,8 +6879,12 @@ mod tests {
_pool_key,
_pool_account,
) = accounts.setup_token_accounts(&user_key, &swapper_key, initial_a, initial_b, 0);
let (pool_mint_key, pool_mint_account) =
create_mint(&pool_token_program_id, &accounts.authority_key, None, &TransferFees::default());
let (pool_mint_key, pool_mint_account) = create_mint(
&pool_token_program_id,
&accounts.authority_key,
None,
&TransferFee::default(),
);
let old_pool_key = accounts.pool_mint_key;
let old_pool_account = accounts.pool_mint_account;
accounts.pool_mint_key = pool_mint_key;
@ -8149,14 +8274,15 @@ mod tests {
let token_b_amount = 50_000_000_000;
check_valid_swap_curve(
fees.clone(),
fees,
SwapTransferFees {
pool_token: TransferFees::default(),
token_a: TransferFees {
transfer_fee_basis_points: 100,
maximum_transfer_fee: 1_000_000_000,
pool_token: TransferFee::default(),
token_a: TransferFee {
epoch: 0.into(),
transfer_fee_basis_points: 100.into(),
maximum_fee: 1_000_000_000.into(),
},
token_b: TransferFees::default(),
token_b: TransferFee::default(),
},
CurveType::ConstantProduct,
Arc::new(ConstantProductCurve {}),