//! Simple constant 1:1 swap curve use solana_program::{ program_error::ProgramError, program_pack::{IsInitialized, Pack, Sealed}, }; use crate::curve::calculator::{calculate_fee, CurveCalculator, DynPack, SwapResult}; use arrayref::{array_mut_ref, array_ref, array_refs, mut_array_refs}; use std::convert::TryFrom; /// FlatCurve struct implementing CurveCalculator #[derive(Clone, Debug, Default, PartialEq)] pub struct FlatCurve { /// Fee numerator pub trade_fee_numerator: u64, /// Fee denominator pub trade_fee_denominator: u64, /// Owner trade fee numerator pub owner_trade_fee_numerator: u64, /// Owner trade fee denominator pub owner_trade_fee_denominator: u64, /// Owner withdraw fee numerator pub owner_withdraw_fee_numerator: u64, /// Owner withdraw fee denominator pub owner_withdraw_fee_denominator: u64, /// Host trading fee numerator pub host_fee_numerator: u64, /// Host trading fee denominator pub host_fee_denominator: u64, } impl CurveCalculator for FlatCurve { /// Flat curve swap always returns 1:1 (minus fee) fn swap( &self, source_amount: u128, swap_source_amount: u128, swap_destination_amount: u128, ) -> Option { // debit the fee to calculate the amount swapped let trade_fee = calculate_fee( source_amount, u128::try_from(self.trade_fee_numerator).ok()?, u128::try_from(self.trade_fee_denominator).ok()?, )?; let owner_fee = calculate_fee( source_amount, u128::try_from(self.owner_trade_fee_numerator).ok()?, u128::try_from(self.owner_trade_fee_denominator).ok()?, )?; let amount_swapped = source_amount .checked_sub(trade_fee)? .checked_sub(owner_fee)?; let new_destination_amount = swap_destination_amount.checked_sub(amount_swapped)?; // actually add the whole amount coming in let new_source_amount = swap_source_amount.checked_add(source_amount)?; Some(SwapResult { new_source_amount, new_destination_amount, amount_swapped, trade_fee, owner_fee, }) } /// Calculate the withdraw fee in pool tokens fn owner_withdraw_fee(&self, pool_tokens: u128) -> Option { calculate_fee( pool_tokens, u128::try_from(self.owner_withdraw_fee_numerator).ok()?, u128::try_from(self.owner_withdraw_fee_denominator).ok()?, ) } /// Calculate the host fee based on the owner fee, only used in production /// situations where a program is hosted by multiple frontends fn host_fee(&self, owner_fee: u128) -> Option { calculate_fee( owner_fee, u128::try_from(self.host_fee_numerator).ok()?, u128::try_from(self.host_fee_denominator).ok()?, ) } } /// IsInitialized is required to use `Pack::pack` and `Pack::unpack` impl IsInitialized for FlatCurve { fn is_initialized(&self) -> bool { true } } impl Sealed for FlatCurve {} impl Pack for FlatCurve { const LEN: usize = 64; fn unpack_from_slice(input: &[u8]) -> Result { let input = array_ref![input, 0, 64]; #[allow(clippy::ptr_offset_with_cast)] let ( trade_fee_numerator, trade_fee_denominator, owner_trade_fee_numerator, owner_trade_fee_denominator, owner_withdraw_fee_numerator, owner_withdraw_fee_denominator, host_fee_numerator, host_fee_denominator, ) = array_refs![input, 8, 8, 8, 8, 8, 8, 8, 8]; Ok(Self { trade_fee_numerator: u64::from_le_bytes(*trade_fee_numerator), trade_fee_denominator: u64::from_le_bytes(*trade_fee_denominator), owner_trade_fee_numerator: u64::from_le_bytes(*owner_trade_fee_numerator), owner_trade_fee_denominator: u64::from_le_bytes(*owner_trade_fee_denominator), owner_withdraw_fee_numerator: u64::from_le_bytes(*owner_withdraw_fee_numerator), owner_withdraw_fee_denominator: u64::from_le_bytes(*owner_withdraw_fee_denominator), host_fee_numerator: u64::from_le_bytes(*host_fee_numerator), host_fee_denominator: u64::from_le_bytes(*host_fee_denominator), }) } fn pack_into_slice(&self, output: &mut [u8]) { (self as &dyn DynPack).pack_into_slice(output); } } impl DynPack for FlatCurve { fn pack_into_slice(&self, output: &mut [u8]) { let output = array_mut_ref![output, 0, 64]; let ( trade_fee_numerator, trade_fee_denominator, owner_trade_fee_numerator, owner_trade_fee_denominator, owner_withdraw_fee_numerator, owner_withdraw_fee_denominator, host_fee_numerator, host_fee_denominator, ) = mut_array_refs![output, 8, 8, 8, 8, 8, 8, 8, 8]; *trade_fee_numerator = self.trade_fee_numerator.to_le_bytes(); *trade_fee_denominator = self.trade_fee_denominator.to_le_bytes(); *owner_trade_fee_numerator = self.owner_trade_fee_numerator.to_le_bytes(); *owner_trade_fee_denominator = self.owner_trade_fee_denominator.to_le_bytes(); *owner_withdraw_fee_numerator = self.owner_withdraw_fee_numerator.to_le_bytes(); *owner_withdraw_fee_denominator = self.owner_withdraw_fee_denominator.to_le_bytes(); *host_fee_numerator = self.host_fee_numerator.to_le_bytes(); *host_fee_denominator = self.host_fee_denominator.to_le_bytes(); } } #[cfg(test)] mod tests { use super::*; #[test] fn flat_swap_calculation() { let swap_source_amount = 1000; let swap_destination_amount = 50000; let trade_fee_numerator = 1; let trade_fee_denominator = 100; let owner_trade_fee_numerator = 2; let owner_trade_fee_denominator = 100; let owner_withdraw_fee_numerator = 2; let owner_withdraw_fee_denominator = 100; let host_fee_numerator = 2; let host_fee_denominator = 100; let source_amount: u128 = 100; let curve = FlatCurve { trade_fee_numerator, trade_fee_denominator, owner_trade_fee_numerator, owner_trade_fee_denominator, owner_withdraw_fee_numerator, owner_withdraw_fee_denominator, host_fee_numerator, host_fee_denominator, }; let result = curve .swap(source_amount, swap_source_amount, swap_destination_amount) .unwrap(); let amount_swapped = 97; assert_eq!(result.new_source_amount, 1100); assert_eq!(result.amount_swapped, amount_swapped); assert_eq!(result.trade_fee, 1); assert_eq!(result.owner_fee, 2); assert_eq!( result.new_destination_amount, swap_destination_amount - amount_swapped ); } #[test] fn pack_flat_curve() { let trade_fee_numerator = 1; let trade_fee_denominator = 4; let owner_trade_fee_numerator = 2; let owner_trade_fee_denominator = 5; let owner_withdraw_fee_numerator = 4; let owner_withdraw_fee_denominator = 10; let host_fee_numerator = 4; let host_fee_denominator = 10; let curve = FlatCurve { trade_fee_numerator, trade_fee_denominator, owner_trade_fee_numerator, owner_trade_fee_denominator, owner_withdraw_fee_numerator, owner_withdraw_fee_denominator, host_fee_numerator, host_fee_denominator, }; let mut packed = [0u8; FlatCurve::LEN]; Pack::pack_into_slice(&curve, &mut packed[..]); let unpacked = FlatCurve::unpack(&packed).unwrap(); assert_eq!(curve, unpacked); let mut packed = vec![]; packed.extend_from_slice(&trade_fee_numerator.to_le_bytes()); packed.extend_from_slice(&trade_fee_denominator.to_le_bytes()); packed.extend_from_slice(&owner_trade_fee_numerator.to_le_bytes()); packed.extend_from_slice(&owner_trade_fee_denominator.to_le_bytes()); packed.extend_from_slice(&owner_withdraw_fee_numerator.to_le_bytes()); packed.extend_from_slice(&owner_withdraw_fee_denominator.to_le_bytes()); packed.extend_from_slice(&host_fee_numerator.to_le_bytes()); packed.extend_from_slice(&host_fee_denominator.to_le_bytes()); let unpacked = FlatCurve::unpack(&packed).unwrap(); assert_eq!(curve, unpacked); } }