token-swap: Add curve and pool token trait and integrate in processor (#624)

* Add SwapCurve trait and integrate in processor

* Add PoolTokenConverter trait to correspond

* Add curve type parameter to initialization and JS

* Refactor for flat curve test, fmt

* Update token-swap/program/src/curve.rs

Co-authored-by: Tyera Eulberg <teulberg@gmail.com>

* Refactor swap curve to allow for any implementation

* Rename SwapCurveWrapper -> SwapCurve

* Run cargo fmt

* Update CurveType to enum in JS

Co-authored-by: Tyera Eulberg <teulberg@gmail.com>
This commit is contained in:
Jon Cinque 2020-10-21 20:46:50 +02:00 committed by GitHub
parent ed11438ec5
commit db42f7abbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 739 additions and 356 deletions

View File

@ -10,7 +10,7 @@ import {
} from '@solana/web3.js'; } from '@solana/web3.js';
import {Token} from '../../../token/js/client/token'; import {Token} from '../../../token/js/client/token';
import {TokenSwap} from '../client/token-swap'; import {TokenSwap, CurveType} from '../client/token-swap';
import {Store} from '../client/util/store'; import {Store} from '../client/util/store';
import {newAccountWithLamports} from '../client/util/new-account-with-lamports'; import {newAccountWithLamports} from '../client/util/new-account-with-lamports';
import {url} from '../url'; import {url} from '../url';
@ -34,6 +34,8 @@ let mintB: Token;
let tokenAccountA: PublicKey; let tokenAccountA: PublicKey;
let tokenAccountB: PublicKey; let tokenAccountB: PublicKey;
// curve type used to calculate swaps and deposits
const CURVE_TYPE = CurveType.ConstantProduct;
// Initial amount in each swap token // Initial amount in each swap token
const BASE_AMOUNT = 1000; const BASE_AMOUNT = 1000;
// Amount passed to swap instruction // Amount passed to swap instruction
@ -194,13 +196,14 @@ export async function createTokenSwap(): Promise<void> {
swapPayer, swapPayer,
tokenSwapAccount, tokenSwapAccount,
authority, authority,
nonce,
tokenAccountA, tokenAccountA,
tokenAccountB, tokenAccountB,
tokenPool.publicKey, tokenPool.publicKey,
tokenAccountPool, tokenAccountPool,
tokenSwapProgramId, tokenSwapProgramId,
tokenProgramId, tokenProgramId,
nonce,
CURVE_TYPE,
1, 1,
4, 4,
); );
@ -217,6 +220,7 @@ export async function createTokenSwap(): Promise<void> {
assert(fetchedTokenSwap.tokenAccountA.equals(tokenAccountA)); assert(fetchedTokenSwap.tokenAccountA.equals(tokenAccountA));
assert(fetchedTokenSwap.tokenAccountB.equals(tokenAccountB)); assert(fetchedTokenSwap.tokenAccountB.equals(tokenAccountB));
assert(fetchedTokenSwap.poolToken.equals(tokenPool.publicKey)); assert(fetchedTokenSwap.poolToken.equals(tokenPool.publicKey));
assert(CURVE_TYPE == fetchedTokenSwap.curveType);
assert(1 == fetchedTokenSwap.feeNumerator.toNumber()); assert(1 == fetchedTokenSwap.feeNumerator.toNumber());
assert(4 == fetchedTokenSwap.feeDenominator.toNumber()); assert(4 == fetchedTokenSwap.feeDenominator.toNumber());
} }

View File

@ -64,11 +64,18 @@ export const TokenSwapLayout: typeof BufferLayout.Structure = BufferLayout.struc
Layout.publicKey('tokenAccountA'), Layout.publicKey('tokenAccountA'),
Layout.publicKey('tokenAccountB'), Layout.publicKey('tokenAccountB'),
Layout.publicKey('tokenPool'), Layout.publicKey('tokenPool'),
BufferLayout.u8('curveType'),
Layout.uint64('feeNumerator'), Layout.uint64('feeNumerator'),
Layout.uint64('feeDenominator'), Layout.uint64('feeDenominator'),
BufferLayout.blob(48, 'padding'),
], ],
); );
export const CurveType = Object.freeze({
ConstantProduct: 0, // Constant product curve, Uniswap-style
Flat: 1, // Flat curve, always 1:1 trades
});
/** /**
* A program to exchange tokens against a pool of liquidity * A program to exchange tokens against a pool of liquidity
*/ */
@ -123,6 +130,11 @@ export class TokenSwap {
*/ */
feeDenominator: Numberu64; feeDenominator: Numberu64;
/**
* CurveType, current options are:
*/
curveType: number;
/** /**
* Fee payer * Fee payer
*/ */
@ -150,6 +162,7 @@ export class TokenSwap {
authority: PublicKey, authority: PublicKey,
tokenAccountA: PublicKey, tokenAccountA: PublicKey,
tokenAccountB: PublicKey, tokenAccountB: PublicKey,
curveType: number,
feeNumerator: Numberu64, feeNumerator: Numberu64,
feeDenominator: Numberu64, feeDenominator: Numberu64,
payer: Account, payer: Account,
@ -163,6 +176,7 @@ export class TokenSwap {
authority, authority,
tokenAccountA, tokenAccountA,
tokenAccountB, tokenAccountB,
curveType,
feeNumerator, feeNumerator,
feeDenominator, feeDenominator,
payer, payer,
@ -185,13 +199,14 @@ export class TokenSwap {
static createInitSwapInstruction( static createInitSwapInstruction(
tokenSwapAccount: Account, tokenSwapAccount: Account,
authority: PublicKey, authority: PublicKey,
nonce: number,
tokenAccountA: PublicKey, tokenAccountA: PublicKey,
tokenAccountB: PublicKey, tokenAccountB: PublicKey,
tokenPool: PublicKey, tokenPool: PublicKey,
tokenAccountPool: PublicKey, tokenAccountPool: PublicKey,
tokenProgramId: PublicKey, tokenProgramId: PublicKey,
swapProgramId: PublicKey, swapProgramId: PublicKey,
nonce: number,
curveType: number,
feeNumerator: number, feeNumerator: number,
feeDenominator: number, feeDenominator: number,
): TransactionInstruction { ): TransactionInstruction {
@ -206,18 +221,21 @@ export class TokenSwap {
]; ];
const commandDataLayout = BufferLayout.struct([ const commandDataLayout = BufferLayout.struct([
BufferLayout.u8('instruction'), BufferLayout.u8('instruction'),
BufferLayout.u8('nonce'),
BufferLayout.u8('curveType'),
BufferLayout.nu64('feeNumerator'), BufferLayout.nu64('feeNumerator'),
BufferLayout.nu64('feeDenominator'), BufferLayout.nu64('feeDenominator'),
BufferLayout.u8('nonce'), BufferLayout.blob(48, 'padding'),
]); ]);
let data = Buffer.alloc(1024); let data = Buffer.alloc(1024);
{ {
const encodeLength = commandDataLayout.encode( const encodeLength = commandDataLayout.encode(
{ {
instruction: 0, // InitializeSwap instruction instruction: 0, // InitializeSwap instruction
nonce,
curveType,
feeNumerator, feeNumerator,
feeDenominator, feeDenominator,
nonce,
}, },
data, data,
); );
@ -254,6 +272,7 @@ export class TokenSwap {
const feeNumerator = Numberu64.fromBuffer(tokenSwapData.feeNumerator); const feeNumerator = Numberu64.fromBuffer(tokenSwapData.feeNumerator);
const feeDenominator = Numberu64.fromBuffer(tokenSwapData.feeDenominator); const feeDenominator = Numberu64.fromBuffer(tokenSwapData.feeDenominator);
const curveType = tokenSwapData.curveType;
return new TokenSwap( return new TokenSwap(
connection, connection,
@ -264,6 +283,7 @@ export class TokenSwap {
authority, authority,
tokenAccountA, tokenAccountA,
tokenAccountB, tokenAccountB,
curveType,
feeNumerator, feeNumerator,
feeDenominator, feeDenominator,
payer, payer,
@ -293,13 +313,14 @@ export class TokenSwap {
payer: Account, payer: Account,
tokenSwapAccount: Account, tokenSwapAccount: Account,
authority: PublicKey, authority: PublicKey,
nonce: number,
tokenAccountA: PublicKey, tokenAccountA: PublicKey,
tokenAccountB: PublicKey, tokenAccountB: PublicKey,
poolToken: PublicKey, poolToken: PublicKey,
tokenAccountPool: PublicKey, tokenAccountPool: PublicKey,
swapProgramId: PublicKey, swapProgramId: PublicKey,
tokenProgramId: PublicKey, tokenProgramId: PublicKey,
nonce: number,
curveType: number,
feeNumerator: number, feeNumerator: number,
feeDenominator: number, feeDenominator: number,
): Promise<TokenSwap> { ): Promise<TokenSwap> {
@ -313,6 +334,7 @@ export class TokenSwap {
authority, authority,
tokenAccountA, tokenAccountA,
tokenAccountB, tokenAccountB,
curveType,
new Numberu64(feeNumerator), new Numberu64(feeNumerator),
new Numberu64(feeDenominator), new Numberu64(feeDenominator),
payer, payer,
@ -336,13 +358,14 @@ export class TokenSwap {
const instruction = TokenSwap.createInitSwapInstruction( const instruction = TokenSwap.createInitSwapInstruction(
tokenSwapAccount, tokenSwapAccount,
authority, authority,
nonce,
tokenAccountA, tokenAccountA,
tokenAccountB, tokenAccountB,
poolToken, poolToken,
tokenAccountPool, tokenAccountPool,
tokenProgramId, tokenProgramId,
swapProgramId, swapProgramId,
nonce,
curveType,
feeNumerator, feeNumerator,
feeDenominator, feeDenominator,
); );

View File

@ -17,6 +17,7 @@ declare module '@solana/spl-token-swap' {
} }
export const TokenSwapLayout: Layout; export const TokenSwapLayout: Layout;
export const CurveType: Object;
export class TokenSwap { export class TokenSwap {
constructor( constructor(
@ -28,6 +29,7 @@ declare module '@solana/spl-token-swap' {
authority: PublicKey, authority: PublicKey,
tokenAccountA: PublicKey, tokenAccountA: PublicKey,
tokenAccountB: PublicKey, tokenAccountB: PublicKey,
curveType: number,
feeNumerator: Numberu64, feeNumerator: Numberu64,
feeDenominator: Numberu64, feeDenominator: Numberu64,
payer: Account, payer: Account,
@ -40,13 +42,14 @@ declare module '@solana/spl-token-swap' {
static createInitSwapInstruction( static createInitSwapInstruction(
tokenSwapAccount: Account, tokenSwapAccount: Account,
authority: PublicKey, authority: PublicKey,
nonce: number,
tokenAccountA: PublicKey, tokenAccountA: PublicKey,
tokenAccountB: PublicKey, tokenAccountB: PublicKey,
tokenPool: PublicKey, tokenPool: PublicKey,
tokenAccountPool: PublicKey, tokenAccountPool: PublicKey,
tokenProgramId: PublicKey, tokenProgramId: PublicKey,
swapProgramId: PublicKey, swapProgramId: PublicKey,
nonce: number,
curveType: number,
feeNumerator: number, feeNumerator: number,
feeDenominator: number, feeDenominator: number,
): TransactionInstruction; ): TransactionInstruction;
@ -69,6 +72,7 @@ declare module '@solana/spl-token-swap' {
tokenAccountPool: PublicKey, tokenAccountPool: PublicKey,
tokenProgramId: PublicKey, tokenProgramId: PublicKey,
nonce: number, nonce: number,
curveType: number,
feeNumerator: number, feeNumerator: number,
feeDenominator: number, feeDenominator: number,
swapProgramId: PublicKey, swapProgramId: PublicKey,

View File

@ -14,6 +14,8 @@ declare module '@solana/spl-token-swap' {
declare export var TokenSwapLayout: Layout; declare export var TokenSwapLayout: Layout;
declare export var CurveType: Object;
declare export class TokenSwap { declare export class TokenSwap {
constructor( constructor(
connection: Connection, connection: Connection,
@ -24,6 +26,7 @@ declare module '@solana/spl-token-swap' {
authority: PublicKey, authority: PublicKey,
tokenAccountA: PublicKey, tokenAccountA: PublicKey,
tokenAccountB: PublicKey, tokenAccountB: PublicKey,
curveType: number,
feeNumerator: Numberu64, feeNumerator: Numberu64,
feeDenominator: Numberu64, feeDenominator: Numberu64,
payer: Account, payer: Account,
@ -37,12 +40,13 @@ declare module '@solana/spl-token-swap' {
programId: PublicKey, programId: PublicKey,
tokenSwapAccount: Account, tokenSwapAccount: Account,
authority: PublicKey, authority: PublicKey,
nonce: number,
tokenAccountA: PublicKey, tokenAccountA: PublicKey,
tokenAccountB: PublicKey, tokenAccountB: PublicKey,
tokenPool: PublicKey, tokenPool: PublicKey,
tokenAccountPool: PublicKey, tokenAccountPool: PublicKey,
tokenProgramId: PublicKey, tokenProgramId: PublicKey,
nonce: number,
curveType: number,
feeNumerator: number, feeNumerator: number,
feeDenominator: number, feeDenominator: number,
): TransactionInstruction; ): TransactionInstruction;
@ -65,6 +69,7 @@ declare module '@solana/spl-token-swap' {
tokenAccountPool: PublicKey, tokenAccountPool: PublicKey,
tokenProgramId: PublicKey, tokenProgramId: PublicKey,
nonce: number, nonce: number,
curveType: number,
feeNumerator: number, feeNumerator: number,
feeDenominator: number, feeDenominator: number,
programId: PublicKey, programId: PublicKey,

View File

@ -1,11 +1,153 @@
//! Swap calculations and curve implementations //! Swap calculations and curve implementations
use solana_sdk::{
program_error::ProgramError,
program_pack::{IsInitialized, Pack, Sealed},
};
use arrayref::{array_mut_ref, array_ref, array_refs, mut_array_refs};
use std::convert::{TryFrom, TryInto};
use std::fmt::Debug;
/// Initial amount of pool tokens for swap contract, hard-coded to something /// Initial amount of pool tokens for swap contract, hard-coded to something
/// "sensible" given a maximum of u64. /// "sensible" given a maximum of u64.
/// Note that on Ethereum, Uniswap uses the geometric mean of all provided /// Note that on Ethereum, Uniswap uses the geometric mean of all provided
/// input amounts, and Balancer uses 100 * 10 ^ 18. /// input amounts, and Balancer uses 100 * 10 ^ 18.
pub const INITIAL_SWAP_POOL_AMOUNT: u64 = 1_000_000_000; pub const INITIAL_SWAP_POOL_AMOUNT: u64 = 1_000_000_000;
/// Curve types supported by the token-swap program.
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum CurveType {
/// Uniswap-style constant product curve, invariant = token_a_amount * token_b_amount
ConstantProduct,
/// Flat line, always providing 1:1 from one token to another
Flat,
}
/// Concrete struct to wrap around the trait object which performs calculation.
#[repr(C)]
#[derive(Debug)]
pub struct SwapCurve {
/// The type of curve contained in the calculator, helpful for outside
/// queries
pub curve_type: CurveType,
/// The actual calculator, represented as a trait object to allow for many
/// different types of curves
pub calculator: Box<dyn CurveCalculator>,
}
/// Default implementation for SwapCurve cannot be derived because of
/// the contained Box.
impl Default for SwapCurve {
fn default() -> Self {
let curve_type: CurveType = Default::default();
let calculator: ConstantProductCurve = Default::default();
Self {
curve_type,
calculator: Box::new(calculator),
}
}
}
/// Simple implementation for PartialEq which assumes that the output of
/// `Pack` is enough to guarantee equality
impl PartialEq for SwapCurve {
fn eq(&self, other: &Self) -> bool {
let mut packed_self = [0u8; Self::LEN];
Self::pack_into_slice(self, &mut packed_self);
let mut packed_other = [0u8; Self::LEN];
Self::pack_into_slice(other, &mut packed_other);
packed_self[..] == packed_other[..]
}
}
impl Sealed for SwapCurve {}
impl Pack for SwapCurve {
/// Size of encoding of all curve parameters, which include fees and any other
/// constants used to calculate swaps, deposits, and withdrawals.
/// This includes 1 byte for the type, and 64 for the calculator to use as
/// it needs. Some calculators may be smaller than 64 bytes.
const LEN: usize = 65;
/// Unpacks a byte buffer into a SwapCurve
fn unpack_from_slice(input: &[u8]) -> Result<Self, ProgramError> {
let input = array_ref![input, 0, 65];
#[allow(clippy::ptr_offset_with_cast)]
let (curve_type, calculator) = array_refs![input, 1, 64];
let curve_type = curve_type[0].try_into()?;
Ok(Self {
curve_type,
calculator: match curve_type {
CurveType::ConstantProduct => {
Box::new(ConstantProductCurve::unpack_from_slice(calculator)?)
}
CurveType::Flat => Box::new(FlatCurve::unpack_from_slice(calculator)?),
},
})
}
/// Pack SwapCurve into a byte buffer
fn pack_into_slice(&self, output: &mut [u8]) {
let output = array_mut_ref![output, 0, 65];
let (curve_type, calculator) = mut_array_refs![output, 1, 64];
curve_type[0] = self.curve_type as u8;
self.calculator.pack_into_slice(&mut calculator[..]);
}
}
/// Sensible default of CurveType to ConstantProduct, the most popular and
/// well-known curve type.
impl Default for CurveType {
fn default() -> Self {
CurveType::ConstantProduct
}
}
impl TryFrom<u8> for CurveType {
type Error = ProgramError;
fn try_from(curve_type: u8) -> Result<Self, Self::Error> {
match curve_type {
0 => Ok(CurveType::ConstantProduct),
1 => Ok(CurveType::Flat),
_ => Err(ProgramError::InvalidAccountData),
}
}
}
/// Trait for packing of trait objects, required because structs that implement
/// `Pack` cannot be used as trait objects (as `dyn Pack`).
pub trait DynPack {
/// Only required function is to pack given a trait object
fn pack_into_slice(&self, dst: &mut [u8]);
}
/// Trait representing operations required on a swap curve
pub trait CurveCalculator: Debug + DynPack {
/// Calculate how much destination token will be provided given an amount
/// of source token.
fn swap(
&self,
source_amount: u64,
swap_source_amount: u64,
swap_destination_amount: u64,
) -> Option<SwapResult>;
/// Get the supply of a new pool (can be a default amount or calculated
/// based on parameters)
fn new_pool_supply(&self) -> u64;
/// Get the amount of liquidity tokens for pool tokens given the total amount
/// of liquidity tokens in the pool
fn liquidity_tokens(
&self,
pool_tokens: u64,
pool_token_supply: u64,
total_liquidity_tokens: u64,
) -> Option<u64>;
}
/// Encodes all results of swapping from a source token to a destination token /// Encodes all results of swapping from a source token to a destination token
pub struct SwapResult { pub struct SwapResult {
/// New amount of source token /// New amount of source token
@ -16,30 +158,42 @@ pub struct SwapResult {
pub amount_swapped: u64, pub amount_swapped: u64,
} }
impl SwapResult { /// Helper function for mapping to SwapError::CalculationFailure
/// SwapResult for swap from one currency into another, given pool information fn map_zero_to_none(x: u64) -> Option<u64> {
/// and fee if x == 0 {
pub fn swap_to( None
} else {
Some(x)
}
}
/// Simple constant 1:1 swap curve, example of different swap curve implementations
#[derive(Clone, Debug, Default, PartialEq)]
pub struct FlatCurve {
/// Fee numerator
pub fee_numerator: u64,
/// Fee denominator
pub fee_denominator: u64,
}
impl CurveCalculator for FlatCurve {
/// Flat curve swap always returns 1:1 (minus fee)
fn swap(
&self,
source_amount: u64, source_amount: u64,
swap_source_amount: u64, swap_source_amount: u64,
swap_destination_amount: u64, swap_destination_amount: u64,
fee_numerator: u64,
fee_denominator: u64,
) -> Option<SwapResult> { ) -> Option<SwapResult> {
let invariant = swap_source_amount.checked_mul(swap_destination_amount)?;
// debit the fee to calculate the amount swapped // debit the fee to calculate the amount swapped
let mut fee = source_amount let mut fee = source_amount
.checked_mul(fee_numerator)? .checked_mul(self.fee_numerator)?
.checked_div(fee_denominator)?; .checked_div(self.fee_denominator)?;
if fee == 0 { if fee == 0 {
fee = 1; // minimum fee of one token fee = 1; // minimum fee of one token
} }
let new_source_amount_less_fee = swap_source_amount
.checked_add(source_amount)? let amount_swapped = source_amount.checked_sub(fee)?;
.checked_sub(fee)?; let new_destination_amount = swap_destination_amount.checked_sub(amount_swapped)?;
let new_destination_amount = invariant.checked_div(new_source_amount_less_fee)?;
let amount_swapped = swap_destination_amount.checked_sub(new_destination_amount)?;
// actually add the whole amount coming in // actually add the whole amount coming in
let new_source_amount = swap_source_amount.checked_add(source_amount)?; let new_source_amount = swap_source_amount.checked_add(source_amount)?;
@ -49,105 +203,154 @@ impl SwapResult {
amount_swapped, amount_swapped,
}) })
} }
/// Balancer-style fixed initial supply
fn new_pool_supply(&self) -> u64 {
INITIAL_SWAP_POOL_AMOUNT
}
/// Simple ratio calculation for how many liquidity tokens correspond to
/// a certain number of pool tokens
fn liquidity_tokens(
&self,
pool_tokens: u64,
pool_token_supply: u64,
total_liquidity_tokens: u64,
) -> Option<u64> {
pool_tokens
.checked_mul(total_liquidity_tokens)?
.checked_div(pool_token_supply)
.and_then(map_zero_to_none)
}
} }
fn map_zero_to_none(x: u64) -> Option<u64> { /// IsInitialized is required to use `Pack::pack` and `Pack::unpack`
if x == 0 { impl IsInitialized for FlatCurve {
None fn is_initialized(&self) -> bool {
} else { true
Some(x) }
}
impl Sealed for FlatCurve {}
impl Pack for FlatCurve {
const LEN: usize = 16;
/// Unpacks a byte buffer into a SwapCurve
fn unpack_from_slice(input: &[u8]) -> Result<FlatCurve, ProgramError> {
let input = array_ref![input, 0, 16];
#[allow(clippy::ptr_offset_with_cast)]
let (fee_numerator, fee_denominator) = array_refs![input, 8, 8];
Ok(Self {
fee_numerator: u64::from_le_bytes(*fee_numerator),
fee_denominator: u64::from_le_bytes(*fee_denominator),
})
}
fn pack_into_slice(&self, output: &mut [u8]) {
(self as &dyn DynPack).pack_into_slice(output);
}
}
impl DynPack for FlatCurve {
fn pack_into_slice(&self, output: &mut [u8]) {
let output = array_mut_ref![output, 0, 16];
let (fee_numerator, fee_denominator) = mut_array_refs![output, 8, 8];
*fee_numerator = self.fee_numerator.to_le_bytes();
*fee_denominator = self.fee_denominator.to_le_bytes();
} }
} }
/// The Uniswap invariant calculator. /// The Uniswap invariant calculator.
pub struct ConstantProduct { #[derive(Clone, Debug, Default, PartialEq)]
/// Token A pub struct ConstantProductCurve {
pub token_a: u64,
/// Token B
pub token_b: u64,
/// Fee numerator /// Fee numerator
pub fee_numerator: u64, pub fee_numerator: u64,
/// Fee denominator /// Fee denominator
pub fee_denominator: u64, pub fee_denominator: u64,
} }
impl ConstantProduct { impl CurveCalculator for ConstantProductCurve {
/// Swap token a to b /// Constant product swap ensures x * y = constant
pub fn swap_a_to_b(&mut self, token_a: u64) -> Option<u64> { fn swap(
let result = SwapResult::swap_to( &self,
token_a, source_amount: u64,
self.token_a, swap_source_amount: u64,
self.token_b, swap_destination_amount: u64,
self.fee_numerator, ) -> Option<SwapResult> {
self.fee_denominator, let invariant = swap_source_amount.checked_mul(swap_destination_amount)?;
)?;
self.token_a = result.new_source_amount; // debit the fee to calculate the amount swapped
self.token_b = result.new_destination_amount; let mut fee = source_amount
map_zero_to_none(result.amount_swapped) .checked_mul(self.fee_numerator)?
.checked_div(self.fee_denominator)?;
if fee == 0 {
fee = 1; // minimum fee of one token
}
let new_source_amount_less_fee = swap_source_amount
.checked_add(source_amount)?
.checked_sub(fee)?;
let new_destination_amount = invariant.checked_div(new_source_amount_less_fee)?;
let amount_swapped =
map_zero_to_none(swap_destination_amount.checked_sub(new_destination_amount)?)?;
// actually add the whole amount coming in
let new_source_amount = swap_source_amount.checked_add(source_amount)?;
Some(SwapResult {
new_source_amount,
new_destination_amount,
amount_swapped,
})
} }
/// Swap token b to a /// Balancer-style supply starts at a constant. This could be modified to
pub fn swap_b_to_a(&mut self, token_b: u64) -> Option<u64> { /// follow the geometric mean, as done in Uniswap v2.
let result = SwapResult::swap_to( fn new_pool_supply(&self) -> u64 {
token_b, INITIAL_SWAP_POOL_AMOUNT
self.token_b, }
self.token_a,
self.fee_numerator, /// Simple ratio calculation to get the amount of liquidity tokens given
self.fee_denominator, /// pool information
)?; fn liquidity_tokens(
self.token_b = result.new_source_amount; &self,
self.token_a = result.new_destination_amount; pool_tokens: u64,
map_zero_to_none(result.amount_swapped) pool_token_supply: u64,
total_liquidity_tokens: u64,
) -> Option<u64> {
pool_tokens
.checked_mul(total_liquidity_tokens)?
.checked_div(pool_token_supply)
.and_then(map_zero_to_none)
} }
} }
/// Conversions for pool tokens, how much to deposit / withdraw, along with /// IsInitialized is required to use `Pack::pack` and `Pack::unpack`
/// proper initialization impl IsInitialized for ConstantProductCurve {
pub struct PoolTokenConverter { fn is_initialized(&self) -> bool {
/// Total supply true
pub supply: u64, }
/// Token A amount }
pub token_a: u64, impl Sealed for ConstantProductCurve {}
/// Token B amount impl Pack for ConstantProductCurve {
pub token_b: u64, const LEN: usize = 16;
fn unpack_from_slice(input: &[u8]) -> Result<ConstantProductCurve, ProgramError> {
let input = array_ref![input, 0, 16];
#[allow(clippy::ptr_offset_with_cast)]
let (fee_numerator, fee_denominator) = array_refs![input, 8, 8];
Ok(Self {
fee_numerator: u64::from_le_bytes(*fee_numerator),
fee_denominator: u64::from_le_bytes(*fee_denominator),
})
}
fn pack_into_slice(&self, output: &mut [u8]) {
(self as &dyn DynPack).pack_into_slice(output);
}
} }
impl PoolTokenConverter { impl DynPack for ConstantProductCurve {
/// Create a converter based on existing market information fn pack_into_slice(&self, output: &mut [u8]) {
pub fn new_existing(supply: u64, token_a: u64, token_b: u64) -> Self { let output = array_mut_ref![output, 0, 16];
Self { let (fee_numerator, fee_denominator) = mut_array_refs![output, 8, 8];
supply, *fee_numerator = self.fee_numerator.to_le_bytes();
token_a, *fee_denominator = self.fee_denominator.to_le_bytes();
token_b,
}
}
/// Create a converter for a new pool token, no supply present yet.
/// According to Uniswap, the geometric mean protects the pool creator
/// in case the initial ratio is off the market.
pub fn new_pool(token_a: u64, token_b: u64) -> Self {
let supply = INITIAL_SWAP_POOL_AMOUNT;
Self {
supply,
token_a,
token_b,
}
}
/// A tokens for pool tokens, returns None if output is less than 0
pub fn token_a_rate(&self, pool_tokens: u64) -> Option<u64> {
pool_tokens
.checked_mul(self.token_a)?
.checked_div(self.supply)
.and_then(map_zero_to_none)
}
/// B tokens for pool tokens, returns None is output is less than 0
pub fn token_b_rate(&self, pool_tokens: u64) -> Option<u64> {
pool_tokens
.checked_mul(self.token_b)?
.checked_div(self.supply)
.and_then(map_zero_to_none)
} }
} }
@ -157,48 +360,152 @@ mod tests {
#[test] #[test]
fn initial_pool_amount() { fn initial_pool_amount() {
let token_converter = PoolTokenConverter::new_pool(1, 5); let fee_numerator = 0;
assert_eq!(token_converter.supply, INITIAL_SWAP_POOL_AMOUNT); let fee_denominator = 1;
let calculator = ConstantProductCurve {
fee_numerator,
fee_denominator,
};
assert_eq!(calculator.new_pool_supply(), INITIAL_SWAP_POOL_AMOUNT);
} }
fn check_pool_token_a_rate( fn check_liquidity_pool_token_rate(
token_a: u64, token_a: u64,
token_b: u64,
deposit: u64, deposit: u64,
supply: u64, supply: u64,
expected: Option<u64>, expected: Option<u64>,
) { ) {
let calculator = PoolTokenConverter::new_existing(supply, token_a, token_b); let fee_numerator = 0;
assert_eq!(calculator.token_a_rate(deposit), expected); let fee_denominator = 1;
let calculator = ConstantProductCurve {
fee_numerator,
fee_denominator,
};
assert_eq!(
calculator.liquidity_tokens(deposit, supply, token_a),
expected
);
} }
#[test] #[test]
fn issued_tokens() { fn issued_tokens() {
check_pool_token_a_rate(2, 50, 5, 10, Some(1)); check_liquidity_pool_token_rate(2, 5, 10, Some(1));
check_pool_token_a_rate(10, 10, 5, 10, Some(5)); check_liquidity_pool_token_rate(10, 5, 10, Some(5));
check_pool_token_a_rate(5, 100, 5, 10, Some(2)); check_liquidity_pool_token_rate(5, 5, 10, Some(2));
check_pool_token_a_rate(5, u64::MAX, 5, 10, Some(2)); check_liquidity_pool_token_rate(5, 5, 10, Some(2));
check_pool_token_a_rate(u64::MAX, u64::MAX, 5, 10, None); check_liquidity_pool_token_rate(u64::MAX, 5, 10, None);
} }
#[test] #[test]
fn swap_calculation() { fn constant_product_swap_calculation() {
// calculation on https://github.com/solana-labs/solana-program-library/issues/341 // calculation on https://github.com/solana-labs/solana-program-library/issues/341
let swap_source_amount: u64 = 1000; let swap_source_amount: u64 = 1000;
let swap_destination_amount: u64 = 50000; let swap_destination_amount: u64 = 50000;
let fee_numerator: u64 = 1; let fee_numerator: u64 = 1;
let fee_denominator: u64 = 100; let fee_denominator: u64 = 100;
let source_amount: u64 = 100; let source_amount: u64 = 100;
let result = SwapResult::swap_to( let curve = ConstantProductCurve {
source_amount,
swap_source_amount,
swap_destination_amount,
fee_numerator, fee_numerator,
fee_denominator, fee_denominator,
) };
.unwrap(); let result = curve
.swap(source_amount, swap_source_amount, swap_destination_amount)
.unwrap();
assert_eq!(result.new_source_amount, 1100); assert_eq!(result.new_source_amount, 1100);
assert_eq!(result.amount_swapped, 4505); assert_eq!(result.amount_swapped, 4505);
assert_eq!(result.new_destination_amount, 45495); assert_eq!(result.new_destination_amount, 45495);
} }
#[test]
fn flat_swap_calculation() {
let swap_source_amount: u64 = 1000;
let swap_destination_amount: u64 = 50000;
let fee_numerator: u64 = 1;
let fee_denominator: u64 = 100;
let source_amount: u64 = 100;
let curve = FlatCurve {
fee_numerator,
fee_denominator,
};
let result = curve
.swap(source_amount, swap_source_amount, swap_destination_amount)
.unwrap();
let amount_swapped = 99;
assert_eq!(result.new_source_amount, 1100);
assert_eq!(result.amount_swapped, amount_swapped);
assert_eq!(
result.new_destination_amount,
swap_destination_amount - amount_swapped
);
}
#[test]
fn pack_flat_curve() {
let fee_numerator = 1;
let fee_denominator = 4;
let curve = FlatCurve {
fee_numerator,
fee_denominator,
};
let mut packed = [0u8; FlatCurve::LEN];
Pack::pack_into_slice(&curve, &mut packed[..]);
let unpacked = FlatCurve::unpack(&packed).unwrap();
assert_eq!(curve, unpacked);
let mut packed = vec![];
packed.extend_from_slice(&fee_numerator.to_le_bytes());
packed.extend_from_slice(&fee_denominator.to_le_bytes());
let unpacked = FlatCurve::unpack(&packed).unwrap();
assert_eq!(curve, unpacked);
}
#[test]
fn pack_constant_product_curve() {
let fee_numerator = 1;
let fee_denominator = 4;
let curve = ConstantProductCurve {
fee_numerator,
fee_denominator,
};
let mut packed = [0u8; ConstantProductCurve::LEN];
Pack::pack_into_slice(&curve, &mut packed[..]);
let unpacked = ConstantProductCurve::unpack(&packed).unwrap();
assert_eq!(curve, unpacked);
let mut packed = vec![];
packed.extend_from_slice(&fee_numerator.to_le_bytes());
packed.extend_from_slice(&fee_denominator.to_le_bytes());
let unpacked = ConstantProductCurve::unpack(&packed).unwrap();
assert_eq!(curve, unpacked);
}
#[test]
fn pack_swap_curve() {
let fee_numerator = 1;
let fee_denominator = 4;
let curve = ConstantProductCurve {
fee_numerator,
fee_denominator,
};
let curve_type = CurveType::ConstantProduct;
let swap_curve = SwapCurve {
curve_type,
calculator: Box::new(curve),
};
let mut packed = [0u8; SwapCurve::LEN];
Pack::pack_into_slice(&swap_curve, &mut packed[..]);
let unpacked = SwapCurve::unpack_from_slice(&packed).unwrap();
assert_eq!(swap_curve, unpacked);
let mut packed = vec![];
packed.push(curve_type as u8);
packed.extend_from_slice(&fee_numerator.to_le_bytes());
packed.extend_from_slice(&fee_denominator.to_le_bytes());
packed.extend_from_slice(&[0u8; 48]); // padding
let unpacked = SwapCurve::unpack_from_slice(&packed).unwrap();
assert_eq!(swap_curve, unpacked);
}
} }

View File

@ -2,10 +2,12 @@
#![allow(clippy::too_many_arguments)] #![allow(clippy::too_many_arguments)]
use crate::curve::{ConstantProductCurve, CurveType, FlatCurve, SwapCurve};
use crate::error::SwapError; use crate::error::SwapError;
use solana_sdk::{ use solana_sdk::{
instruction::{AccountMeta, Instruction}, instruction::{AccountMeta, Instruction},
program_error::ProgramError, program_error::ProgramError,
program_pack::Pack,
pubkey::Pubkey, pubkey::Pubkey,
}; };
use std::convert::TryInto; use std::convert::TryInto;
@ -13,7 +15,7 @@ use std::mem::size_of;
/// Instructions supported by the SwapInfo program. /// Instructions supported by the SwapInfo program.
#[repr(C)] #[repr(C)]
#[derive(Clone, Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub enum SwapInstruction { pub enum SwapInstruction {
/// Initializes a new SwapInfo. /// Initializes a new SwapInfo.
/// ///
@ -25,12 +27,11 @@ pub enum SwapInstruction {
/// 5. `[writable]` Pool Token Account to deposit the minted tokens. Must be empty, owned by user. /// 5. `[writable]` Pool Token Account to deposit the minted tokens. Must be empty, owned by user.
/// 6. '[]` Token program id /// 6. '[]` Token program id
Initialize { Initialize {
/// swap pool fee numerator
fee_numerator: u64,
/// swap pool fee denominator
fee_denominator: u64,
/// nonce used to create valid program address /// nonce used to create valid program address
nonce: u8, nonce: u8,
/// swap curve info for pool, including CurveType, fees, and anything
/// else that may be required
swap_curve: SwapCurve,
}, },
/// Swap the tokens in the pool. /// Swap the tokens in the pool.
@ -99,14 +100,9 @@ impl SwapInstruction {
let (&tag, rest) = input.split_first().ok_or(SwapError::InvalidInstruction)?; let (&tag, rest) = input.split_first().ok_or(SwapError::InvalidInstruction)?;
Ok(match tag { Ok(match tag {
0 => { 0 => {
let (fee_numerator, rest) = Self::unpack_u64(rest)?; let (&nonce, rest) = rest.split_first().ok_or(SwapError::InvalidInstruction)?;
let (fee_denominator, rest) = Self::unpack_u64(rest)?; let swap_curve = SwapCurve::unpack_unchecked(rest)?;
let (&nonce, _rest) = rest.split_first().ok_or(SwapError::InvalidInstruction)?; Self::Initialize { nonce, swap_curve }
Self::Initialize {
fee_numerator,
fee_denominator,
nonce,
}
} }
1 => { 1 => {
let (amount_in, rest) = Self::unpack_u64(rest)?; let (amount_in, rest) = Self::unpack_u64(rest)?;
@ -157,16 +153,13 @@ impl SwapInstruction {
/// Packs a [SwapInstruction](enum.SwapInstruction.html) into a byte buffer. /// Packs a [SwapInstruction](enum.SwapInstruction.html) into a byte buffer.
pub fn pack(&self) -> Vec<u8> { pub fn pack(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(size_of::<Self>()); let mut buf = Vec::with_capacity(size_of::<Self>());
match *self { match &*self {
Self::Initialize { Self::Initialize { nonce, swap_curve } => {
fee_numerator,
fee_denominator,
nonce,
} => {
buf.push(0); buf.push(0);
buf.extend_from_slice(&fee_numerator.to_le_bytes()); buf.push(*nonce);
buf.extend_from_slice(&fee_denominator.to_le_bytes()); let mut swap_curve_slice = [0u8; SwapCurve::LEN];
buf.push(nonce); Pack::pack_into_slice(swap_curve, &mut swap_curve_slice[..]);
buf.extend_from_slice(&swap_curve_slice);
} }
Self::Swap { Self::Swap {
amount_in, amount_in,
@ -212,14 +205,24 @@ pub fn initialize(
pool_pubkey: &Pubkey, pool_pubkey: &Pubkey,
destination_pubkey: &Pubkey, destination_pubkey: &Pubkey,
nonce: u8, nonce: u8,
curve_type: CurveType,
fee_numerator: u64, fee_numerator: u64,
fee_denominator: u64, fee_denominator: u64,
) -> Result<Instruction, ProgramError> { ) -> Result<Instruction, ProgramError> {
let init_data = SwapInstruction::Initialize { let swap_curve = SwapCurve {
fee_numerator, curve_type,
fee_denominator, calculator: match curve_type {
nonce, CurveType::ConstantProduct => Box::new(ConstantProductCurve {
fee_numerator,
fee_denominator,
}),
CurveType::Flat => Box::new(FlatCurve {
fee_numerator,
fee_denominator,
}),
},
}; };
let init_data = SwapInstruction::Initialize { nonce, swap_curve };
let data = init_data.pack(); let data = init_data.pack();
let accounts = vec![ let accounts = vec![
@ -379,17 +382,24 @@ mod tests {
let fee_numerator: u64 = 1; let fee_numerator: u64 = 1;
let fee_denominator: u64 = 4; let fee_denominator: u64 = 4;
let nonce: u8 = 255; let nonce: u8 = 255;
let check = SwapInstruction::Initialize { let curve_type = CurveType::Flat;
let calculator = Box::new(FlatCurve {
fee_numerator, fee_numerator,
fee_denominator, fee_denominator,
nonce, });
let swap_curve = SwapCurve {
curve_type,
calculator,
}; };
let check = SwapInstruction::Initialize { nonce, swap_curve };
let packed = check.pack(); let packed = check.pack();
let mut expect = vec![]; let mut expect = vec![];
expect.push(0 as u8); expect.push(0 as u8);
expect.push(nonce);
expect.push(curve_type as u8);
expect.extend_from_slice(&fee_numerator.to_le_bytes()); expect.extend_from_slice(&fee_numerator.to_le_bytes());
expect.extend_from_slice(&fee_denominator.to_le_bytes()); expect.extend_from_slice(&fee_denominator.to_le_bytes());
expect.push(nonce); expect.extend_from_slice(&[0u8; 48]); // padding
assert_eq!(packed, expect); assert_eq!(packed, expect);
let unpacked = SwapInstruction::unpack(&expect).unwrap(); let unpacked = SwapInstruction::unpack(&expect).unwrap();
assert_eq!(unpacked, check); assert_eq!(unpacked, check);

View File

@ -2,12 +2,7 @@
#![cfg(feature = "program")] #![cfg(feature = "program")]
use crate::{ use crate::{curve::SwapCurve, error::SwapError, instruction::SwapInstruction, state::SwapInfo};
curve::{ConstantProduct, PoolTokenConverter},
error::SwapError,
instruction::SwapInstruction,
state::SwapInfo,
};
use num_traits::FromPrimitive; use num_traits::FromPrimitive;
#[cfg(not(target_arch = "bpf"))] #[cfg(not(target_arch = "bpf"))]
use solana_sdk::instruction::Instruction; use solana_sdk::instruction::Instruction;
@ -142,8 +137,7 @@ impl Processor {
pub fn process_initialize( pub fn process_initialize(
program_id: &Pubkey, program_id: &Pubkey,
nonce: u8, nonce: u8,
fee_numerator: u64, swap_curve: SwapCurve,
fee_denominator: u64,
accounts: &[AccountInfo], accounts: &[AccountInfo],
) -> ProgramResult { ) -> ProgramResult {
let account_info_iter = &mut accounts.iter(); let account_info_iter = &mut accounts.iter();
@ -198,8 +192,7 @@ impl Processor {
return Err(SwapError::InvalidSupply.into()); return Err(SwapError::InvalidSupply.into());
} }
let converter = PoolTokenConverter::new_pool(token_a.amount, token_b.amount); let initial_amount = swap_curve.calculator.new_pool_supply();
let initial_amount = converter.supply;
Self::token_mint_to( Self::token_mint_to(
swap_info.key, swap_info.key,
@ -218,8 +211,7 @@ impl Processor {
token_a: *token_a_info.key, token_a: *token_a_info.key,
token_b: *token_b_info.key, token_b: *token_b_info.key,
pool_mint: *pool_mint_info.key, pool_mint: *pool_mint_info.key,
fee_numerator, swap_curve,
fee_denominator,
}; };
SwapInfo::pack(obj, &mut swap_info.data.borrow_mut())?; SwapInfo::pack(obj, &mut swap_info.data.borrow_mut())?;
Ok(()) Ok(())
@ -262,28 +254,12 @@ impl Processor {
let source_account = Self::unpack_token_account(&swap_source_info.data.borrow())?; let source_account = Self::unpack_token_account(&swap_source_info.data.borrow())?;
let dest_account = Self::unpack_token_account(&swap_destination_info.data.borrow())?; let dest_account = Self::unpack_token_account(&swap_destination_info.data.borrow())?;
let amount_out = if *swap_source_info.key == token_swap.token_a { let result = token_swap
let mut invariant = ConstantProduct { .swap_curve
token_a: source_account.amount, .calculator
token_b: dest_account.amount, .swap(amount_in, source_account.amount, dest_account.amount)
fee_numerator: token_swap.fee_numerator, .ok_or(SwapError::CalculationFailure)?;
fee_denominator: token_swap.fee_denominator, if result.amount_swapped < minimum_amount_out {
};
invariant
.swap_a_to_b(amount_in)
.ok_or(SwapError::CalculationFailure)?
} else {
let mut invariant = ConstantProduct {
token_a: dest_account.amount,
token_b: source_account.amount,
fee_numerator: token_swap.fee_numerator,
fee_denominator: token_swap.fee_denominator,
};
invariant
.swap_b_to_a(amount_in)
.ok_or(SwapError::CalculationFailure)?
};
if amount_out < minimum_amount_out {
return Err(SwapError::ExceededSlippage.into()); return Err(SwapError::ExceededSlippage.into());
} }
Self::token_transfer( Self::token_transfer(
@ -302,7 +278,7 @@ impl Processor {
destination_info.clone(), destination_info.clone(),
authority_info.clone(), authority_info.clone(),
token_swap.nonce, token_swap.nonce,
amount_out, result.amount_swapped,
)?; )?;
Ok(()) Ok(())
} }
@ -344,17 +320,16 @@ impl Processor {
let token_b = Self::unpack_token_account(&token_b_info.data.borrow())?; let token_b = Self::unpack_token_account(&token_b_info.data.borrow())?;
let pool_mint = Self::unpack_mint(&pool_mint_info.data.borrow())?; let pool_mint = Self::unpack_mint(&pool_mint_info.data.borrow())?;
let converter = let calculator = token_swap.swap_curve.calculator;
PoolTokenConverter::new_existing(pool_mint.supply, token_a.amount, token_b.amount);
let a_amount = converter let a_amount = calculator
.token_a_rate(pool_token_amount) .liquidity_tokens(pool_token_amount, pool_mint.supply, token_a.amount)
.ok_or(SwapError::CalculationFailure)?; .ok_or(SwapError::CalculationFailure)?;
if a_amount > maximum_token_a_amount { if a_amount > maximum_token_a_amount {
return Err(SwapError::ExceededSlippage.into()); return Err(SwapError::ExceededSlippage.into());
} }
let b_amount = converter let b_amount = calculator
.token_b_rate(pool_token_amount) .liquidity_tokens(pool_token_amount, pool_mint.supply, token_b.amount)
.ok_or(SwapError::CalculationFailure)?; .ok_or(SwapError::CalculationFailure)?;
if b_amount > maximum_token_b_amount { if b_amount > maximum_token_b_amount {
return Err(SwapError::ExceededSlippage.into()); return Err(SwapError::ExceededSlippage.into());
@ -428,17 +403,16 @@ impl Processor {
let token_b = Self::unpack_token_account(&token_b_info.data.borrow())?; let token_b = Self::unpack_token_account(&token_b_info.data.borrow())?;
let pool_mint = Self::unpack_mint(&pool_mint_info.data.borrow())?; let pool_mint = Self::unpack_mint(&pool_mint_info.data.borrow())?;
let converter = let calculator = token_swap.swap_curve.calculator;
PoolTokenConverter::new_existing(pool_mint.supply, token_a.amount, token_b.amount);
let a_amount = converter let a_amount = calculator
.token_a_rate(pool_token_amount) .liquidity_tokens(pool_token_amount, pool_mint.supply, token_a.amount)
.ok_or(SwapError::CalculationFailure)?; .ok_or(SwapError::CalculationFailure)?;
if a_amount < minimum_token_a_amount { if a_amount < minimum_token_a_amount {
return Err(SwapError::ExceededSlippage.into()); return Err(SwapError::ExceededSlippage.into());
} }
let b_amount = converter let b_amount = calculator
.token_b_rate(pool_token_amount) .liquidity_tokens(pool_token_amount, pool_mint.supply, token_b.amount)
.ok_or(SwapError::CalculationFailure)?; .ok_or(SwapError::CalculationFailure)?;
if b_amount < minimum_token_b_amount { if b_amount < minimum_token_b_amount {
return Err(SwapError::ExceededSlippage.into()); return Err(SwapError::ExceededSlippage.into());
@ -478,19 +452,9 @@ impl Processor {
pub fn process(program_id: &Pubkey, accounts: &[AccountInfo], input: &[u8]) -> ProgramResult { pub fn process(program_id: &Pubkey, accounts: &[AccountInfo], input: &[u8]) -> ProgramResult {
let instruction = SwapInstruction::unpack(input)?; let instruction = SwapInstruction::unpack(input)?;
match instruction { match instruction {
SwapInstruction::Initialize { SwapInstruction::Initialize { nonce, swap_curve } => {
fee_numerator,
fee_denominator,
nonce,
} => {
info!("Instruction: Init"); info!("Instruction: Init");
Self::process_initialize( Self::process_initialize(program_id, nonce, swap_curve, accounts)
program_id,
nonce,
fee_numerator,
fee_denominator,
accounts,
)
} }
SwapInstruction::Swap { SwapInstruction::Swap {
amount_in, amount_in,
@ -618,7 +582,9 @@ solana_sdk::program_stubs!();
mod tests { mod tests {
use super::*; use super::*;
use crate::{ use crate::{
curve::{SwapResult, INITIAL_SWAP_POOL_AMOUNT}, curve::{
ConstantProductCurve, CurveCalculator, CurveType, FlatCurve, INITIAL_SWAP_POOL_AMOUNT,
},
instruction::{deposit, initialize, swap, withdraw}, instruction::{deposit, initialize, swap, withdraw},
}; };
use solana_sdk::{ use solana_sdk::{
@ -634,6 +600,7 @@ mod tests {
struct SwapAccountInfo { struct SwapAccountInfo {
nonce: u8, nonce: u8,
curve_type: CurveType,
authority_key: Pubkey, authority_key: Pubkey,
fee_numerator: u64, fee_numerator: u64,
fee_denominator: u64, fee_denominator: u64,
@ -656,6 +623,7 @@ mod tests {
impl SwapAccountInfo { impl SwapAccountInfo {
pub fn new( pub fn new(
user_key: &Pubkey, user_key: &Pubkey,
curve_type: CurveType,
fee_numerator: u64, fee_numerator: u64,
fee_denominator: u64, fee_denominator: u64,
token_a_amount: u64, token_a_amount: u64,
@ -699,6 +667,7 @@ mod tests {
SwapAccountInfo { SwapAccountInfo {
nonce, nonce,
curve_type,
authority_key, authority_key,
fee_numerator, fee_numerator,
fee_denominator, fee_denominator,
@ -731,6 +700,7 @@ mod tests {
&self.pool_mint_key, &self.pool_mint_key,
&self.pool_token_key, &self.pool_token_key,
self.nonce, self.nonce,
self.curve_type,
self.fee_numerator, self.fee_numerator,
self.fee_denominator, self.fee_denominator,
) )
@ -1179,9 +1149,11 @@ mod tests {
let token_a_amount = 1000; let token_a_amount = 1000;
let token_b_amount = 2000; let token_b_amount = 2000;
let pool_token_amount = 10; let pool_token_amount = 10;
let curve_type = CurveType::ConstantProduct;
let mut accounts = SwapAccountInfo::new( let mut accounts = SwapAccountInfo::new(
&user_key, &user_key,
curve_type,
fee_numerator, fee_numerator,
fee_denominator, fee_denominator,
token_a_amount, token_a_amount,
@ -1474,6 +1446,7 @@ mod tests {
&accounts.pool_mint_key, &accounts.pool_mint_key,
&accounts.pool_token_key, &accounts.pool_token_key,
accounts.nonce, accounts.nonce,
accounts.curve_type,
accounts.fee_numerator, accounts.fee_numerator,
accounts.fee_denominator, accounts.fee_denominator,
) )
@ -1513,6 +1486,19 @@ mod tests {
// create valid swap // create valid swap
accounts.initialize_swap().unwrap(); accounts.initialize_swap().unwrap();
// create valid flat swap
{
let mut accounts = SwapAccountInfo::new(
&user_key,
CurveType::Flat,
fee_numerator,
fee_denominator,
token_a_amount,
token_b_amount,
);
accounts.initialize_swap().unwrap();
}
// create again // create again
{ {
assert_eq!( assert_eq!(
@ -1523,11 +1509,10 @@ mod tests {
let swap_info = SwapInfo::unpack(&accounts.swap_account.data).unwrap(); let swap_info = SwapInfo::unpack(&accounts.swap_account.data).unwrap();
assert_eq!(swap_info.is_initialized, true); assert_eq!(swap_info.is_initialized, true);
assert_eq!(swap_info.nonce, accounts.nonce); assert_eq!(swap_info.nonce, accounts.nonce);
assert_eq!(swap_info.swap_curve.curve_type, accounts.curve_type);
assert_eq!(swap_info.token_a, accounts.token_a_key); assert_eq!(swap_info.token_a, accounts.token_a_key);
assert_eq!(swap_info.token_b, accounts.token_b_key); assert_eq!(swap_info.token_b, accounts.token_b_key);
assert_eq!(swap_info.pool_mint, accounts.pool_mint_key); assert_eq!(swap_info.pool_mint, accounts.pool_mint_key);
assert_eq!(swap_info.fee_denominator, fee_denominator);
assert_eq!(swap_info.fee_numerator, fee_numerator);
let token_a = Processor::unpack_token_account(&accounts.token_a_account.data).unwrap(); let token_a = Processor::unpack_token_account(&accounts.token_a_account.data).unwrap();
assert_eq!(token_a.amount, token_a_amount); assert_eq!(token_a.amount, token_a_amount);
let token_b = Processor::unpack_token_account(&accounts.token_b_account.data).unwrap(); let token_b = Processor::unpack_token_account(&accounts.token_b_account.data).unwrap();
@ -1546,8 +1531,11 @@ mod tests {
let fee_denominator = 2; let fee_denominator = 2;
let token_a_amount = 1000; let token_a_amount = 1000;
let token_b_amount = 9000; let token_b_amount = 9000;
let curve_type = CurveType::ConstantProduct;
let mut accounts = SwapAccountInfo::new( let mut accounts = SwapAccountInfo::new(
&user_key, &user_key,
curve_type,
fee_numerator, fee_numerator,
fee_denominator, fee_denominator,
token_a_amount, token_a_amount,
@ -2063,18 +2051,24 @@ mod tests {
let fee_denominator = 2; let fee_denominator = 2;
let token_a_amount = 1000; let token_a_amount = 1000;
let token_b_amount = 2000; let token_b_amount = 2000;
let curve_type = CurveType::ConstantProduct;
let mut accounts = SwapAccountInfo::new( let mut accounts = SwapAccountInfo::new(
&user_key, &user_key,
curve_type,
fee_numerator, fee_numerator,
fee_denominator, fee_denominator,
token_a_amount, token_a_amount,
token_b_amount, token_b_amount,
); );
let withdrawer_key = pubkey_rand(); let withdrawer_key = pubkey_rand();
let pool_converter = PoolTokenConverter::new_pool(token_a_amount, token_b_amount); let calculator = ConstantProductCurve {
fee_numerator,
fee_denominator,
};
let initial_a = token_a_amount / 10; let initial_a = token_a_amount / 10;
let initial_b = token_b_amount / 10; let initial_b = token_b_amount / 10;
let initial_pool = pool_converter.supply / 10; let initial_pool = calculator.new_pool_supply() / 10;
let withdraw_amount = initial_pool / 4; let withdraw_amount = initial_pool / 4;
let minimum_a_amount = initial_a / 40; let minimum_a_amount = initial_a / 40;
let minimum_b_amount = initial_b / 40; let minimum_b_amount = initial_b / 40;
@ -2583,15 +2577,18 @@ mod tests {
let swap_token_b = let swap_token_b =
Processor::unpack_token_account(&accounts.token_b_account.data).unwrap(); Processor::unpack_token_account(&accounts.token_b_account.data).unwrap();
let pool_mint = Processor::unpack_mint(&accounts.pool_mint_account.data).unwrap(); let pool_mint = Processor::unpack_mint(&accounts.pool_mint_account.data).unwrap();
let pool_converter = PoolTokenConverter::new_existing( let calculator = ConstantProductCurve {
pool_mint.supply, fee_numerator,
swap_token_a.amount, fee_denominator,
swap_token_b.amount, };
);
let withdrawn_a = pool_converter.token_a_rate(withdraw_amount).unwrap(); let withdrawn_a = calculator
.liquidity_tokens(withdraw_amount, pool_mint.supply, swap_token_a.amount)
.unwrap();
assert_eq!(swap_token_a.amount, token_a_amount - withdrawn_a); assert_eq!(swap_token_a.amount, token_a_amount - withdrawn_a);
let withdrawn_b = pool_converter.token_b_rate(withdraw_amount).unwrap(); let withdrawn_b = calculator
.liquidity_tokens(withdraw_amount, pool_mint.supply, swap_token_b.amount)
.unwrap();
assert_eq!(swap_token_b.amount, token_b_amount - withdrawn_b); assert_eq!(swap_token_b.amount, token_b_amount - withdrawn_b);
let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap(); let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap();
assert_eq!(token_a.amount, initial_a + withdrawn_a); assert_eq!(token_a.amount, initial_a + withdrawn_a);
@ -2602,16 +2599,150 @@ mod tests {
} }
} }
#[test] fn check_valid_swap_curve(curve_type: CurveType, calculator: Box<dyn CurveCalculator>) {
fn test_swap() {
let user_key = pubkey_rand(); let user_key = pubkey_rand();
let swapper_key = pubkey_rand(); let swapper_key = pubkey_rand();
let fee_numerator = 1; let fee_numerator = 1;
let fee_denominator = 10; let fee_denominator = 10;
let token_a_amount = 1000; let token_a_amount = 1000;
let token_b_amount = 5000; let token_b_amount = 5000;
let swap_curve = SwapCurve {
curve_type,
calculator,
};
let mut accounts = SwapAccountInfo::new( let mut accounts = SwapAccountInfo::new(
&user_key, &user_key,
curve_type,
fee_numerator,
fee_denominator,
token_a_amount,
token_b_amount,
);
let initial_a = token_a_amount / 5;
let initial_b = token_b_amount / 5;
accounts.initialize_swap().unwrap();
let swap_token_a_key = accounts.token_a_key;
let swap_token_b_key = accounts.token_b_key;
let (
token_a_key,
mut token_a_account,
token_b_key,
mut token_b_account,
_pool_key,
_pool_account,
) = accounts.setup_token_accounts(&user_key, &swapper_key, initial_a, initial_b, 0);
// swap one way
let a_to_b_amount = initial_a / 10;
let minimum_b_amount = 0;
accounts
.swap(
&swapper_key,
&token_a_key,
&mut token_a_account,
&swap_token_a_key,
&swap_token_b_key,
&token_b_key,
&mut token_b_account,
a_to_b_amount,
minimum_b_amount,
)
.unwrap();
let results = swap_curve
.calculator
.swap(a_to_b_amount, token_a_amount, token_b_amount)
.unwrap();
let swap_token_a = Processor::unpack_token_account(&accounts.token_a_account.data).unwrap();
let token_a_amount = swap_token_a.amount;
assert_eq!(token_a_amount, results.new_source_amount);
let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap();
assert_eq!(token_a.amount, initial_a - a_to_b_amount);
let swap_token_b = Processor::unpack_token_account(&accounts.token_b_account.data).unwrap();
let token_b_amount = swap_token_b.amount;
assert_eq!(token_b_amount, results.new_destination_amount);
let token_b = Processor::unpack_token_account(&token_b_account.data).unwrap();
assert_eq!(token_b.amount, initial_b + results.amount_swapped);
let first_swap_amount = results.amount_swapped;
// swap the other way
let b_to_a_amount = initial_b / 10;
let minimum_a_amount = 0;
accounts
.swap(
&swapper_key,
&token_b_key,
&mut token_b_account,
&swap_token_b_key,
&swap_token_a_key,
&token_a_key,
&mut token_a_account,
b_to_a_amount,
minimum_a_amount,
)
.unwrap();
let results = swap_curve
.calculator
.swap(b_to_a_amount, token_b_amount, token_a_amount)
.unwrap();
let swap_token_a = Processor::unpack_token_account(&accounts.token_a_account.data).unwrap();
assert_eq!(swap_token_a.amount, results.new_destination_amount);
let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap();
assert_eq!(
token_a.amount,
initial_a - a_to_b_amount + results.amount_swapped
);
let swap_token_b = Processor::unpack_token_account(&accounts.token_b_account.data).unwrap();
assert_eq!(swap_token_b.amount, results.new_source_amount);
let token_b = Processor::unpack_token_account(&token_b_account.data).unwrap();
assert_eq!(
token_b.amount,
initial_b + first_swap_amount - b_to_a_amount
);
}
#[test]
fn test_valid_swap_curves() {
let fee_numerator = 1;
let fee_denominator = 10;
check_valid_swap_curve(
CurveType::ConstantProduct,
Box::new(ConstantProductCurve {
fee_numerator,
fee_denominator,
}),
);
check_valid_swap_curve(
CurveType::Flat,
Box::new(FlatCurve {
fee_numerator,
fee_denominator,
}),
);
}
#[test]
fn test_invalid_swap() {
let user_key = pubkey_rand();
let swapper_key = pubkey_rand();
let fee_numerator = 1;
let fee_denominator = 10;
let token_a_amount = 1000;
let token_b_amount = 5000;
let curve_type = CurveType::ConstantProduct;
let mut accounts = SwapAccountInfo::new(
&user_key,
curve_type,
fee_numerator, fee_numerator,
fee_denominator, fee_denominator,
token_a_amount, token_a_amount,
@ -2932,102 +3063,5 @@ mod tests {
) )
); );
} }
// correct swap
{
let (
token_a_key,
mut token_a_account,
token_b_key,
mut token_b_account,
_pool_key,
_pool_account,
) = accounts.setup_token_accounts(&user_key, &swapper_key, initial_a, initial_b, 0);
// swap one way
let a_to_b_amount = initial_a / 10;
let minimum_b_amount = initial_b / 20;
accounts
.swap(
&swapper_key,
&token_a_key,
&mut token_a_account,
&swap_token_a_key,
&swap_token_b_key,
&token_b_key,
&mut token_b_account,
a_to_b_amount,
minimum_b_amount,
)
.unwrap();
let results = SwapResult::swap_to(
a_to_b_amount,
token_a_amount,
token_b_amount,
fee_numerator,
fee_denominator,
)
.unwrap();
let swap_token_a =
Processor::unpack_token_account(&accounts.token_a_account.data).unwrap();
let token_a_amount = swap_token_a.amount;
assert_eq!(token_a_amount, results.new_source_amount);
let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap();
assert_eq!(token_a.amount, initial_a - a_to_b_amount);
let swap_token_b =
Processor::unpack_token_account(&accounts.token_b_account.data).unwrap();
let token_b_amount = swap_token_b.amount;
assert_eq!(token_b_amount, results.new_destination_amount);
let token_b = Processor::unpack_token_account(&token_b_account.data).unwrap();
assert_eq!(token_b.amount, initial_b + results.amount_swapped);
let first_swap_amount = results.amount_swapped;
// swap the other way
let b_to_a_amount = initial_b / 10;
let minimum_a_amount = initial_a / 20;
accounts
.swap(
&swapper_key,
&token_b_key,
&mut token_b_account,
&swap_token_b_key,
&swap_token_a_key,
&token_a_key,
&mut token_a_account,
b_to_a_amount,
minimum_a_amount,
)
.unwrap();
let results = SwapResult::swap_to(
b_to_a_amount,
token_b_amount,
token_a_amount,
fee_numerator,
fee_denominator,
)
.unwrap();
let swap_token_a =
Processor::unpack_token_account(&accounts.token_a_account.data).unwrap();
assert_eq!(swap_token_a.amount, results.new_destination_amount);
let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap();
assert_eq!(
token_a.amount,
initial_a - a_to_b_amount + results.amount_swapped
);
let swap_token_b =
Processor::unpack_token_account(&accounts.token_b_account.data).unwrap();
assert_eq!(swap_token_b.amount, results.new_source_amount);
let token_b = Processor::unpack_token_account(&token_b_account.data).unwrap();
assert_eq!(
token_b.amount,
initial_b + first_swap_amount - b_to_a_amount
);
}
} }
} }

View File

@ -1,5 +1,6 @@
//! State transition types //! State transition types
use crate::curve::SwapCurve;
use arrayref::{array_mut_ref, array_ref, array_refs, mut_array_refs}; use arrayref::{array_mut_ref, array_ref, array_refs, mut_array_refs};
use solana_sdk::{ use solana_sdk::{
program_error::ProgramError, program_error::ProgramError,
@ -9,7 +10,7 @@ use solana_sdk::{
/// Program states. /// Program states.
#[repr(C)] #[repr(C)]
#[derive(Clone, Copy, Debug, Default, PartialEq)] #[derive(Debug, Default, PartialEq)]
pub struct SwapInfo { pub struct SwapInfo {
/// Initialized state. /// Initialized state.
pub is_initialized: bool, pub is_initialized: bool,
@ -32,10 +33,9 @@ pub struct SwapInfo {
/// Pool tokens can be withdrawn back to the original A or B token. /// Pool tokens can be withdrawn back to the original A or B token.
pub pool_mint: Pubkey, pub pool_mint: Pubkey,
/// Numerator of fee applied to the input token amount prior to output calculation. /// Swap curve parameters, to be unpacked and used by the SwapCurve, which
pub fee_numerator: u64, /// calculates swaps, deposits, and withdrawals
/// Denominator of fee applied to the input token amount prior to output calculation. pub swap_curve: SwapCurve,
pub fee_denominator: u64,
} }
impl Sealed for SwapInfo {} impl Sealed for SwapInfo {}
@ -46,22 +46,14 @@ impl IsInitialized for SwapInfo {
} }
impl Pack for SwapInfo { impl Pack for SwapInfo {
const LEN: usize = 146; const LEN: usize = 195;
/// Unpacks a byte buffer into a [SwapInfo](struct.SwapInfo.html). /// Unpacks a byte buffer into a [SwapInfo](struct.SwapInfo.html).
fn unpack_from_slice(input: &[u8]) -> Result<Self, ProgramError> { fn unpack_from_slice(input: &[u8]) -> Result<Self, ProgramError> {
let input = array_ref![input, 0, 146]; let input = array_ref![input, 0, 195];
#[allow(clippy::ptr_offset_with_cast)] #[allow(clippy::ptr_offset_with_cast)]
let ( let (is_initialized, nonce, token_program_id, token_a, token_b, pool_mint, swap_curve) =
is_initialized, array_refs![input, 1, 1, 32, 32, 32, 32, 65];
nonce,
token_program_id,
token_a,
token_b,
pool_mint,
fee_numerator,
fee_denominator,
) = array_refs![input, 1, 1, 32, 32, 32, 32, 8, 8];
Ok(Self { Ok(Self {
is_initialized: match is_initialized { is_initialized: match is_initialized {
[0] => false, [0] => false,
@ -73,41 +65,36 @@ impl Pack for SwapInfo {
token_a: Pubkey::new_from_array(*token_a), token_a: Pubkey::new_from_array(*token_a),
token_b: Pubkey::new_from_array(*token_b), token_b: Pubkey::new_from_array(*token_b),
pool_mint: Pubkey::new_from_array(*pool_mint), pool_mint: Pubkey::new_from_array(*pool_mint),
fee_numerator: u64::from_le_bytes(*fee_numerator), swap_curve: SwapCurve::unpack_from_slice(swap_curve)?,
fee_denominator: u64::from_le_bytes(*fee_denominator),
}) })
} }
fn pack_into_slice(&self, output: &mut [u8]) { fn pack_into_slice(&self, output: &mut [u8]) {
let output = array_mut_ref![output, 0, 146]; let output = array_mut_ref![output, 0, 195];
let ( let (is_initialized, nonce, token_program_id, token_a, token_b, pool_mint, swap_curve) =
is_initialized, mut_array_refs![output, 1, 1, 32, 32, 32, 32, 65];
nonce,
token_program_id,
token_a,
token_b,
pool_mint,
fee_numerator,
fee_denominator,
) = mut_array_refs![output, 1, 1, 32, 32, 32, 32, 8, 8];
is_initialized[0] = self.is_initialized as u8; is_initialized[0] = self.is_initialized as u8;
nonce[0] = self.nonce; nonce[0] = self.nonce;
token_program_id.copy_from_slice(self.token_program_id.as_ref()); token_program_id.copy_from_slice(self.token_program_id.as_ref());
token_a.copy_from_slice(self.token_a.as_ref()); token_a.copy_from_slice(self.token_a.as_ref());
token_b.copy_from_slice(self.token_b.as_ref()); token_b.copy_from_slice(self.token_b.as_ref());
pool_mint.copy_from_slice(self.pool_mint.as_ref()); pool_mint.copy_from_slice(self.pool_mint.as_ref());
*fee_numerator = self.fee_numerator.to_le_bytes(); self.swap_curve.pack_into_slice(&mut swap_curve[..]);
*fee_denominator = self.fee_denominator.to_le_bytes();
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::curve::FlatCurve;
use std::convert::TryInto;
#[test] #[test]
fn test_swap_info_packing() { fn test_swap_info_packing() {
let nonce = 255; let nonce = 255;
let curve_type_raw: u8 = 1;
let curve_type = curve_type_raw.try_into().unwrap();
let token_program_id_raw = [1u8; 32]; let token_program_id_raw = [1u8; 32];
let token_a_raw = [1u8; 32]; let token_a_raw = [1u8; 32];
let token_b_raw = [2u8; 32]; let token_b_raw = [2u8; 32];
@ -118,6 +105,14 @@ mod tests {
let pool_mint = Pubkey::new_from_array(pool_mint_raw); let pool_mint = Pubkey::new_from_array(pool_mint_raw);
let fee_numerator = 1; let fee_numerator = 1;
let fee_denominator = 4; let fee_denominator = 4;
let calculator = Box::new(FlatCurve {
fee_numerator,
fee_denominator,
});
let swap_curve = SwapCurve {
curve_type,
calculator,
};
let is_initialized = true; let is_initialized = true;
let swap_info = SwapInfo { let swap_info = SwapInfo {
is_initialized, is_initialized,
@ -126,12 +121,11 @@ mod tests {
token_a, token_a,
token_b, token_b,
pool_mint, pool_mint,
fee_numerator, swap_curve,
fee_denominator,
}; };
let mut packed = [0u8; SwapInfo::LEN]; let mut packed = [0u8; SwapInfo::LEN];
SwapInfo::pack(swap_info, &mut packed).unwrap(); SwapInfo::pack_into_slice(&swap_info, &mut packed);
let unpacked = SwapInfo::unpack(&packed).unwrap(); let unpacked = SwapInfo::unpack(&packed).unwrap();
assert_eq!(swap_info, unpacked); assert_eq!(swap_info, unpacked);
@ -142,10 +136,12 @@ mod tests {
packed.extend_from_slice(&token_a_raw); packed.extend_from_slice(&token_a_raw);
packed.extend_from_slice(&token_b_raw); packed.extend_from_slice(&token_b_raw);
packed.extend_from_slice(&pool_mint_raw); packed.extend_from_slice(&pool_mint_raw);
packed.push(curve_type_raw);
packed.push(fee_numerator as u8); packed.push(fee_numerator as u8);
packed.extend_from_slice(&[0u8; 7]); // padding packed.extend_from_slice(&[0u8; 7]); // padding
packed.push(fee_denominator as u8); packed.push(fee_denominator as u8);
packed.extend_from_slice(&[0u8; 7]); // padding packed.extend_from_slice(&[0u8; 7]); // padding
packed.extend_from_slice(&[0u8; 48]); // padding
let unpacked = SwapInfo::unpack(&packed).unwrap(); let unpacked = SwapInfo::unpack(&packed).unwrap();
assert_eq!(swap_info, unpacked); assert_eq!(swap_info, unpacked);