token-swap: Add curve and pool token trait and integrate in processor (#624)

* Add SwapCurve trait and integrate in processor

* Add PoolTokenConverter trait to correspond

* Add curve type parameter to initialization and JS

* Refactor for flat curve test, fmt

* Update token-swap/program/src/curve.rs

Co-authored-by: Tyera Eulberg <teulberg@gmail.com>

* Refactor swap curve to allow for any implementation

* Rename SwapCurveWrapper -> SwapCurve

* Run cargo fmt

* Update CurveType to enum in JS

Co-authored-by: Tyera Eulberg <teulberg@gmail.com>
This commit is contained in:
Jon Cinque 2020-10-21 20:46:50 +02:00 committed by GitHub
parent ed11438ec5
commit db42f7abbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 739 additions and 356 deletions

View File

@ -10,7 +10,7 @@ import {
} from '@solana/web3.js';
import {Token} from '../../../token/js/client/token';
import {TokenSwap} from '../client/token-swap';
import {TokenSwap, CurveType} from '../client/token-swap';
import {Store} from '../client/util/store';
import {newAccountWithLamports} from '../client/util/new-account-with-lamports';
import {url} from '../url';
@ -34,6 +34,8 @@ let mintB: Token;
let tokenAccountA: PublicKey;
let tokenAccountB: PublicKey;
// curve type used to calculate swaps and deposits
const CURVE_TYPE = CurveType.ConstantProduct;
// Initial amount in each swap token
const BASE_AMOUNT = 1000;
// Amount passed to swap instruction
@ -194,13 +196,14 @@ export async function createTokenSwap(): Promise<void> {
swapPayer,
tokenSwapAccount,
authority,
nonce,
tokenAccountA,
tokenAccountB,
tokenPool.publicKey,
tokenAccountPool,
tokenSwapProgramId,
tokenProgramId,
nonce,
CURVE_TYPE,
1,
4,
);
@ -217,6 +220,7 @@ export async function createTokenSwap(): Promise<void> {
assert(fetchedTokenSwap.tokenAccountA.equals(tokenAccountA));
assert(fetchedTokenSwap.tokenAccountB.equals(tokenAccountB));
assert(fetchedTokenSwap.poolToken.equals(tokenPool.publicKey));
assert(CURVE_TYPE == fetchedTokenSwap.curveType);
assert(1 == fetchedTokenSwap.feeNumerator.toNumber());
assert(4 == fetchedTokenSwap.feeDenominator.toNumber());
}

View File

@ -64,11 +64,18 @@ export const TokenSwapLayout: typeof BufferLayout.Structure = BufferLayout.struc
Layout.publicKey('tokenAccountA'),
Layout.publicKey('tokenAccountB'),
Layout.publicKey('tokenPool'),
BufferLayout.u8('curveType'),
Layout.uint64('feeNumerator'),
Layout.uint64('feeDenominator'),
BufferLayout.blob(48, 'padding'),
],
);
export const CurveType = Object.freeze({
ConstantProduct: 0, // Constant product curve, Uniswap-style
Flat: 1, // Flat curve, always 1:1 trades
});
/**
* A program to exchange tokens against a pool of liquidity
*/
@ -123,6 +130,11 @@ export class TokenSwap {
*/
feeDenominator: Numberu64;
/**
* CurveType, current options are:
*/
curveType: number;
/**
* Fee payer
*/
@ -150,6 +162,7 @@ export class TokenSwap {
authority: PublicKey,
tokenAccountA: PublicKey,
tokenAccountB: PublicKey,
curveType: number,
feeNumerator: Numberu64,
feeDenominator: Numberu64,
payer: Account,
@ -163,6 +176,7 @@ export class TokenSwap {
authority,
tokenAccountA,
tokenAccountB,
curveType,
feeNumerator,
feeDenominator,
payer,
@ -185,13 +199,14 @@ export class TokenSwap {
static createInitSwapInstruction(
tokenSwapAccount: Account,
authority: PublicKey,
nonce: number,
tokenAccountA: PublicKey,
tokenAccountB: PublicKey,
tokenPool: PublicKey,
tokenAccountPool: PublicKey,
tokenProgramId: PublicKey,
swapProgramId: PublicKey,
nonce: number,
curveType: number,
feeNumerator: number,
feeDenominator: number,
): TransactionInstruction {
@ -206,18 +221,21 @@ export class TokenSwap {
];
const commandDataLayout = BufferLayout.struct([
BufferLayout.u8('instruction'),
BufferLayout.u8('nonce'),
BufferLayout.u8('curveType'),
BufferLayout.nu64('feeNumerator'),
BufferLayout.nu64('feeDenominator'),
BufferLayout.u8('nonce'),
BufferLayout.blob(48, 'padding'),
]);
let data = Buffer.alloc(1024);
{
const encodeLength = commandDataLayout.encode(
{
instruction: 0, // InitializeSwap instruction
nonce,
curveType,
feeNumerator,
feeDenominator,
nonce,
},
data,
);
@ -254,6 +272,7 @@ export class TokenSwap {
const feeNumerator = Numberu64.fromBuffer(tokenSwapData.feeNumerator);
const feeDenominator = Numberu64.fromBuffer(tokenSwapData.feeDenominator);
const curveType = tokenSwapData.curveType;
return new TokenSwap(
connection,
@ -264,6 +283,7 @@ export class TokenSwap {
authority,
tokenAccountA,
tokenAccountB,
curveType,
feeNumerator,
feeDenominator,
payer,
@ -293,13 +313,14 @@ export class TokenSwap {
payer: Account,
tokenSwapAccount: Account,
authority: PublicKey,
nonce: number,
tokenAccountA: PublicKey,
tokenAccountB: PublicKey,
poolToken: PublicKey,
tokenAccountPool: PublicKey,
swapProgramId: PublicKey,
tokenProgramId: PublicKey,
nonce: number,
curveType: number,
feeNumerator: number,
feeDenominator: number,
): Promise<TokenSwap> {
@ -313,6 +334,7 @@ export class TokenSwap {
authority,
tokenAccountA,
tokenAccountB,
curveType,
new Numberu64(feeNumerator),
new Numberu64(feeDenominator),
payer,
@ -336,13 +358,14 @@ export class TokenSwap {
const instruction = TokenSwap.createInitSwapInstruction(
tokenSwapAccount,
authority,
nonce,
tokenAccountA,
tokenAccountB,
poolToken,
tokenAccountPool,
tokenProgramId,
swapProgramId,
nonce,
curveType,
feeNumerator,
feeDenominator,
);

View File

@ -17,6 +17,7 @@ declare module '@solana/spl-token-swap' {
}
export const TokenSwapLayout: Layout;
export const CurveType: Object;
export class TokenSwap {
constructor(
@ -28,6 +29,7 @@ declare module '@solana/spl-token-swap' {
authority: PublicKey,
tokenAccountA: PublicKey,
tokenAccountB: PublicKey,
curveType: number,
feeNumerator: Numberu64,
feeDenominator: Numberu64,
payer: Account,
@ -40,13 +42,14 @@ declare module '@solana/spl-token-swap' {
static createInitSwapInstruction(
tokenSwapAccount: Account,
authority: PublicKey,
nonce: number,
tokenAccountA: PublicKey,
tokenAccountB: PublicKey,
tokenPool: PublicKey,
tokenAccountPool: PublicKey,
tokenProgramId: PublicKey,
swapProgramId: PublicKey,
nonce: number,
curveType: number,
feeNumerator: number,
feeDenominator: number,
): TransactionInstruction;
@ -69,6 +72,7 @@ declare module '@solana/spl-token-swap' {
tokenAccountPool: PublicKey,
tokenProgramId: PublicKey,
nonce: number,
curveType: number,
feeNumerator: number,
feeDenominator: number,
swapProgramId: PublicKey,

View File

@ -14,6 +14,8 @@ declare module '@solana/spl-token-swap' {
declare export var TokenSwapLayout: Layout;
declare export var CurveType: Object;
declare export class TokenSwap {
constructor(
connection: Connection,
@ -24,6 +26,7 @@ declare module '@solana/spl-token-swap' {
authority: PublicKey,
tokenAccountA: PublicKey,
tokenAccountB: PublicKey,
curveType: number,
feeNumerator: Numberu64,
feeDenominator: Numberu64,
payer: Account,
@ -37,12 +40,13 @@ declare module '@solana/spl-token-swap' {
programId: PublicKey,
tokenSwapAccount: Account,
authority: PublicKey,
nonce: number,
tokenAccountA: PublicKey,
tokenAccountB: PublicKey,
tokenPool: PublicKey,
tokenAccountPool: PublicKey,
tokenProgramId: PublicKey,
nonce: number,
curveType: number,
feeNumerator: number,
feeDenominator: number,
): TransactionInstruction;
@ -65,6 +69,7 @@ declare module '@solana/spl-token-swap' {
tokenAccountPool: PublicKey,
tokenProgramId: PublicKey,
nonce: number,
curveType: number,
feeNumerator: number,
feeDenominator: number,
programId: PublicKey,

View File

@ -1,11 +1,153 @@
//! Swap calculations and curve implementations
use solana_sdk::{
program_error::ProgramError,
program_pack::{IsInitialized, Pack, Sealed},
};
use arrayref::{array_mut_ref, array_ref, array_refs, mut_array_refs};
use std::convert::{TryFrom, TryInto};
use std::fmt::Debug;
/// Initial amount of pool tokens for swap contract, hard-coded to something
/// "sensible" given a maximum of u64.
/// Note that on Ethereum, Uniswap uses the geometric mean of all provided
/// input amounts, and Balancer uses 100 * 10 ^ 18.
pub const INITIAL_SWAP_POOL_AMOUNT: u64 = 1_000_000_000;
/// Curve types supported by the token-swap program.
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum CurveType {
/// Uniswap-style constant product curve, invariant = token_a_amount * token_b_amount
ConstantProduct,
/// Flat line, always providing 1:1 from one token to another
Flat,
}
/// Concrete struct to wrap around the trait object which performs calculation.
#[repr(C)]
#[derive(Debug)]
pub struct SwapCurve {
/// The type of curve contained in the calculator, helpful for outside
/// queries
pub curve_type: CurveType,
/// The actual calculator, represented as a trait object to allow for many
/// different types of curves
pub calculator: Box<dyn CurveCalculator>,
}
/// Default implementation for SwapCurve cannot be derived because of
/// the contained Box.
impl Default for SwapCurve {
fn default() -> Self {
let curve_type: CurveType = Default::default();
let calculator: ConstantProductCurve = Default::default();
Self {
curve_type,
calculator: Box::new(calculator),
}
}
}
/// Simple implementation for PartialEq which assumes that the output of
/// `Pack` is enough to guarantee equality
impl PartialEq for SwapCurve {
fn eq(&self, other: &Self) -> bool {
let mut packed_self = [0u8; Self::LEN];
Self::pack_into_slice(self, &mut packed_self);
let mut packed_other = [0u8; Self::LEN];
Self::pack_into_slice(other, &mut packed_other);
packed_self[..] == packed_other[..]
}
}
impl Sealed for SwapCurve {}
impl Pack for SwapCurve {
/// Size of encoding of all curve parameters, which include fees and any other
/// constants used to calculate swaps, deposits, and withdrawals.
/// This includes 1 byte for the type, and 64 for the calculator to use as
/// it needs. Some calculators may be smaller than 64 bytes.
const LEN: usize = 65;
/// Unpacks a byte buffer into a SwapCurve
fn unpack_from_slice(input: &[u8]) -> Result<Self, ProgramError> {
let input = array_ref![input, 0, 65];
#[allow(clippy::ptr_offset_with_cast)]
let (curve_type, calculator) = array_refs![input, 1, 64];
let curve_type = curve_type[0].try_into()?;
Ok(Self {
curve_type,
calculator: match curve_type {
CurveType::ConstantProduct => {
Box::new(ConstantProductCurve::unpack_from_slice(calculator)?)
}
CurveType::Flat => Box::new(FlatCurve::unpack_from_slice(calculator)?),
},
})
}
/// Pack SwapCurve into a byte buffer
fn pack_into_slice(&self, output: &mut [u8]) {
let output = array_mut_ref![output, 0, 65];
let (curve_type, calculator) = mut_array_refs![output, 1, 64];
curve_type[0] = self.curve_type as u8;
self.calculator.pack_into_slice(&mut calculator[..]);
}
}
/// Sensible default of CurveType to ConstantProduct, the most popular and
/// well-known curve type.
impl Default for CurveType {
fn default() -> Self {
CurveType::ConstantProduct
}
}
impl TryFrom<u8> for CurveType {
type Error = ProgramError;
fn try_from(curve_type: u8) -> Result<Self, Self::Error> {
match curve_type {
0 => Ok(CurveType::ConstantProduct),
1 => Ok(CurveType::Flat),
_ => Err(ProgramError::InvalidAccountData),
}
}
}
/// Trait for packing of trait objects, required because structs that implement
/// `Pack` cannot be used as trait objects (as `dyn Pack`).
pub trait DynPack {
/// Only required function is to pack given a trait object
fn pack_into_slice(&self, dst: &mut [u8]);
}
/// Trait representing operations required on a swap curve
pub trait CurveCalculator: Debug + DynPack {
/// Calculate how much destination token will be provided given an amount
/// of source token.
fn swap(
&self,
source_amount: u64,
swap_source_amount: u64,
swap_destination_amount: u64,
) -> Option<SwapResult>;
/// Get the supply of a new pool (can be a default amount or calculated
/// based on parameters)
fn new_pool_supply(&self) -> u64;
/// Get the amount of liquidity tokens for pool tokens given the total amount
/// of liquidity tokens in the pool
fn liquidity_tokens(
&self,
pool_tokens: u64,
pool_token_supply: u64,
total_liquidity_tokens: u64,
) -> Option<u64>;
}
/// Encodes all results of swapping from a source token to a destination token
pub struct SwapResult {
/// New amount of source token
@ -16,30 +158,42 @@ pub struct SwapResult {
pub amount_swapped: u64,
}
impl SwapResult {
/// SwapResult for swap from one currency into another, given pool information
/// and fee
pub fn swap_to(
/// Helper function for mapping to SwapError::CalculationFailure
fn map_zero_to_none(x: u64) -> Option<u64> {
if x == 0 {
None
} else {
Some(x)
}
}
/// Simple constant 1:1 swap curve, example of different swap curve implementations
#[derive(Clone, Debug, Default, PartialEq)]
pub struct FlatCurve {
/// Fee numerator
pub fee_numerator: u64,
/// Fee denominator
pub fee_denominator: u64,
}
impl CurveCalculator for FlatCurve {
/// Flat curve swap always returns 1:1 (minus fee)
fn swap(
&self,
source_amount: u64,
swap_source_amount: u64,
swap_destination_amount: u64,
fee_numerator: u64,
fee_denominator: u64,
) -> Option<SwapResult> {
let invariant = swap_source_amount.checked_mul(swap_destination_amount)?;
// debit the fee to calculate the amount swapped
let mut fee = source_amount
.checked_mul(fee_numerator)?
.checked_div(fee_denominator)?;
.checked_mul(self.fee_numerator)?
.checked_div(self.fee_denominator)?;
if fee == 0 {
fee = 1; // minimum fee of one token
}
let new_source_amount_less_fee = swap_source_amount
.checked_add(source_amount)?
.checked_sub(fee)?;
let new_destination_amount = invariant.checked_div(new_source_amount_less_fee)?;
let amount_swapped = swap_destination_amount.checked_sub(new_destination_amount)?;
let amount_swapped = source_amount.checked_sub(fee)?;
let new_destination_amount = swap_destination_amount.checked_sub(amount_swapped)?;
// actually add the whole amount coming in
let new_source_amount = swap_source_amount.checked_add(source_amount)?;
@ -49,105 +203,154 @@ impl SwapResult {
amount_swapped,
})
}
/// Balancer-style fixed initial supply
fn new_pool_supply(&self) -> u64 {
INITIAL_SWAP_POOL_AMOUNT
}
/// Simple ratio calculation for how many liquidity tokens correspond to
/// a certain number of pool tokens
fn liquidity_tokens(
&self,
pool_tokens: u64,
pool_token_supply: u64,
total_liquidity_tokens: u64,
) -> Option<u64> {
pool_tokens
.checked_mul(total_liquidity_tokens)?
.checked_div(pool_token_supply)
.and_then(map_zero_to_none)
}
}
fn map_zero_to_none(x: u64) -> Option<u64> {
if x == 0 {
None
} else {
Some(x)
/// IsInitialized is required to use `Pack::pack` and `Pack::unpack`
impl IsInitialized for FlatCurve {
fn is_initialized(&self) -> bool {
true
}
}
impl Sealed for FlatCurve {}
impl Pack for FlatCurve {
const LEN: usize = 16;
/// Unpacks a byte buffer into a SwapCurve
fn unpack_from_slice(input: &[u8]) -> Result<FlatCurve, ProgramError> {
let input = array_ref![input, 0, 16];
#[allow(clippy::ptr_offset_with_cast)]
let (fee_numerator, fee_denominator) = array_refs![input, 8, 8];
Ok(Self {
fee_numerator: u64::from_le_bytes(*fee_numerator),
fee_denominator: u64::from_le_bytes(*fee_denominator),
})
}
fn pack_into_slice(&self, output: &mut [u8]) {
(self as &dyn DynPack).pack_into_slice(output);
}
}
impl DynPack for FlatCurve {
fn pack_into_slice(&self, output: &mut [u8]) {
let output = array_mut_ref![output, 0, 16];
let (fee_numerator, fee_denominator) = mut_array_refs![output, 8, 8];
*fee_numerator = self.fee_numerator.to_le_bytes();
*fee_denominator = self.fee_denominator.to_le_bytes();
}
}
/// The Uniswap invariant calculator.
pub struct ConstantProduct {
/// Token A
pub token_a: u64,
/// Token B
pub token_b: u64,
#[derive(Clone, Debug, Default, PartialEq)]
pub struct ConstantProductCurve {
/// Fee numerator
pub fee_numerator: u64,
/// Fee denominator
pub fee_denominator: u64,
}
impl ConstantProduct {
/// Swap token a to b
pub fn swap_a_to_b(&mut self, token_a: u64) -> Option<u64> {
let result = SwapResult::swap_to(
token_a,
self.token_a,
self.token_b,
self.fee_numerator,
self.fee_denominator,
)?;
self.token_a = result.new_source_amount;
self.token_b = result.new_destination_amount;
map_zero_to_none(result.amount_swapped)
impl CurveCalculator for ConstantProductCurve {
/// Constant product swap ensures x * y = constant
fn swap(
&self,
source_amount: u64,
swap_source_amount: u64,
swap_destination_amount: u64,
) -> Option<SwapResult> {
let invariant = swap_source_amount.checked_mul(swap_destination_amount)?;
// debit the fee to calculate the amount swapped
let mut fee = source_amount
.checked_mul(self.fee_numerator)?
.checked_div(self.fee_denominator)?;
if fee == 0 {
fee = 1; // minimum fee of one token
}
let new_source_amount_less_fee = swap_source_amount
.checked_add(source_amount)?
.checked_sub(fee)?;
let new_destination_amount = invariant.checked_div(new_source_amount_less_fee)?;
let amount_swapped =
map_zero_to_none(swap_destination_amount.checked_sub(new_destination_amount)?)?;
// actually add the whole amount coming in
let new_source_amount = swap_source_amount.checked_add(source_amount)?;
Some(SwapResult {
new_source_amount,
new_destination_amount,
amount_swapped,
})
}
/// Swap token b to a
pub fn swap_b_to_a(&mut self, token_b: u64) -> Option<u64> {
let result = SwapResult::swap_to(
token_b,
self.token_b,
self.token_a,
self.fee_numerator,
self.fee_denominator,
)?;
self.token_b = result.new_source_amount;
self.token_a = result.new_destination_amount;
map_zero_to_none(result.amount_swapped)
/// Balancer-style supply starts at a constant. This could be modified to
/// follow the geometric mean, as done in Uniswap v2.
fn new_pool_supply(&self) -> u64 {
INITIAL_SWAP_POOL_AMOUNT
}
/// Simple ratio calculation to get the amount of liquidity tokens given
/// pool information
fn liquidity_tokens(
&self,
pool_tokens: u64,
pool_token_supply: u64,
total_liquidity_tokens: u64,
) -> Option<u64> {
pool_tokens
.checked_mul(total_liquidity_tokens)?
.checked_div(pool_token_supply)
.and_then(map_zero_to_none)
}
}
/// Conversions for pool tokens, how much to deposit / withdraw, along with
/// proper initialization
pub struct PoolTokenConverter {
/// Total supply
pub supply: u64,
/// Token A amount
pub token_a: u64,
/// Token B amount
pub token_b: u64,
/// IsInitialized is required to use `Pack::pack` and `Pack::unpack`
impl IsInitialized for ConstantProductCurve {
fn is_initialized(&self) -> bool {
true
}
}
impl Sealed for ConstantProductCurve {}
impl Pack for ConstantProductCurve {
const LEN: usize = 16;
fn unpack_from_slice(input: &[u8]) -> Result<ConstantProductCurve, ProgramError> {
let input = array_ref![input, 0, 16];
#[allow(clippy::ptr_offset_with_cast)]
let (fee_numerator, fee_denominator) = array_refs![input, 8, 8];
Ok(Self {
fee_numerator: u64::from_le_bytes(*fee_numerator),
fee_denominator: u64::from_le_bytes(*fee_denominator),
})
}
fn pack_into_slice(&self, output: &mut [u8]) {
(self as &dyn DynPack).pack_into_slice(output);
}
}
impl PoolTokenConverter {
/// Create a converter based on existing market information
pub fn new_existing(supply: u64, token_a: u64, token_b: u64) -> Self {
Self {
supply,
token_a,
token_b,
}
}
/// Create a converter for a new pool token, no supply present yet.
/// According to Uniswap, the geometric mean protects the pool creator
/// in case the initial ratio is off the market.
pub fn new_pool(token_a: u64, token_b: u64) -> Self {
let supply = INITIAL_SWAP_POOL_AMOUNT;
Self {
supply,
token_a,
token_b,
}
}
/// A tokens for pool tokens, returns None if output is less than 0
pub fn token_a_rate(&self, pool_tokens: u64) -> Option<u64> {
pool_tokens
.checked_mul(self.token_a)?
.checked_div(self.supply)
.and_then(map_zero_to_none)
}
/// B tokens for pool tokens, returns None is output is less than 0
pub fn token_b_rate(&self, pool_tokens: u64) -> Option<u64> {
pool_tokens
.checked_mul(self.token_b)?
.checked_div(self.supply)
.and_then(map_zero_to_none)
impl DynPack for ConstantProductCurve {
fn pack_into_slice(&self, output: &mut [u8]) {
let output = array_mut_ref![output, 0, 16];
let (fee_numerator, fee_denominator) = mut_array_refs![output, 8, 8];
*fee_numerator = self.fee_numerator.to_le_bytes();
*fee_denominator = self.fee_denominator.to_le_bytes();
}
}
@ -157,48 +360,152 @@ mod tests {
#[test]
fn initial_pool_amount() {
let token_converter = PoolTokenConverter::new_pool(1, 5);
assert_eq!(token_converter.supply, INITIAL_SWAP_POOL_AMOUNT);
let fee_numerator = 0;
let fee_denominator = 1;
let calculator = ConstantProductCurve {
fee_numerator,
fee_denominator,
};
assert_eq!(calculator.new_pool_supply(), INITIAL_SWAP_POOL_AMOUNT);
}
fn check_pool_token_a_rate(
fn check_liquidity_pool_token_rate(
token_a: u64,
token_b: u64,
deposit: u64,
supply: u64,
expected: Option<u64>,
) {
let calculator = PoolTokenConverter::new_existing(supply, token_a, token_b);
assert_eq!(calculator.token_a_rate(deposit), expected);
let fee_numerator = 0;
let fee_denominator = 1;
let calculator = ConstantProductCurve {
fee_numerator,
fee_denominator,
};
assert_eq!(
calculator.liquidity_tokens(deposit, supply, token_a),
expected
);
}
#[test]
fn issued_tokens() {
check_pool_token_a_rate(2, 50, 5, 10, Some(1));
check_pool_token_a_rate(10, 10, 5, 10, Some(5));
check_pool_token_a_rate(5, 100, 5, 10, Some(2));
check_pool_token_a_rate(5, u64::MAX, 5, 10, Some(2));
check_pool_token_a_rate(u64::MAX, u64::MAX, 5, 10, None);
check_liquidity_pool_token_rate(2, 5, 10, Some(1));
check_liquidity_pool_token_rate(10, 5, 10, Some(5));
check_liquidity_pool_token_rate(5, 5, 10, Some(2));
check_liquidity_pool_token_rate(5, 5, 10, Some(2));
check_liquidity_pool_token_rate(u64::MAX, 5, 10, None);
}
#[test]
fn swap_calculation() {
fn constant_product_swap_calculation() {
// calculation on https://github.com/solana-labs/solana-program-library/issues/341
let swap_source_amount: u64 = 1000;
let swap_destination_amount: u64 = 50000;
let fee_numerator: u64 = 1;
let fee_denominator: u64 = 100;
let source_amount: u64 = 100;
let result = SwapResult::swap_to(
source_amount,
swap_source_amount,
swap_destination_amount,
let curve = ConstantProductCurve {
fee_numerator,
fee_denominator,
)
.unwrap();
};
let result = curve
.swap(source_amount, swap_source_amount, swap_destination_amount)
.unwrap();
assert_eq!(result.new_source_amount, 1100);
assert_eq!(result.amount_swapped, 4505);
assert_eq!(result.new_destination_amount, 45495);
}
#[test]
fn flat_swap_calculation() {
let swap_source_amount: u64 = 1000;
let swap_destination_amount: u64 = 50000;
let fee_numerator: u64 = 1;
let fee_denominator: u64 = 100;
let source_amount: u64 = 100;
let curve = FlatCurve {
fee_numerator,
fee_denominator,
};
let result = curve
.swap(source_amount, swap_source_amount, swap_destination_amount)
.unwrap();
let amount_swapped = 99;
assert_eq!(result.new_source_amount, 1100);
assert_eq!(result.amount_swapped, amount_swapped);
assert_eq!(
result.new_destination_amount,
swap_destination_amount - amount_swapped
);
}
#[test]
fn pack_flat_curve() {
let fee_numerator = 1;
let fee_denominator = 4;
let curve = FlatCurve {
fee_numerator,
fee_denominator,
};
let mut packed = [0u8; FlatCurve::LEN];
Pack::pack_into_slice(&curve, &mut packed[..]);
let unpacked = FlatCurve::unpack(&packed).unwrap();
assert_eq!(curve, unpacked);
let mut packed = vec![];
packed.extend_from_slice(&fee_numerator.to_le_bytes());
packed.extend_from_slice(&fee_denominator.to_le_bytes());
let unpacked = FlatCurve::unpack(&packed).unwrap();
assert_eq!(curve, unpacked);
}
#[test]
fn pack_constant_product_curve() {
let fee_numerator = 1;
let fee_denominator = 4;
let curve = ConstantProductCurve {
fee_numerator,
fee_denominator,
};
let mut packed = [0u8; ConstantProductCurve::LEN];
Pack::pack_into_slice(&curve, &mut packed[..]);
let unpacked = ConstantProductCurve::unpack(&packed).unwrap();
assert_eq!(curve, unpacked);
let mut packed = vec![];
packed.extend_from_slice(&fee_numerator.to_le_bytes());
packed.extend_from_slice(&fee_denominator.to_le_bytes());
let unpacked = ConstantProductCurve::unpack(&packed).unwrap();
assert_eq!(curve, unpacked);
}
#[test]
fn pack_swap_curve() {
let fee_numerator = 1;
let fee_denominator = 4;
let curve = ConstantProductCurve {
fee_numerator,
fee_denominator,
};
let curve_type = CurveType::ConstantProduct;
let swap_curve = SwapCurve {
curve_type,
calculator: Box::new(curve),
};
let mut packed = [0u8; SwapCurve::LEN];
Pack::pack_into_slice(&swap_curve, &mut packed[..]);
let unpacked = SwapCurve::unpack_from_slice(&packed).unwrap();
assert_eq!(swap_curve, unpacked);
let mut packed = vec![];
packed.push(curve_type as u8);
packed.extend_from_slice(&fee_numerator.to_le_bytes());
packed.extend_from_slice(&fee_denominator.to_le_bytes());
packed.extend_from_slice(&[0u8; 48]); // padding
let unpacked = SwapCurve::unpack_from_slice(&packed).unwrap();
assert_eq!(swap_curve, unpacked);
}
}

View File

@ -2,10 +2,12 @@
#![allow(clippy::too_many_arguments)]
use crate::curve::{ConstantProductCurve, CurveType, FlatCurve, SwapCurve};
use crate::error::SwapError;
use solana_sdk::{
instruction::{AccountMeta, Instruction},
program_error::ProgramError,
program_pack::Pack,
pubkey::Pubkey,
};
use std::convert::TryInto;
@ -13,7 +15,7 @@ use std::mem::size_of;
/// Instructions supported by the SwapInfo program.
#[repr(C)]
#[derive(Clone, Debug, PartialEq)]
#[derive(Debug, PartialEq)]
pub enum SwapInstruction {
/// Initializes a new SwapInfo.
///
@ -25,12 +27,11 @@ pub enum SwapInstruction {
/// 5. `[writable]` Pool Token Account to deposit the minted tokens. Must be empty, owned by user.
/// 6. '[]` Token program id
Initialize {
/// swap pool fee numerator
fee_numerator: u64,
/// swap pool fee denominator
fee_denominator: u64,
/// nonce used to create valid program address
nonce: u8,
/// swap curve info for pool, including CurveType, fees, and anything
/// else that may be required
swap_curve: SwapCurve,
},
/// Swap the tokens in the pool.
@ -99,14 +100,9 @@ impl SwapInstruction {
let (&tag, rest) = input.split_first().ok_or(SwapError::InvalidInstruction)?;
Ok(match tag {
0 => {
let (fee_numerator, rest) = Self::unpack_u64(rest)?;
let (fee_denominator, rest) = Self::unpack_u64(rest)?;
let (&nonce, _rest) = rest.split_first().ok_or(SwapError::InvalidInstruction)?;
Self::Initialize {
fee_numerator,
fee_denominator,
nonce,
}
let (&nonce, rest) = rest.split_first().ok_or(SwapError::InvalidInstruction)?;
let swap_curve = SwapCurve::unpack_unchecked(rest)?;
Self::Initialize { nonce, swap_curve }
}
1 => {
let (amount_in, rest) = Self::unpack_u64(rest)?;
@ -157,16 +153,13 @@ impl SwapInstruction {
/// Packs a [SwapInstruction](enum.SwapInstruction.html) into a byte buffer.
pub fn pack(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(size_of::<Self>());
match *self {
Self::Initialize {
fee_numerator,
fee_denominator,
nonce,
} => {
match &*self {
Self::Initialize { nonce, swap_curve } => {
buf.push(0);
buf.extend_from_slice(&fee_numerator.to_le_bytes());
buf.extend_from_slice(&fee_denominator.to_le_bytes());
buf.push(nonce);
buf.push(*nonce);
let mut swap_curve_slice = [0u8; SwapCurve::LEN];
Pack::pack_into_slice(swap_curve, &mut swap_curve_slice[..]);
buf.extend_from_slice(&swap_curve_slice);
}
Self::Swap {
amount_in,
@ -212,14 +205,24 @@ pub fn initialize(
pool_pubkey: &Pubkey,
destination_pubkey: &Pubkey,
nonce: u8,
curve_type: CurveType,
fee_numerator: u64,
fee_denominator: u64,
) -> Result<Instruction, ProgramError> {
let init_data = SwapInstruction::Initialize {
fee_numerator,
fee_denominator,
nonce,
let swap_curve = SwapCurve {
curve_type,
calculator: match curve_type {
CurveType::ConstantProduct => Box::new(ConstantProductCurve {
fee_numerator,
fee_denominator,
}),
CurveType::Flat => Box::new(FlatCurve {
fee_numerator,
fee_denominator,
}),
},
};
let init_data = SwapInstruction::Initialize { nonce, swap_curve };
let data = init_data.pack();
let accounts = vec![
@ -379,17 +382,24 @@ mod tests {
let fee_numerator: u64 = 1;
let fee_denominator: u64 = 4;
let nonce: u8 = 255;
let check = SwapInstruction::Initialize {
let curve_type = CurveType::Flat;
let calculator = Box::new(FlatCurve {
fee_numerator,
fee_denominator,
nonce,
});
let swap_curve = SwapCurve {
curve_type,
calculator,
};
let check = SwapInstruction::Initialize { nonce, swap_curve };
let packed = check.pack();
let mut expect = vec![];
expect.push(0 as u8);
expect.push(nonce);
expect.push(curve_type as u8);
expect.extend_from_slice(&fee_numerator.to_le_bytes());
expect.extend_from_slice(&fee_denominator.to_le_bytes());
expect.push(nonce);
expect.extend_from_slice(&[0u8; 48]); // padding
assert_eq!(packed, expect);
let unpacked = SwapInstruction::unpack(&expect).unwrap();
assert_eq!(unpacked, check);

View File

@ -2,12 +2,7 @@
#![cfg(feature = "program")]
use crate::{
curve::{ConstantProduct, PoolTokenConverter},
error::SwapError,
instruction::SwapInstruction,
state::SwapInfo,
};
use crate::{curve::SwapCurve, error::SwapError, instruction::SwapInstruction, state::SwapInfo};
use num_traits::FromPrimitive;
#[cfg(not(target_arch = "bpf"))]
use solana_sdk::instruction::Instruction;
@ -142,8 +137,7 @@ impl Processor {
pub fn process_initialize(
program_id: &Pubkey,
nonce: u8,
fee_numerator: u64,
fee_denominator: u64,
swap_curve: SwapCurve,
accounts: &[AccountInfo],
) -> ProgramResult {
let account_info_iter = &mut accounts.iter();
@ -198,8 +192,7 @@ impl Processor {
return Err(SwapError::InvalidSupply.into());
}
let converter = PoolTokenConverter::new_pool(token_a.amount, token_b.amount);
let initial_amount = converter.supply;
let initial_amount = swap_curve.calculator.new_pool_supply();
Self::token_mint_to(
swap_info.key,
@ -218,8 +211,7 @@ impl Processor {
token_a: *token_a_info.key,
token_b: *token_b_info.key,
pool_mint: *pool_mint_info.key,
fee_numerator,
fee_denominator,
swap_curve,
};
SwapInfo::pack(obj, &mut swap_info.data.borrow_mut())?;
Ok(())
@ -262,28 +254,12 @@ impl Processor {
let source_account = Self::unpack_token_account(&swap_source_info.data.borrow())?;
let dest_account = Self::unpack_token_account(&swap_destination_info.data.borrow())?;
let amount_out = if *swap_source_info.key == token_swap.token_a {
let mut invariant = ConstantProduct {
token_a: source_account.amount,
token_b: dest_account.amount,
fee_numerator: token_swap.fee_numerator,
fee_denominator: token_swap.fee_denominator,
};
invariant
.swap_a_to_b(amount_in)
.ok_or(SwapError::CalculationFailure)?
} else {
let mut invariant = ConstantProduct {
token_a: dest_account.amount,
token_b: source_account.amount,
fee_numerator: token_swap.fee_numerator,
fee_denominator: token_swap.fee_denominator,
};
invariant
.swap_b_to_a(amount_in)
.ok_or(SwapError::CalculationFailure)?
};
if amount_out < minimum_amount_out {
let result = token_swap
.swap_curve
.calculator
.swap(amount_in, source_account.amount, dest_account.amount)
.ok_or(SwapError::CalculationFailure)?;
if result.amount_swapped < minimum_amount_out {
return Err(SwapError::ExceededSlippage.into());
}
Self::token_transfer(
@ -302,7 +278,7 @@ impl Processor {
destination_info.clone(),
authority_info.clone(),
token_swap.nonce,
amount_out,
result.amount_swapped,
)?;
Ok(())
}
@ -344,17 +320,16 @@ impl Processor {
let token_b = Self::unpack_token_account(&token_b_info.data.borrow())?;
let pool_mint = Self::unpack_mint(&pool_mint_info.data.borrow())?;
let converter =
PoolTokenConverter::new_existing(pool_mint.supply, token_a.amount, token_b.amount);
let calculator = token_swap.swap_curve.calculator;
let a_amount = converter
.token_a_rate(pool_token_amount)
let a_amount = calculator
.liquidity_tokens(pool_token_amount, pool_mint.supply, token_a.amount)
.ok_or(SwapError::CalculationFailure)?;
if a_amount > maximum_token_a_amount {
return Err(SwapError::ExceededSlippage.into());
}
let b_amount = converter
.token_b_rate(pool_token_amount)
let b_amount = calculator
.liquidity_tokens(pool_token_amount, pool_mint.supply, token_b.amount)
.ok_or(SwapError::CalculationFailure)?;
if b_amount > maximum_token_b_amount {
return Err(SwapError::ExceededSlippage.into());
@ -428,17 +403,16 @@ impl Processor {
let token_b = Self::unpack_token_account(&token_b_info.data.borrow())?;
let pool_mint = Self::unpack_mint(&pool_mint_info.data.borrow())?;
let converter =
PoolTokenConverter::new_existing(pool_mint.supply, token_a.amount, token_b.amount);
let calculator = token_swap.swap_curve.calculator;
let a_amount = converter
.token_a_rate(pool_token_amount)
let a_amount = calculator
.liquidity_tokens(pool_token_amount, pool_mint.supply, token_a.amount)
.ok_or(SwapError::CalculationFailure)?;
if a_amount < minimum_token_a_amount {
return Err(SwapError::ExceededSlippage.into());
}
let b_amount = converter
.token_b_rate(pool_token_amount)
let b_amount = calculator
.liquidity_tokens(pool_token_amount, pool_mint.supply, token_b.amount)
.ok_or(SwapError::CalculationFailure)?;
if b_amount < minimum_token_b_amount {
return Err(SwapError::ExceededSlippage.into());
@ -478,19 +452,9 @@ impl Processor {
pub fn process(program_id: &Pubkey, accounts: &[AccountInfo], input: &[u8]) -> ProgramResult {
let instruction = SwapInstruction::unpack(input)?;
match instruction {
SwapInstruction::Initialize {
fee_numerator,
fee_denominator,
nonce,
} => {
SwapInstruction::Initialize { nonce, swap_curve } => {
info!("Instruction: Init");
Self::process_initialize(
program_id,
nonce,
fee_numerator,
fee_denominator,
accounts,
)
Self::process_initialize(program_id, nonce, swap_curve, accounts)
}
SwapInstruction::Swap {
amount_in,
@ -618,7 +582,9 @@ solana_sdk::program_stubs!();
mod tests {
use super::*;
use crate::{
curve::{SwapResult, INITIAL_SWAP_POOL_AMOUNT},
curve::{
ConstantProductCurve, CurveCalculator, CurveType, FlatCurve, INITIAL_SWAP_POOL_AMOUNT,
},
instruction::{deposit, initialize, swap, withdraw},
};
use solana_sdk::{
@ -634,6 +600,7 @@ mod tests {
struct SwapAccountInfo {
nonce: u8,
curve_type: CurveType,
authority_key: Pubkey,
fee_numerator: u64,
fee_denominator: u64,
@ -656,6 +623,7 @@ mod tests {
impl SwapAccountInfo {
pub fn new(
user_key: &Pubkey,
curve_type: CurveType,
fee_numerator: u64,
fee_denominator: u64,
token_a_amount: u64,
@ -699,6 +667,7 @@ mod tests {
SwapAccountInfo {
nonce,
curve_type,
authority_key,
fee_numerator,
fee_denominator,
@ -731,6 +700,7 @@ mod tests {
&self.pool_mint_key,
&self.pool_token_key,
self.nonce,
self.curve_type,
self.fee_numerator,
self.fee_denominator,
)
@ -1179,9 +1149,11 @@ mod tests {
let token_a_amount = 1000;
let token_b_amount = 2000;
let pool_token_amount = 10;
let curve_type = CurveType::ConstantProduct;
let mut accounts = SwapAccountInfo::new(
&user_key,
curve_type,
fee_numerator,
fee_denominator,
token_a_amount,
@ -1474,6 +1446,7 @@ mod tests {
&accounts.pool_mint_key,
&accounts.pool_token_key,
accounts.nonce,
accounts.curve_type,
accounts.fee_numerator,
accounts.fee_denominator,
)
@ -1513,6 +1486,19 @@ mod tests {
// create valid swap
accounts.initialize_swap().unwrap();
// create valid flat swap
{
let mut accounts = SwapAccountInfo::new(
&user_key,
CurveType::Flat,
fee_numerator,
fee_denominator,
token_a_amount,
token_b_amount,
);
accounts.initialize_swap().unwrap();
}
// create again
{
assert_eq!(
@ -1523,11 +1509,10 @@ mod tests {
let swap_info = SwapInfo::unpack(&accounts.swap_account.data).unwrap();
assert_eq!(swap_info.is_initialized, true);
assert_eq!(swap_info.nonce, accounts.nonce);
assert_eq!(swap_info.swap_curve.curve_type, accounts.curve_type);
assert_eq!(swap_info.token_a, accounts.token_a_key);
assert_eq!(swap_info.token_b, accounts.token_b_key);
assert_eq!(swap_info.pool_mint, accounts.pool_mint_key);
assert_eq!(swap_info.fee_denominator, fee_denominator);
assert_eq!(swap_info.fee_numerator, fee_numerator);
let token_a = Processor::unpack_token_account(&accounts.token_a_account.data).unwrap();
assert_eq!(token_a.amount, token_a_amount);
let token_b = Processor::unpack_token_account(&accounts.token_b_account.data).unwrap();
@ -1546,8 +1531,11 @@ mod tests {
let fee_denominator = 2;
let token_a_amount = 1000;
let token_b_amount = 9000;
let curve_type = CurveType::ConstantProduct;
let mut accounts = SwapAccountInfo::new(
&user_key,
curve_type,
fee_numerator,
fee_denominator,
token_a_amount,
@ -2063,18 +2051,24 @@ mod tests {
let fee_denominator = 2;
let token_a_amount = 1000;
let token_b_amount = 2000;
let curve_type = CurveType::ConstantProduct;
let mut accounts = SwapAccountInfo::new(
&user_key,
curve_type,
fee_numerator,
fee_denominator,
token_a_amount,
token_b_amount,
);
let withdrawer_key = pubkey_rand();
let pool_converter = PoolTokenConverter::new_pool(token_a_amount, token_b_amount);
let calculator = ConstantProductCurve {
fee_numerator,
fee_denominator,
};
let initial_a = token_a_amount / 10;
let initial_b = token_b_amount / 10;
let initial_pool = pool_converter.supply / 10;
let initial_pool = calculator.new_pool_supply() / 10;
let withdraw_amount = initial_pool / 4;
let minimum_a_amount = initial_a / 40;
let minimum_b_amount = initial_b / 40;
@ -2583,15 +2577,18 @@ mod tests {
let swap_token_b =
Processor::unpack_token_account(&accounts.token_b_account.data).unwrap();
let pool_mint = Processor::unpack_mint(&accounts.pool_mint_account.data).unwrap();
let pool_converter = PoolTokenConverter::new_existing(
pool_mint.supply,
swap_token_a.amount,
swap_token_b.amount,
);
let calculator = ConstantProductCurve {
fee_numerator,
fee_denominator,
};
let withdrawn_a = pool_converter.token_a_rate(withdraw_amount).unwrap();
let withdrawn_a = calculator
.liquidity_tokens(withdraw_amount, pool_mint.supply, swap_token_a.amount)
.unwrap();
assert_eq!(swap_token_a.amount, token_a_amount - withdrawn_a);
let withdrawn_b = pool_converter.token_b_rate(withdraw_amount).unwrap();
let withdrawn_b = calculator
.liquidity_tokens(withdraw_amount, pool_mint.supply, swap_token_b.amount)
.unwrap();
assert_eq!(swap_token_b.amount, token_b_amount - withdrawn_b);
let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap();
assert_eq!(token_a.amount, initial_a + withdrawn_a);
@ -2602,16 +2599,150 @@ mod tests {
}
}
#[test]
fn test_swap() {
fn check_valid_swap_curve(curve_type: CurveType, calculator: Box<dyn CurveCalculator>) {
let user_key = pubkey_rand();
let swapper_key = pubkey_rand();
let fee_numerator = 1;
let fee_denominator = 10;
let token_a_amount = 1000;
let token_b_amount = 5000;
let swap_curve = SwapCurve {
curve_type,
calculator,
};
let mut accounts = SwapAccountInfo::new(
&user_key,
curve_type,
fee_numerator,
fee_denominator,
token_a_amount,
token_b_amount,
);
let initial_a = token_a_amount / 5;
let initial_b = token_b_amount / 5;
accounts.initialize_swap().unwrap();
let swap_token_a_key = accounts.token_a_key;
let swap_token_b_key = accounts.token_b_key;
let (
token_a_key,
mut token_a_account,
token_b_key,
mut token_b_account,
_pool_key,
_pool_account,
) = accounts.setup_token_accounts(&user_key, &swapper_key, initial_a, initial_b, 0);
// swap one way
let a_to_b_amount = initial_a / 10;
let minimum_b_amount = 0;
accounts
.swap(
&swapper_key,
&token_a_key,
&mut token_a_account,
&swap_token_a_key,
&swap_token_b_key,
&token_b_key,
&mut token_b_account,
a_to_b_amount,
minimum_b_amount,
)
.unwrap();
let results = swap_curve
.calculator
.swap(a_to_b_amount, token_a_amount, token_b_amount)
.unwrap();
let swap_token_a = Processor::unpack_token_account(&accounts.token_a_account.data).unwrap();
let token_a_amount = swap_token_a.amount;
assert_eq!(token_a_amount, results.new_source_amount);
let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap();
assert_eq!(token_a.amount, initial_a - a_to_b_amount);
let swap_token_b = Processor::unpack_token_account(&accounts.token_b_account.data).unwrap();
let token_b_amount = swap_token_b.amount;
assert_eq!(token_b_amount, results.new_destination_amount);
let token_b = Processor::unpack_token_account(&token_b_account.data).unwrap();
assert_eq!(token_b.amount, initial_b + results.amount_swapped);
let first_swap_amount = results.amount_swapped;
// swap the other way
let b_to_a_amount = initial_b / 10;
let minimum_a_amount = 0;
accounts
.swap(
&swapper_key,
&token_b_key,
&mut token_b_account,
&swap_token_b_key,
&swap_token_a_key,
&token_a_key,
&mut token_a_account,
b_to_a_amount,
minimum_a_amount,
)
.unwrap();
let results = swap_curve
.calculator
.swap(b_to_a_amount, token_b_amount, token_a_amount)
.unwrap();
let swap_token_a = Processor::unpack_token_account(&accounts.token_a_account.data).unwrap();
assert_eq!(swap_token_a.amount, results.new_destination_amount);
let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap();
assert_eq!(
token_a.amount,
initial_a - a_to_b_amount + results.amount_swapped
);
let swap_token_b = Processor::unpack_token_account(&accounts.token_b_account.data).unwrap();
assert_eq!(swap_token_b.amount, results.new_source_amount);
let token_b = Processor::unpack_token_account(&token_b_account.data).unwrap();
assert_eq!(
token_b.amount,
initial_b + first_swap_amount - b_to_a_amount
);
}
#[test]
fn test_valid_swap_curves() {
let fee_numerator = 1;
let fee_denominator = 10;
check_valid_swap_curve(
CurveType::ConstantProduct,
Box::new(ConstantProductCurve {
fee_numerator,
fee_denominator,
}),
);
check_valid_swap_curve(
CurveType::Flat,
Box::new(FlatCurve {
fee_numerator,
fee_denominator,
}),
);
}
#[test]
fn test_invalid_swap() {
let user_key = pubkey_rand();
let swapper_key = pubkey_rand();
let fee_numerator = 1;
let fee_denominator = 10;
let token_a_amount = 1000;
let token_b_amount = 5000;
let curve_type = CurveType::ConstantProduct;
let mut accounts = SwapAccountInfo::new(
&user_key,
curve_type,
fee_numerator,
fee_denominator,
token_a_amount,
@ -2932,102 +3063,5 @@ mod tests {
)
);
}
// correct swap
{
let (
token_a_key,
mut token_a_account,
token_b_key,
mut token_b_account,
_pool_key,
_pool_account,
) = accounts.setup_token_accounts(&user_key, &swapper_key, initial_a, initial_b, 0);
// swap one way
let a_to_b_amount = initial_a / 10;
let minimum_b_amount = initial_b / 20;
accounts
.swap(
&swapper_key,
&token_a_key,
&mut token_a_account,
&swap_token_a_key,
&swap_token_b_key,
&token_b_key,
&mut token_b_account,
a_to_b_amount,
minimum_b_amount,
)
.unwrap();
let results = SwapResult::swap_to(
a_to_b_amount,
token_a_amount,
token_b_amount,
fee_numerator,
fee_denominator,
)
.unwrap();
let swap_token_a =
Processor::unpack_token_account(&accounts.token_a_account.data).unwrap();
let token_a_amount = swap_token_a.amount;
assert_eq!(token_a_amount, results.new_source_amount);
let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap();
assert_eq!(token_a.amount, initial_a - a_to_b_amount);
let swap_token_b =
Processor::unpack_token_account(&accounts.token_b_account.data).unwrap();
let token_b_amount = swap_token_b.amount;
assert_eq!(token_b_amount, results.new_destination_amount);
let token_b = Processor::unpack_token_account(&token_b_account.data).unwrap();
assert_eq!(token_b.amount, initial_b + results.amount_swapped);
let first_swap_amount = results.amount_swapped;
// swap the other way
let b_to_a_amount = initial_b / 10;
let minimum_a_amount = initial_a / 20;
accounts
.swap(
&swapper_key,
&token_b_key,
&mut token_b_account,
&swap_token_b_key,
&swap_token_a_key,
&token_a_key,
&mut token_a_account,
b_to_a_amount,
minimum_a_amount,
)
.unwrap();
let results = SwapResult::swap_to(
b_to_a_amount,
token_b_amount,
token_a_amount,
fee_numerator,
fee_denominator,
)
.unwrap();
let swap_token_a =
Processor::unpack_token_account(&accounts.token_a_account.data).unwrap();
assert_eq!(swap_token_a.amount, results.new_destination_amount);
let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap();
assert_eq!(
token_a.amount,
initial_a - a_to_b_amount + results.amount_swapped
);
let swap_token_b =
Processor::unpack_token_account(&accounts.token_b_account.data).unwrap();
assert_eq!(swap_token_b.amount, results.new_source_amount);
let token_b = Processor::unpack_token_account(&token_b_account.data).unwrap();
assert_eq!(
token_b.amount,
initial_b + first_swap_amount - b_to_a_amount
);
}
}
}

View File

@ -1,5 +1,6 @@
//! State transition types
use crate::curve::SwapCurve;
use arrayref::{array_mut_ref, array_ref, array_refs, mut_array_refs};
use solana_sdk::{
program_error::ProgramError,
@ -9,7 +10,7 @@ use solana_sdk::{
/// Program states.
#[repr(C)]
#[derive(Clone, Copy, Debug, Default, PartialEq)]
#[derive(Debug, Default, PartialEq)]
pub struct SwapInfo {
/// Initialized state.
pub is_initialized: bool,
@ -32,10 +33,9 @@ pub struct SwapInfo {
/// Pool tokens can be withdrawn back to the original A or B token.
pub pool_mint: Pubkey,
/// Numerator of fee applied to the input token amount prior to output calculation.
pub fee_numerator: u64,
/// Denominator of fee applied to the input token amount prior to output calculation.
pub fee_denominator: u64,
/// Swap curve parameters, to be unpacked and used by the SwapCurve, which
/// calculates swaps, deposits, and withdrawals
pub swap_curve: SwapCurve,
}
impl Sealed for SwapInfo {}
@ -46,22 +46,14 @@ impl IsInitialized for SwapInfo {
}
impl Pack for SwapInfo {
const LEN: usize = 146;
const LEN: usize = 195;
/// Unpacks a byte buffer into a [SwapInfo](struct.SwapInfo.html).
fn unpack_from_slice(input: &[u8]) -> Result<Self, ProgramError> {
let input = array_ref![input, 0, 146];
let input = array_ref![input, 0, 195];
#[allow(clippy::ptr_offset_with_cast)]
let (
is_initialized,
nonce,
token_program_id,
token_a,
token_b,
pool_mint,
fee_numerator,
fee_denominator,
) = array_refs![input, 1, 1, 32, 32, 32, 32, 8, 8];
let (is_initialized, nonce, token_program_id, token_a, token_b, pool_mint, swap_curve) =
array_refs![input, 1, 1, 32, 32, 32, 32, 65];
Ok(Self {
is_initialized: match is_initialized {
[0] => false,
@ -73,41 +65,36 @@ impl Pack for SwapInfo {
token_a: Pubkey::new_from_array(*token_a),
token_b: Pubkey::new_from_array(*token_b),
pool_mint: Pubkey::new_from_array(*pool_mint),
fee_numerator: u64::from_le_bytes(*fee_numerator),
fee_denominator: u64::from_le_bytes(*fee_denominator),
swap_curve: SwapCurve::unpack_from_slice(swap_curve)?,
})
}
fn pack_into_slice(&self, output: &mut [u8]) {
let output = array_mut_ref![output, 0, 146];
let (
is_initialized,
nonce,
token_program_id,
token_a,
token_b,
pool_mint,
fee_numerator,
fee_denominator,
) = mut_array_refs![output, 1, 1, 32, 32, 32, 32, 8, 8];
let output = array_mut_ref![output, 0, 195];
let (is_initialized, nonce, token_program_id, token_a, token_b, pool_mint, swap_curve) =
mut_array_refs![output, 1, 1, 32, 32, 32, 32, 65];
is_initialized[0] = self.is_initialized as u8;
nonce[0] = self.nonce;
token_program_id.copy_from_slice(self.token_program_id.as_ref());
token_a.copy_from_slice(self.token_a.as_ref());
token_b.copy_from_slice(self.token_b.as_ref());
pool_mint.copy_from_slice(self.pool_mint.as_ref());
*fee_numerator = self.fee_numerator.to_le_bytes();
*fee_denominator = self.fee_denominator.to_le_bytes();
self.swap_curve.pack_into_slice(&mut swap_curve[..]);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::curve::FlatCurve;
use std::convert::TryInto;
#[test]
fn test_swap_info_packing() {
let nonce = 255;
let curve_type_raw: u8 = 1;
let curve_type = curve_type_raw.try_into().unwrap();
let token_program_id_raw = [1u8; 32];
let token_a_raw = [1u8; 32];
let token_b_raw = [2u8; 32];
@ -118,6 +105,14 @@ mod tests {
let pool_mint = Pubkey::new_from_array(pool_mint_raw);
let fee_numerator = 1;
let fee_denominator = 4;
let calculator = Box::new(FlatCurve {
fee_numerator,
fee_denominator,
});
let swap_curve = SwapCurve {
curve_type,
calculator,
};
let is_initialized = true;
let swap_info = SwapInfo {
is_initialized,
@ -126,12 +121,11 @@ mod tests {
token_a,
token_b,
pool_mint,
fee_numerator,
fee_denominator,
swap_curve,
};
let mut packed = [0u8; SwapInfo::LEN];
SwapInfo::pack(swap_info, &mut packed).unwrap();
SwapInfo::pack_into_slice(&swap_info, &mut packed);
let unpacked = SwapInfo::unpack(&packed).unwrap();
assert_eq!(swap_info, unpacked);
@ -142,10 +136,12 @@ mod tests {
packed.extend_from_slice(&token_a_raw);
packed.extend_from_slice(&token_b_raw);
packed.extend_from_slice(&pool_mint_raw);
packed.push(curve_type_raw);
packed.push(fee_numerator as u8);
packed.extend_from_slice(&[0u8; 7]); // padding
packed.push(fee_denominator as u8);
packed.extend_from_slice(&[0u8; 7]); // padding
packed.extend_from_slice(&[0u8; 48]); // padding
let unpacked = SwapInfo::unpack(&packed).unwrap();
assert_eq!(swap_info, unpacked);