From b0867c7e28cc03f51e6a3558d46f1fbae0bf5cc9 Mon Sep 17 00:00:00 2001 From: Jon Cinque Date: Fri, 23 Oct 2020 18:31:58 +0200 Subject: [PATCH] token-swap: Add fee account to receive withdraw / trading fees, trading token mints (#695) * Add mints to swap info * Add mints to JS * Add fee account in SwapInfo / init * Add test for 0 fee, init test fully * Add withdraw command interface * Add fee accounts to swap instruction * Add calculations for swap and withdraw fees * Run cargo fmt * Add new fees to JS and test * Review feedback: fixup instruction doc and clone * Update order of accounts in instructions * Run cargo fmt * Fix owner fee pool token calculation to include trading fee * Add owner fees to flat curve, per request * Fix instruction comment numbering * Add more errors types for clearer calculation errors Add a check for withdrawing from fee account * Cargo fmt --- token-swap/js/cli/token-swap-test.js | 109 +++-- token-swap/js/client/token-swap.js | 180 ++++++-- token-swap/js/module.d.ts | 40 +- token-swap/js/module.flow.js | 41 +- token-swap/program/src/curve.rs | 506 ++++++++++++++------ token-swap/program/src/error.rs | 13 +- token-swap/program/src/instruction.rs | 69 +-- token-swap/program/src/processor.rs | 634 +++++++++++++++++++++----- token-swap/program/src/state.rs | 90 +++- 9 files changed, 1322 insertions(+), 360 deletions(-) diff --git a/token-swap/js/cli/token-swap-test.js b/token-swap/js/cli/token-swap-test.js index 6806eb9a..dfb9f9e9 100644 --- a/token-swap/js/cli/token-swap-test.js +++ b/token-swap/js/cli/token-swap-test.js @@ -28,6 +28,7 @@ let owner: Account; // Token pool let tokenPool: Token; let tokenAccountPool: PublicKey; +let feeAccount: PublicKey; // Tokens swapped let mintA: Token; let mintB: Token; @@ -37,14 +38,25 @@ let tokenAccountB: PublicKey; // curve type used to calculate swaps and deposits const CURVE_TYPE = CurveType.ConstantProduct; // Initial amount in each swap token -const BASE_AMOUNT = 1000; +let currentSwapTokenA = 1000; +let currentSwapTokenB = 1000; +let currentFeeAmount = 0; // Amount passed to swap instruction const SWAP_AMOUNT_IN = 100; -const SWAP_AMOUNT_OUT = 70; +const SWAP_AMOUNT_OUT = 53; +const SWAP_FEE = 6817150; // Pool token amount minted on init const DEFAULT_POOL_TOKEN_AMOUNT = 1000000000; // Pool token amount to withdraw / deposit -const POOL_TOKEN_AMOUNT = 1000000; +const POOL_TOKEN_AMOUNT = 10000000; + +// Pool fees +const TRADING_FEE_NUMERATOR = 1; +const TRADING_FEE_DENOMINATOR = 4; +const OWNER_TRADING_FEE_NUMERATOR = 1; +const OWNER_TRADING_FEE_DENOMINATOR = 5; +const OWNER_WITHDRAW_FEE_NUMERATOR = 1; +const OWNER_WITHDRAW_FEE_DENOMINATOR = 6; function assert(condition, message) { if (!condition) { @@ -131,11 +143,8 @@ export async function loadPrograms(): Promise { export async function createTokenSwap(): Promise { const connection = await getConnection(); const [tokenProgramId, tokenSwapProgramId] = await GetPrograms(connection); - const payer = await newAccountWithLamports( - connection, - 100000000000 /* wag */, - ); - owner = await newAccountWithLamports(connection, 100000000000 /* wag */); + const payer = await newAccountWithLamports(connection, 1000000000); + owner = await newAccountWithLamports(connection, 1000000000); const tokenSwapAccount = new Account(); [authority, nonce] = await PublicKey.findProgramAddress( @@ -155,6 +164,7 @@ export async function createTokenSwap(): Promise { console.log('creating pool account'); tokenAccountPool = await tokenPool.createAccount(owner.publicKey); + feeAccount = await tokenPool.createAccount(owner.publicKey); console.log('creating token A'); mintA = await Token.createMint( @@ -169,7 +179,7 @@ export async function createTokenSwap(): Promise { console.log('creating token A account'); tokenAccountA = await mintA.createAccount(authority); console.log('minting token A to swap'); - await mintA.mintTo(tokenAccountA, owner, [], BASE_AMOUNT); + await mintA.mintTo(tokenAccountA, owner, [], currentSwapTokenA); console.log('creating token B'); mintB = await Token.createMint( @@ -184,13 +194,10 @@ export async function createTokenSwap(): Promise { console.log('creating token B account'); tokenAccountB = await mintB.createAccount(authority); console.log('minting token B to swap'); - await mintB.mintTo(tokenAccountB, owner, [], BASE_AMOUNT); + await mintB.mintTo(tokenAccountB, owner, [], currentSwapTokenB); console.log('creating token swap'); - const swapPayer = await newAccountWithLamports( - connection, - 100000000000 /* wag */, - ); + const swapPayer = await newAccountWithLamports(connection, 10000000000); tokenSwap = await TokenSwap.createTokenSwap( connection, swapPayer, @@ -199,13 +206,20 @@ export async function createTokenSwap(): Promise { tokenAccountA, tokenAccountB, tokenPool.publicKey, + mintA.publicKey, + mintB.publicKey, + feeAccount, tokenAccountPool, tokenSwapProgramId, tokenProgramId, nonce, CURVE_TYPE, - 1, - 4, + TRADING_FEE_NUMERATOR, + TRADING_FEE_DENOMINATOR, + OWNER_TRADING_FEE_NUMERATOR, + OWNER_TRADING_FEE_DENOMINATOR, + OWNER_WITHDRAW_FEE_NUMERATOR, + OWNER_WITHDRAW_FEE_DENOMINATOR, ); console.log('loading token swap'); @@ -219,10 +233,33 @@ export async function createTokenSwap(): Promise { assert(fetchedTokenSwap.tokenProgramId.equals(tokenProgramId)); assert(fetchedTokenSwap.tokenAccountA.equals(tokenAccountA)); assert(fetchedTokenSwap.tokenAccountB.equals(tokenAccountB)); + assert(fetchedTokenSwap.mintA.equals(mintA.publicKey)); + assert(fetchedTokenSwap.mintB.equals(mintB.publicKey)); assert(fetchedTokenSwap.poolToken.equals(tokenPool.publicKey)); + assert(fetchedTokenSwap.feeAccount.equals(feeAccount)); assert(CURVE_TYPE == fetchedTokenSwap.curveType); - assert(1 == fetchedTokenSwap.feeNumerator.toNumber()); - assert(4 == fetchedTokenSwap.feeDenominator.toNumber()); + assert( + TRADING_FEE_NUMERATOR == fetchedTokenSwap.tradeFeeNumerator.toNumber(), + ); + assert( + TRADING_FEE_DENOMINATOR == fetchedTokenSwap.tradeFeeDenominator.toNumber(), + ); + assert( + OWNER_TRADING_FEE_NUMERATOR == + fetchedTokenSwap.ownerTradeFeeNumerator.toNumber(), + ); + assert( + OWNER_TRADING_FEE_DENOMINATOR == + fetchedTokenSwap.ownerTradeFeeDenominator.toNumber(), + ); + assert( + OWNER_WITHDRAW_FEE_NUMERATOR == + fetchedTokenSwap.ownerWithdrawFeeNumerator.toNumber(), + ); + assert( + OWNER_WITHDRAW_FEE_DENOMINATOR == + fetchedTokenSwap.ownerWithdrawFeeDenominator.toNumber(), + ); } export async function deposit(): Promise { @@ -260,9 +297,11 @@ export async function deposit(): Promise { info = await mintB.getAccountInfo(userAccountB); assert(info.amount.toNumber() == 0); info = await mintA.getAccountInfo(tokenAccountA); - assert(info.amount.toNumber() == BASE_AMOUNT + tokenA); + assert(info.amount.toNumber() == currentSwapTokenA + tokenA); + currentSwapTokenA += tokenA; info = await mintB.getAccountInfo(tokenAccountB); - assert(info.amount.toNumber() == BASE_AMOUNT + tokenB); + assert(info.amount.toNumber() == currentSwapTokenB + tokenB); + currentSwapTokenB += tokenB; info = await tokenPool.getAccountInfo(newAccountPool); assert(info.amount.toNumber() == POOL_TOKEN_AMOUNT); } @@ -272,8 +311,17 @@ export async function withdraw(): Promise { const supply = poolMintInfo.supply.toNumber(); let swapTokenA = await mintA.getAccountInfo(tokenAccountA); let swapTokenB = await mintB.getAccountInfo(tokenAccountB); - const tokenA = (swapTokenA.amount.toNumber() * POOL_TOKEN_AMOUNT) / supply; - const tokenB = (swapTokenB.amount.toNumber() * POOL_TOKEN_AMOUNT) / supply; + const feeAmount = Math.floor( + (POOL_TOKEN_AMOUNT * OWNER_WITHDRAW_FEE_NUMERATOR) / + OWNER_WITHDRAW_FEE_DENOMINATOR, + ); + const poolTokenAmount = POOL_TOKEN_AMOUNT - feeAmount; + const tokenA = Math.floor( + (swapTokenA.amount.toNumber() * poolTokenAmount) / supply, + ); + const tokenB = Math.floor( + (swapTokenB.amount.toNumber() * poolTokenAmount) / supply, + ); console.log('Creating withdraw token A account'); let userAccountA = await mintA.createAccount(owner.publicKey); @@ -307,12 +355,17 @@ export async function withdraw(): Promise { assert( info.amount.toNumber() == DEFAULT_POOL_TOKEN_AMOUNT - POOL_TOKEN_AMOUNT, ); - assert(swapTokenA.amount.toNumber() == BASE_AMOUNT); - assert(swapTokenB.amount.toNumber() == BASE_AMOUNT); + assert(swapTokenA.amount.toNumber() == currentSwapTokenA - tokenA); + currentSwapTokenA -= tokenA; + assert(swapTokenB.amount.toNumber() == currentSwapTokenB - tokenB); + currentSwapTokenB -= tokenB; info = await mintA.getAccountInfo(userAccountA); assert(info.amount.toNumber() == tokenA); info = await mintB.getAccountInfo(userAccountB); assert(info.amount.toNumber() == tokenB); + info = await tokenPool.getAccountInfo(feeAccount); + assert(info.amount.toNumber() == feeAmount); + currentFeeAmount = feeAmount; } export async function swap(): Promise { @@ -337,13 +390,17 @@ export async function swap(): Promise { info = await mintA.getAccountInfo(userAccountA); assert(info.amount.toNumber() == 0); info = await mintA.getAccountInfo(tokenAccountA); - assert(info.amount.toNumber() == BASE_AMOUNT + SWAP_AMOUNT_IN); + assert(info.amount.toNumber() == currentSwapTokenA + SWAP_AMOUNT_IN); + currentSwapTokenA -= SWAP_AMOUNT_IN; info = await mintB.getAccountInfo(tokenAccountB); - assert(info.amount.toNumber() == BASE_AMOUNT - SWAP_AMOUNT_OUT); + assert(info.amount.toNumber() == currentSwapTokenB - SWAP_AMOUNT_OUT); + currentSwapTokenB -= SWAP_AMOUNT_OUT; info = await mintB.getAccountInfo(userAccountB); assert(info.amount.toNumber() == SWAP_AMOUNT_OUT); info = await tokenPool.getAccountInfo(tokenAccountPool); assert( info.amount.toNumber() == DEFAULT_POOL_TOKEN_AMOUNT - POOL_TOKEN_AMOUNT, ); + info = await tokenPool.getAccountInfo(feeAccount); + assert(info.amount.toNumber() == currentFeeAmount + SWAP_FEE); } diff --git a/token-swap/js/client/token-swap.js b/token-swap/js/client/token-swap.js index a5f66653..71fb8ada 100644 --- a/token-swap/js/client/token-swap.js +++ b/token-swap/js/client/token-swap.js @@ -64,10 +64,17 @@ export const TokenSwapLayout: typeof BufferLayout.Structure = BufferLayout.struc Layout.publicKey('tokenAccountA'), Layout.publicKey('tokenAccountB'), Layout.publicKey('tokenPool'), + Layout.publicKey('mintA'), + Layout.publicKey('mintB'), + Layout.publicKey('feeAccount'), BufferLayout.u8('curveType'), - Layout.uint64('feeNumerator'), - Layout.uint64('feeDenominator'), - BufferLayout.blob(48, 'padding'), + Layout.uint64('tradeFeeNumerator'), + Layout.uint64('tradeFeeDenominator'), + Layout.uint64('ownerTradeFeeNumerator'), + Layout.uint64('ownerTradeFeeDenominator'), + Layout.uint64('ownerWithdrawFeeNumerator'), + Layout.uint64('ownerWithdrawFeeDenominator'), + BufferLayout.blob(16, 'padding'), ], ); @@ -105,6 +112,11 @@ export class TokenSwap { */ poolToken: PublicKey; + /** + * The public key for the fee account receiving trade and/or withdrawal fees + */ + feeAccount: PublicKey; + /** * Authority */ @@ -121,14 +133,44 @@ export class TokenSwap { tokenAccountB: PublicKey; /** - * Fee numerator + * The public key for the mint of the first token account of the trading pair */ - feeNumerator: Numberu64; + mintA: PublicKey; /** - * Fee denominator + * The public key for the mint of the second token account of the trading pair */ - feeDenominator: Numberu64; + mintB: PublicKey; + + /** + * Trading fee numerator + */ + tradeFeeNumerator: Numberu64; + + /** + * Trading fee denominator + */ + tradeFeeDenominator: Numberu64; + + /** + * Owner trading fee numerator + */ + ownerTradeFeeNumerator: Numberu64; + + /** + * Owner trading fee denominator + */ + ownerTradeFeeDenominator: Numberu64; + + /** + * Owner withdraw fee numerator + */ + ownerWithdrawFeeNumerator: Numberu64; + + /** + * Owner withdraw fee denominator + */ + ownerWithdrawFeeDenominator: Numberu64; /** * CurveType, current options are: @@ -159,12 +201,19 @@ export class TokenSwap { swapProgramId: PublicKey, tokenProgramId: PublicKey, poolToken: PublicKey, + feeAccount: PublicKey, authority: PublicKey, tokenAccountA: PublicKey, tokenAccountB: PublicKey, + mintA: PublicKey, + mintB: PublicKey, curveType: number, - feeNumerator: Numberu64, - feeDenominator: Numberu64, + tradeFeeNumerator: Numberu64, + tradeFeeDenominator: Numberu64, + ownerTradeFeeNumerator: Numberu64, + ownerTradeFeeDenominator: Numberu64, + ownerWithdrawFeeNumerator: Numberu64, + ownerWithdrawFeeDenominator: Numberu64, payer: Account, ) { Object.assign(this, { @@ -173,12 +222,19 @@ export class TokenSwap { swapProgramId, tokenProgramId, poolToken, + feeAccount, authority, tokenAccountA, tokenAccountB, + mintA, + mintB, curveType, - feeNumerator, - feeDenominator, + tradeFeeNumerator, + tradeFeeDenominator, + ownerTradeFeeNumerator, + ownerTradeFeeDenominator, + ownerWithdrawFeeNumerator, + ownerWithdrawFeeDenominator, payer, }); } @@ -202,13 +258,18 @@ export class TokenSwap { tokenAccountA: PublicKey, tokenAccountB: PublicKey, tokenPool: PublicKey, + feeAccount: PublicKey, tokenAccountPool: PublicKey, tokenProgramId: PublicKey, swapProgramId: PublicKey, nonce: number, curveType: number, - feeNumerator: number, - feeDenominator: number, + tradeFeeNumerator: number, + tradeFeeDenominator: number, + ownerTradeFeeNumerator: number, + ownerTradeFeeDenominator: number, + ownerWithdrawFeeNumerator: number, + ownerWithdrawFeeDenominator: number, ): TransactionInstruction { const keys = [ {pubkey: tokenSwapAccount.publicKey, isSigner: false, isWritable: true}, @@ -216,6 +277,7 @@ export class TokenSwap { {pubkey: tokenAccountA, isSigner: false, isWritable: false}, {pubkey: tokenAccountB, isSigner: false, isWritable: false}, {pubkey: tokenPool, isSigner: false, isWritable: true}, + {pubkey: feeAccount, isSigner: false, isWritable: false}, {pubkey: tokenAccountPool, isSigner: false, isWritable: true}, {pubkey: tokenProgramId, isSigner: false, isWritable: false}, ]; @@ -223,9 +285,13 @@ export class TokenSwap { BufferLayout.u8('instruction'), BufferLayout.u8('nonce'), BufferLayout.u8('curveType'), - BufferLayout.nu64('feeNumerator'), - BufferLayout.nu64('feeDenominator'), - BufferLayout.blob(48, 'padding'), + BufferLayout.nu64('tradeFeeNumerator'), + BufferLayout.nu64('tradeFeeDenominator'), + BufferLayout.nu64('ownerTradeFeeNumerator'), + BufferLayout.nu64('ownerTradeFeeDenominator'), + BufferLayout.nu64('ownerWithdrawFeeNumerator'), + BufferLayout.nu64('ownerWithdrawFeeDenominator'), + BufferLayout.blob(16, 'padding'), ]); let data = Buffer.alloc(1024); { @@ -234,8 +300,12 @@ export class TokenSwap { instruction: 0, // InitializeSwap instruction nonce, curveType, - feeNumerator, - feeDenominator, + tradeFeeNumerator, + tradeFeeDenominator, + ownerTradeFeeNumerator, + ownerTradeFeeDenominator, + ownerWithdrawFeeNumerator, + ownerWithdrawFeeDenominator, }, data, ); @@ -266,12 +336,31 @@ export class TokenSwap { ); const poolToken = new PublicKey(tokenSwapData.tokenPool); + const feeAccount = new PublicKey(tokenSwapData.feeAccount); const tokenAccountA = new PublicKey(tokenSwapData.tokenAccountA); const tokenAccountB = new PublicKey(tokenSwapData.tokenAccountB); + const mintA = new PublicKey(tokenSwapData.mintA); + const mintB = new PublicKey(tokenSwapData.mintB); const tokenProgramId = new PublicKey(tokenSwapData.tokenProgramId); - const feeNumerator = Numberu64.fromBuffer(tokenSwapData.feeNumerator); - const feeDenominator = Numberu64.fromBuffer(tokenSwapData.feeDenominator); + const tradeFeeNumerator = Numberu64.fromBuffer( + tokenSwapData.tradeFeeNumerator, + ); + const tradeFeeDenominator = Numberu64.fromBuffer( + tokenSwapData.tradeFeeDenominator, + ); + const ownerTradeFeeNumerator = Numberu64.fromBuffer( + tokenSwapData.ownerTradeFeeNumerator, + ); + const ownerTradeFeeDenominator = Numberu64.fromBuffer( + tokenSwapData.ownerTradeFeeDenominator, + ); + const ownerWithdrawFeeNumerator = Numberu64.fromBuffer( + tokenSwapData.ownerWithdrawFeeNumerator, + ); + const ownerWithdrawFeeDenominator = Numberu64.fromBuffer( + tokenSwapData.ownerWithdrawFeeDenominator, + ); const curveType = tokenSwapData.curveType; return new TokenSwap( @@ -280,12 +369,19 @@ export class TokenSwap { programId, tokenProgramId, poolToken, + feeAccount, authority, tokenAccountA, tokenAccountB, + mintA, + mintB, curveType, - feeNumerator, - feeDenominator, + tradeFeeNumerator, + tradeFeeDenominator, + ownerTradeFeeNumerator, + ownerTradeFeeDenominator, + ownerWithdrawFeeNumerator, + ownerWithdrawFeeDenominator, payer, ); } @@ -316,13 +412,20 @@ export class TokenSwap { tokenAccountA: PublicKey, tokenAccountB: PublicKey, poolToken: PublicKey, + mintA: PublicKey, + mintB: PublicKey, + feeAccount: PublicKey, tokenAccountPool: PublicKey, swapProgramId: PublicKey, tokenProgramId: PublicKey, nonce: number, curveType: number, - feeNumerator: number, - feeDenominator: number, + tradeFeeNumerator: number, + tradeFeeDenominator: number, + ownerTradeFeeNumerator: number, + ownerTradeFeeDenominator: number, + ownerWithdrawFeeNumerator: number, + ownerWithdrawFeeDenominator: number, ): Promise { let transaction; const tokenSwap = new TokenSwap( @@ -331,12 +434,19 @@ export class TokenSwap { swapProgramId, tokenProgramId, poolToken, + feeAccount, authority, tokenAccountA, tokenAccountB, + mintA, + mintB, curveType, - new Numberu64(feeNumerator), - new Numberu64(feeDenominator), + new Numberu64(tradeFeeNumerator), + new Numberu64(tradeFeeDenominator), + new Numberu64(ownerTradeFeeNumerator), + new Numberu64(ownerTradeFeeDenominator), + new Numberu64(ownerWithdrawFeeNumerator), + new Numberu64(ownerWithdrawFeeDenominator), payer, ); @@ -361,13 +471,18 @@ export class TokenSwap { tokenAccountA, tokenAccountB, poolToken, + feeAccount, tokenAccountPool, tokenProgramId, swapProgramId, nonce, curveType, - feeNumerator, - feeDenominator, + tradeFeeNumerator, + tradeFeeDenominator, + ownerTradeFeeNumerator, + ownerTradeFeeDenominator, + ownerWithdrawFeeNumerator, + ownerWithdrawFeeDenominator, ); transaction.add(instruction); @@ -411,6 +526,8 @@ export class TokenSwap { poolSource, poolDestination, userDestination, + this.poolToken, + this.feeAccount, this.swapProgramId, this.tokenProgramId, amountIn, @@ -428,6 +545,8 @@ export class TokenSwap { poolSource: PublicKey, poolDestination: PublicKey, userDestination: PublicKey, + poolMint: PublicKey, + feeAccount: PublicKey, swapProgramId: PublicKey, tokenProgramId: PublicKey, amountIn: number | Numberu64, @@ -456,6 +575,8 @@ export class TokenSwap { {pubkey: poolSource, isSigner: false, isWritable: true}, {pubkey: poolDestination, isSigner: false, isWritable: true}, {pubkey: userDestination, isSigner: false, isWritable: true}, + {pubkey: poolMint, isSigner: false, isWritable: true}, + {pubkey: feeAccount, isSigner: false, isWritable: true}, {pubkey: tokenProgramId, isSigner: false, isWritable: false}, ]; return new TransactionInstruction({ @@ -583,6 +704,7 @@ export class TokenSwap { this.tokenSwap, this.authority, this.poolToken, + this.feeAccount, poolAccount, this.tokenAccountA, this.tokenAccountB, @@ -603,6 +725,7 @@ export class TokenSwap { tokenSwap: PublicKey, authority: PublicKey, poolMint: PublicKey, + feeAccount: PublicKey, sourcePoolAccount: PublicKey, fromA: PublicKey, fromB: PublicKey, @@ -641,6 +764,7 @@ export class TokenSwap { {pubkey: fromB, isSigner: false, isWritable: true}, {pubkey: userAccountA, isSigner: false, isWritable: true}, {pubkey: userAccountB, isSigner: false, isWritable: true}, + {pubkey: feeAccount, isSigner: false, isWritable: true}, {pubkey: tokenProgramId, isSigner: false, isWritable: false}, ]; return new TransactionInstruction({ diff --git a/token-swap/js/module.d.ts b/token-swap/js/module.d.ts index 862c5da5..05233f72 100644 --- a/token-swap/js/module.d.ts +++ b/token-swap/js/module.d.ts @@ -26,12 +26,19 @@ declare module '@solana/spl-token-swap' { tokenProgramId: PublicKey, tokenSwap: PublicKey, poolToken: PublicKey, + feeAccount: PublicKey, authority: PublicKey, tokenAccountA: PublicKey, tokenAccountB: PublicKey, + mintA: PublicKey, + mintB: PublicKey, curveType: number, - feeNumerator: Numberu64, - feeDenominator: Numberu64, + tradeFeeNumerator: Numberu64, + tradeFeeDenominator: Numberu64, + ownerTradeFeeNumerator: Numberu64, + ownerTradeFeeDenominator: Numberu64, + ownerWithdrawFeeNumerator: Numberu64, + ownerWithdrawFeeDenominator: Numberu64, payer: Account, ); @@ -45,13 +52,18 @@ declare module '@solana/spl-token-swap' { tokenAccountA: PublicKey, tokenAccountB: PublicKey, tokenPool: PublicKey, + feeAccount: PublicKey, tokenAccountPool: PublicKey, tokenProgramId: PublicKey, swapProgramId: PublicKey, nonce: number, curveType: number, - feeNumerator: number, - feeDenominator: number, + tradeFeeNumerator: number, + tradeFeeDenominator: number, + ownerTradeFeeNumerator: number, + ownerTradeFeeDenominator: number, + ownerWithdrawFeeNumerator: number, + ownerWithdrawFeeDenominator: number, ): TransactionInstruction; static loadTokenSwap( @@ -69,12 +81,19 @@ declare module '@solana/spl-token-swap' { tokenAccountA: PublicKey, tokenAccountB: PublicKey, tokenPool: PublicKey, + mintA: PublicKey, + mintB: PublicKey, + feeAccount: PublicKey, tokenAccountPool: PublicKey, tokenProgramId: PublicKey, nonce: number, curveType: number, - feeNumerator: number, - feeDenominator: number, + tradeFeeNumerator: number, + tradeFeeDenominator: number, + ownerTradeFeeNumerator: number, + ownerTradeFeeDenominator: number, + ownerWithdrawFeeNumerator: number, + ownerWithdrawFeeDenominator: number, swapProgramId: PublicKey, ): Promise; @@ -94,6 +113,8 @@ declare module '@solana/spl-token-swap' { poolSource: PublicKey, poolDestination: PublicKey, userDestination: PublicKey, + poolMint: PublicKey, + feeAccount: PublicKey, swapProgramId: PublicKey, tokenProgramId: PublicKey, amountIn: number | Numberu64, @@ -131,14 +152,8 @@ declare module '@solana/spl-token-swap' { ): TransactionInstruction; withdraw( - authority: PublicKey, - poolMint: PublicKey, - sourcePoolAccount: PublicKey, - fromA: PublicKey, - fromB: PublicKey, userAccountA: PublicKey, userAccountB: PublicKey, - tokenProgramId: PublicKey, poolTokenAmount: number | Numberu64, minimumTokenA: number | Numberu64, minimumTokenB: number | Numberu64, @@ -148,6 +163,7 @@ declare module '@solana/spl-token-swap' { tokenSwap: PublicKey, authority: PublicKey, poolMint: PublicKey, + feeAccount: PublicKey, sourcePoolAccount: PublicKey, fromA: PublicKey, fromB: PublicKey, diff --git a/token-swap/js/module.flow.js b/token-swap/js/module.flow.js index a977521f..df3ca88d 100644 --- a/token-swap/js/module.flow.js +++ b/token-swap/js/module.flow.js @@ -23,12 +23,19 @@ declare module '@solana/spl-token-swap' { tokenProgramId: PublicKey, tokenSwap: PublicKey, poolToken: PublicKey, + feeAccount: PublicKey, authority: PublicKey, tokenAccountA: PublicKey, tokenAccountB: PublicKey, + mintA: PublicKey, + mintB: PublicKey, curveType: number, - feeNumerator: Numberu64, - feeDenominator: Numberu64, + tradeFeeNumerator: Numberu64, + tradeFeeDenominator: Numberu64, + ownerTradeFeeNumerator: Numberu64, + ownerTradeFeeDenominator: Numberu64, + ownerWithdrawFeeNumerator: Numberu64, + ownerWithdrawFeeDenominator: Numberu64, payer: Account, ): TokenSwap; @@ -43,12 +50,17 @@ declare module '@solana/spl-token-swap' { tokenAccountA: PublicKey, tokenAccountB: PublicKey, tokenPool: PublicKey, + feeAccount: PublicKey, tokenAccountPool: PublicKey, tokenProgramId: PublicKey, nonce: number, curveType: number, - feeNumerator: number, - feeDenominator: number, + tradeFeeNumerator: number, + tradeFeeDenominator: number, + ownerTradeFeeNumerator: number, + ownerTradeFeeDenominator: number, + ownerWithdrawFeeNumerator: number, + ownerWithdrawFeeDenominator: number, ): TransactionInstruction; static loadTokenSwap( @@ -66,12 +78,19 @@ declare module '@solana/spl-token-swap' { tokenAccountA: PublicKey, tokenAccountB: PublicKey, tokenPool: PublicKey, + mintA: PublicKey, + mintB: PublicKey, + feeAccount: PublicKey, tokenAccountPool: PublicKey, tokenProgramId: PublicKey, nonce: number, curveType: number, - feeNumerator: number, - feeDenominator: number, + tradeFeeNumerator: number, + tradeFeeDenominator: number, + ownerTradeFeeNumerator: number, + ownerTradeFeeDenominator: number, + ownerWithdrawFeeNumerator: number, + ownerWithdrawFeeDenominator: number, programId: PublicKey, ): Promise; @@ -91,6 +110,8 @@ declare module '@solana/spl-token-swap' { poolSource: PublicKey, poolDestination: PublicKey, userDestination: PublicKey, + poolMint: PublicKey, + feeAccount: PublicKey, swapProgramId: PublicKey, tokenProgramId: PublicKey, amountIn: number | Numberu64, @@ -128,14 +149,9 @@ declare module '@solana/spl-token-swap' { ): TransactionInstruction; withdraw( - authority: PublicKey, - poolMint: PublicKey, - sourcePoolAccount: PublicKey, - fromA: PublicKey, - fromB: PublicKey, userAccountA: PublicKey, userAccountB: PublicKey, - tokenProgramId: PublicKey, + poolAccount: PublicKey, poolTokenAmount: number | Numberu64, minimumTokenA: number | Numberu64, minimumTokenB: number | Numberu64, @@ -145,6 +161,7 @@ declare module '@solana/spl-token-swap' { tokenSwap: PublicKey, authority: PublicKey, poolMint: PublicKey, + feeAccount: PublicKey, sourcePoolAccount: PublicKey, fromA: PublicKey, fromB: PublicKey, diff --git a/token-swap/program/src/curve.rs b/token-swap/program/src/curve.rs index b2bf4739..b4c4c0b5 100644 --- a/token-swap/program/src/curve.rs +++ b/token-swap/program/src/curve.rs @@ -50,6 +50,18 @@ impl Default for SwapCurve { } } +/// Clone takes advantage of pack / unpack to get around the difficulty of +/// cloning dynamic objects. +/// Note that this is only to be used for testing. +#[cfg(test)] +impl Clone for SwapCurve { + fn clone(&self) -> Self { + let mut packed_self = [0u8; Self::LEN]; + Self::pack_into_slice(self, &mut packed_self); + Self::unpack_from_slice(&packed_self).unwrap() + } +} + /// Simple implementation for PartialEq which assumes that the output of /// `Pack` is enough to guarantee equality impl PartialEq for SwapCurve { @@ -134,18 +146,60 @@ pub trait CurveCalculator: Debug + DynPack { swap_destination_amount: u64, ) -> Option; - /// Get the supply of a new pool (can be a default amount or calculated - /// based on parameters) - fn new_pool_supply(&self) -> u64; + /// Calculate the withdraw fee in pool tokens + /// Default implementation assumes no fee + fn owner_withdraw_fee(&self, _pool_tokens: u64) -> Option { + Some(0) + } - /// Get the amount of liquidity tokens for pool tokens given the total amount - /// of liquidity tokens in the pool - fn liquidity_tokens( + /// Calculate the trading fee in trading tokens + /// Default implementation assumes no fee + fn trading_fee(&self, _trading_tokens: u64) -> Option { + Some(0) + } + + /// Calculate the pool token equivalent of the owner fee on trade + /// See the math at: https://balancer.finance/whitepaper/#single-asset-deposit + /// For the moment, we do an approximation for the square root. For numbers + /// just above 1, simply dividing by 2 brings you very close to the correct + /// value. + fn owner_fee_to_pool_tokens( + &self, + owner_fee: u64, + trading_token_amount: u64, + pool_supply: u64, + tokens_in_pool: u64, + ) -> Option { + // Get the trading fee incurred if the owner fee is swapped for the other side + let trade_fee = self.trading_fee(owner_fee)?; + let owner_fee = owner_fee.checked_sub(trade_fee)?; + pool_supply + .checked_mul(owner_fee)? + .checked_div(trading_token_amount)? + .checked_div(tokens_in_pool) + } + + /// Get the supply for a new pool + /// The default implementation is a Balancer-style fixed initial supply + fn new_pool_supply(&self) -> u64 { + INITIAL_SWAP_POOL_AMOUNT + } + + /// Get the amount of trading tokens for the given amount of pool tokens, + /// provided the total trading tokens and supply of pool tokens. + /// The default implementation is a simple ratio calculation for how many + /// trading tokens correspond to a certain number of pool tokens + fn pool_tokens_to_trading_tokens( &self, pool_tokens: u64, pool_token_supply: u64, - total_liquidity_tokens: u64, - ) -> Option; + total_trading_tokens: u64, + ) -> Option { + pool_tokens + .checked_mul(total_trading_tokens)? + .checked_div(pool_token_supply) + .and_then(map_zero_to_none) + } } /// Encodes all results of swapping from a source token to a destination token @@ -156,6 +210,10 @@ pub struct SwapResult { pub new_destination_amount: u64, /// Amount of destination token swapped pub amount_swapped: u64, + /// Amount of source tokens going to pool holders + pub trade_fee: u64, + /// Amount of source tokens going to owner + pub owner_fee: u64, } /// Helper function for mapping to SwapError::CalculationFailure @@ -171,9 +229,32 @@ fn map_zero_to_none(x: u64) -> Option { #[derive(Clone, Debug, Default, PartialEq)] pub struct FlatCurve { /// Fee numerator - pub fee_numerator: u64, + pub trade_fee_numerator: u64, /// Fee denominator - pub fee_denominator: u64, + pub trade_fee_denominator: u64, + /// Owner trade fee numerator + pub owner_trade_fee_numerator: u64, + /// Owner trade fee denominator + pub owner_trade_fee_denominator: u64, + /// Owner withdraw fee numerator + pub owner_withdraw_fee_numerator: u64, + /// Owner withdraw fee denominator + pub owner_withdraw_fee_denominator: u64, +} + +fn calculate_fee(token_amount: u64, fee_numerator: u64, fee_denominator: u64) -> Option { + if fee_numerator == 0 { + Some(0) + } else { + let fee = token_amount + .checked_mul(fee_numerator)? + .checked_div(fee_denominator)?; + if fee == 0 { + Some(1) // minimum fee of one token + } else { + Some(fee) + } + } } impl CurveCalculator for FlatCurve { @@ -185,14 +266,20 @@ impl CurveCalculator for FlatCurve { swap_destination_amount: u64, ) -> Option { // debit the fee to calculate the amount swapped - let mut fee = source_amount - .checked_mul(self.fee_numerator)? - .checked_div(self.fee_denominator)?; - if fee == 0 { - fee = 1; // minimum fee of one token - } + let trade_fee = calculate_fee( + source_amount, + self.trade_fee_numerator, + self.trade_fee_denominator, + )?; + let owner_fee = calculate_fee( + source_amount, + self.owner_trade_fee_numerator, + self.owner_trade_fee_denominator, + )?; - let amount_swapped = source_amount.checked_sub(fee)?; + let amount_swapped = source_amount + .checked_sub(trade_fee)? + .checked_sub(owner_fee)?; let new_destination_amount = swap_destination_amount.checked_sub(amount_swapped)?; // actually add the whole amount coming in @@ -201,26 +288,18 @@ impl CurveCalculator for FlatCurve { new_source_amount, new_destination_amount, amount_swapped, + trade_fee, + owner_fee, }) } - /// Balancer-style fixed initial supply - fn new_pool_supply(&self) -> u64 { - INITIAL_SWAP_POOL_AMOUNT - } - - /// Simple ratio calculation for how many liquidity tokens correspond to - /// a certain number of pool tokens - fn liquidity_tokens( - &self, - pool_tokens: u64, - pool_token_supply: u64, - total_liquidity_tokens: u64, - ) -> Option { - pool_tokens - .checked_mul(total_liquidity_tokens)? - .checked_div(pool_token_supply) - .and_then(map_zero_to_none) + /// Calculate the withdraw fee in pool tokens + fn owner_withdraw_fee(&self, pool_tokens: u64) -> Option { + calculate_fee( + pool_tokens, + self.owner_withdraw_fee_numerator, + self.owner_withdraw_fee_denominator, + ) } } @@ -232,15 +311,25 @@ impl IsInitialized for FlatCurve { } impl Sealed for FlatCurve {} impl Pack for FlatCurve { - const LEN: usize = 16; - /// Unpacks a byte buffer into a SwapCurve + const LEN: usize = 48; fn unpack_from_slice(input: &[u8]) -> Result { - let input = array_ref![input, 0, 16]; + let input = array_ref![input, 0, 48]; #[allow(clippy::ptr_offset_with_cast)] - let (fee_numerator, fee_denominator) = array_refs![input, 8, 8]; + let ( + trade_fee_numerator, + trade_fee_denominator, + owner_trade_fee_numerator, + owner_trade_fee_denominator, + owner_withdraw_fee_numerator, + owner_withdraw_fee_denominator, + ) = array_refs![input, 8, 8, 8, 8, 8, 8]; Ok(Self { - fee_numerator: u64::from_le_bytes(*fee_numerator), - fee_denominator: u64::from_le_bytes(*fee_denominator), + trade_fee_numerator: u64::from_le_bytes(*trade_fee_numerator), + trade_fee_denominator: u64::from_le_bytes(*trade_fee_denominator), + owner_trade_fee_numerator: u64::from_le_bytes(*owner_trade_fee_numerator), + owner_trade_fee_denominator: u64::from_le_bytes(*owner_trade_fee_denominator), + owner_withdraw_fee_numerator: u64::from_le_bytes(*owner_withdraw_fee_numerator), + owner_withdraw_fee_denominator: u64::from_le_bytes(*owner_withdraw_fee_denominator), }) } @@ -251,20 +340,39 @@ impl Pack for FlatCurve { impl DynPack for FlatCurve { fn pack_into_slice(&self, output: &mut [u8]) { - let output = array_mut_ref![output, 0, 16]; - let (fee_numerator, fee_denominator) = mut_array_refs![output, 8, 8]; - *fee_numerator = self.fee_numerator.to_le_bytes(); - *fee_denominator = self.fee_denominator.to_le_bytes(); + let output = array_mut_ref![output, 0, 48]; + let ( + trade_fee_numerator, + trade_fee_denominator, + owner_trade_fee_numerator, + owner_trade_fee_denominator, + owner_withdraw_fee_numerator, + owner_withdraw_fee_denominator, + ) = mut_array_refs![output, 8, 8, 8, 8, 8, 8]; + *trade_fee_numerator = self.trade_fee_numerator.to_le_bytes(); + *trade_fee_denominator = self.trade_fee_denominator.to_le_bytes(); + *owner_trade_fee_numerator = self.owner_trade_fee_numerator.to_le_bytes(); + *owner_trade_fee_denominator = self.owner_trade_fee_denominator.to_le_bytes(); + *owner_withdraw_fee_numerator = self.owner_withdraw_fee_numerator.to_le_bytes(); + *owner_withdraw_fee_denominator = self.owner_withdraw_fee_denominator.to_le_bytes(); } } /// The Uniswap invariant calculator. #[derive(Clone, Debug, Default, PartialEq)] pub struct ConstantProductCurve { - /// Fee numerator - pub fee_numerator: u64, - /// Fee denominator - pub fee_denominator: u64, + /// Trade fee numerator + pub trade_fee_numerator: u64, + /// Trade fee denominator + pub trade_fee_denominator: u64, + /// Owner trade fee numerator + pub owner_trade_fee_numerator: u64, + /// Owner trade fee denominator + pub owner_trade_fee_denominator: u64, + /// Owner withdraw fee numerator + pub owner_withdraw_fee_numerator: u64, + /// Owner withdraw fee denominator + pub owner_withdraw_fee_denominator: u64, } impl CurveCalculator for ConstantProductCurve { @@ -275,18 +383,19 @@ impl CurveCalculator for ConstantProductCurve { swap_source_amount: u64, swap_destination_amount: u64, ) -> Option { - let invariant = swap_source_amount.checked_mul(swap_destination_amount)?; - // debit the fee to calculate the amount swapped - let mut fee = source_amount - .checked_mul(self.fee_numerator)? - .checked_div(self.fee_denominator)?; - if fee == 0 { - fee = 1; // minimum fee of one token - } + let trade_fee = self.trading_fee(source_amount)?; + let owner_fee = calculate_fee( + source_amount, + self.owner_trade_fee_numerator, + self.owner_trade_fee_denominator, + )?; + + let invariant = swap_source_amount.checked_mul(swap_destination_amount)?; let new_source_amount_less_fee = swap_source_amount .checked_add(source_amount)? - .checked_sub(fee)?; + .checked_sub(trade_fee)? + .checked_sub(owner_fee)?; let new_destination_amount = invariant.checked_div(new_source_amount_less_fee)?; let amount_swapped = map_zero_to_none(swap_destination_amount.checked_sub(new_destination_amount)?)?; @@ -297,27 +406,27 @@ impl CurveCalculator for ConstantProductCurve { new_source_amount, new_destination_amount, amount_swapped, + trade_fee, + owner_fee, }) } - /// Balancer-style supply starts at a constant. This could be modified to - /// follow the geometric mean, as done in Uniswap v2. - fn new_pool_supply(&self) -> u64 { - INITIAL_SWAP_POOL_AMOUNT + /// Calculate the withdraw fee in pool tokens + fn owner_withdraw_fee(&self, pool_tokens: u64) -> Option { + calculate_fee( + pool_tokens, + self.owner_withdraw_fee_numerator, + self.owner_withdraw_fee_denominator, + ) } - /// Simple ratio calculation to get the amount of liquidity tokens given - /// pool information - fn liquidity_tokens( - &self, - pool_tokens: u64, - pool_token_supply: u64, - total_liquidity_tokens: u64, - ) -> Option { - pool_tokens - .checked_mul(total_liquidity_tokens)? - .checked_div(pool_token_supply) - .and_then(map_zero_to_none) + /// Calculate the trading fee in trading tokens + fn trading_fee(&self, trading_tokens: u64) -> Option { + calculate_fee( + trading_tokens, + self.trade_fee_numerator, + self.trade_fee_denominator, + ) } } @@ -329,14 +438,25 @@ impl IsInitialized for ConstantProductCurve { } impl Sealed for ConstantProductCurve {} impl Pack for ConstantProductCurve { - const LEN: usize = 16; + const LEN: usize = 48; fn unpack_from_slice(input: &[u8]) -> Result { - let input = array_ref![input, 0, 16]; + let input = array_ref![input, 0, 48]; #[allow(clippy::ptr_offset_with_cast)] - let (fee_numerator, fee_denominator) = array_refs![input, 8, 8]; + let ( + trade_fee_numerator, + trade_fee_denominator, + owner_trade_fee_numerator, + owner_trade_fee_denominator, + owner_withdraw_fee_numerator, + owner_withdraw_fee_denominator, + ) = array_refs![input, 8, 8, 8, 8, 8, 8]; Ok(Self { - fee_numerator: u64::from_le_bytes(*fee_numerator), - fee_denominator: u64::from_le_bytes(*fee_denominator), + trade_fee_numerator: u64::from_le_bytes(*trade_fee_numerator), + trade_fee_denominator: u64::from_le_bytes(*trade_fee_denominator), + owner_trade_fee_numerator: u64::from_le_bytes(*owner_trade_fee_numerator), + owner_trade_fee_denominator: u64::from_le_bytes(*owner_trade_fee_denominator), + owner_withdraw_fee_numerator: u64::from_le_bytes(*owner_withdraw_fee_numerator), + owner_withdraw_fee_denominator: u64::from_le_bytes(*owner_withdraw_fee_denominator), }) } @@ -347,10 +467,21 @@ impl Pack for ConstantProductCurve { impl DynPack for ConstantProductCurve { fn pack_into_slice(&self, output: &mut [u8]) { - let output = array_mut_ref![output, 0, 16]; - let (fee_numerator, fee_denominator) = mut_array_refs![output, 8, 8]; - *fee_numerator = self.fee_numerator.to_le_bytes(); - *fee_denominator = self.fee_denominator.to_le_bytes(); + let output = array_mut_ref![output, 0, 48]; + let ( + trade_fee_numerator, + trade_fee_denominator, + owner_trade_fee_numerator, + owner_trade_fee_denominator, + owner_withdraw_fee_numerator, + owner_withdraw_fee_denominator, + ) = mut_array_refs![output, 8, 8, 8, 8, 8, 8]; + *trade_fee_numerator = self.trade_fee_numerator.to_le_bytes(); + *trade_fee_denominator = self.trade_fee_denominator.to_le_bytes(); + *owner_trade_fee_numerator = self.owner_trade_fee_numerator.to_le_bytes(); + *owner_trade_fee_denominator = self.owner_trade_fee_denominator.to_le_bytes(); + *owner_withdraw_fee_numerator = self.owner_withdraw_fee_numerator.to_le_bytes(); + *owner_withdraw_fee_denominator = self.owner_withdraw_fee_denominator.to_le_bytes(); } } @@ -360,53 +491,72 @@ mod tests { #[test] fn initial_pool_amount() { - let fee_numerator = 0; - let fee_denominator = 1; + let trade_fee_numerator = 0; + let trade_fee_denominator = 1; + let owner_trade_fee_numerator = 0; + let owner_trade_fee_denominator = 1; + let owner_withdraw_fee_numerator = 0; + let owner_withdraw_fee_denominator = 1; let calculator = ConstantProductCurve { - fee_numerator, - fee_denominator, + trade_fee_numerator, + trade_fee_denominator, + owner_trade_fee_numerator, + owner_trade_fee_denominator, + owner_withdraw_fee_numerator, + owner_withdraw_fee_denominator, }; assert_eq!(calculator.new_pool_supply(), INITIAL_SWAP_POOL_AMOUNT); } - fn check_liquidity_pool_token_rate( - token_a: u64, - deposit: u64, - supply: u64, - expected: Option, - ) { - let fee_numerator = 0; - let fee_denominator = 1; + fn check_pool_token_rate(token_a: u64, deposit: u64, supply: u64, expected: Option) { + let trade_fee_numerator = 0; + let trade_fee_denominator = 1; + let owner_trade_fee_numerator = 0; + let owner_trade_fee_denominator = 1; + let owner_withdraw_fee_numerator = 0; + let owner_withdraw_fee_denominator = 1; let calculator = ConstantProductCurve { - fee_numerator, - fee_denominator, + trade_fee_numerator, + trade_fee_denominator, + owner_trade_fee_numerator, + owner_trade_fee_denominator, + owner_withdraw_fee_numerator, + owner_withdraw_fee_denominator, }; assert_eq!( - calculator.liquidity_tokens(deposit, supply, token_a), + calculator.pool_tokens_to_trading_tokens(deposit, supply, token_a), expected ); } #[test] - fn issued_tokens() { - check_liquidity_pool_token_rate(2, 5, 10, Some(1)); - check_liquidity_pool_token_rate(10, 5, 10, Some(5)); - check_liquidity_pool_token_rate(5, 5, 10, Some(2)); - check_liquidity_pool_token_rate(5, 5, 10, Some(2)); - check_liquidity_pool_token_rate(u64::MAX, 5, 10, None); + fn trading_token_conversion() { + check_pool_token_rate(2, 5, 10, Some(1)); + check_pool_token_rate(10, 5, 10, Some(5)); + check_pool_token_rate(5, 5, 10, Some(2)); + check_pool_token_rate(5, 5, 10, Some(2)); + check_pool_token_rate(u64::MAX, 5, 10, None); } #[test] - fn constant_product_swap_calculation() { + fn constant_product_swap_calculation_trade_fee() { // calculation on https://github.com/solana-labs/solana-program-library/issues/341 let swap_source_amount: u64 = 1000; let swap_destination_amount: u64 = 50000; - let fee_numerator: u64 = 1; - let fee_denominator: u64 = 100; + let trade_fee_numerator: u64 = 1; + let trade_fee_denominator: u64 = 100; + let owner_trade_fee_numerator: u64 = 0; + let owner_trade_fee_denominator: u64 = 0; + let owner_withdraw_fee_numerator: u64 = 0; + let owner_withdraw_fee_denominator: u64 = 0; let source_amount: u64 = 100; let curve = ConstantProductCurve { - fee_numerator, - fee_denominator, + trade_fee_numerator, + trade_fee_denominator, + owner_trade_fee_numerator, + owner_trade_fee_denominator, + owner_withdraw_fee_numerator, + owner_withdraw_fee_denominator, }; let result = curve .swap(source_amount, swap_source_amount, swap_destination_amount) @@ -414,25 +564,81 @@ mod tests { assert_eq!(result.new_source_amount, 1100); assert_eq!(result.amount_swapped, 4505); assert_eq!(result.new_destination_amount, 45495); + assert_eq!(result.trade_fee, 1); + assert_eq!(result.owner_fee, 0); + } + + #[test] + fn constant_product_swap_calculation_owner_fee() { + // calculation on https://github.com/solana-labs/solana-program-library/issues/341 + let swap_source_amount: u64 = 1000; + let swap_destination_amount: u64 = 50000; + let trade_fee_numerator: u64 = 0; + let trade_fee_denominator: u64 = 0; + let owner_trade_fee_numerator: u64 = 1; + let owner_trade_fee_denominator: u64 = 100; + let owner_withdraw_fee_numerator: u64 = 0; + let owner_withdraw_fee_denominator: u64 = 0; + let source_amount: u64 = 100; + let curve = ConstantProductCurve { + trade_fee_numerator, + trade_fee_denominator, + owner_trade_fee_numerator, + owner_trade_fee_denominator, + owner_withdraw_fee_numerator, + owner_withdraw_fee_denominator, + }; + let result = curve + .swap(source_amount, swap_source_amount, swap_destination_amount) + .unwrap(); + assert_eq!(result.new_source_amount, 1100); + assert_eq!(result.amount_swapped, 4505); + assert_eq!(result.new_destination_amount, 45495); + assert_eq!(result.trade_fee, 0); + assert_eq!(result.owner_fee, 1); + } + + #[test] + fn constant_product_swap_no_fee() { + let swap_source_amount: u64 = 1000; + let swap_destination_amount: u64 = 50000; + let source_amount: u64 = 100; + let curve = ConstantProductCurve::default(); + let result = curve + .swap(source_amount, swap_source_amount, swap_destination_amount) + .unwrap(); + assert_eq!(result.new_source_amount, 1100); + assert_eq!(result.amount_swapped, 4546); + assert_eq!(result.new_destination_amount, 45454); } #[test] fn flat_swap_calculation() { let swap_source_amount: u64 = 1000; let swap_destination_amount: u64 = 50000; - let fee_numerator: u64 = 1; - let fee_denominator: u64 = 100; + let trade_fee_numerator: u64 = 1; + let trade_fee_denominator: u64 = 100; + let owner_trade_fee_numerator: u64 = 2; + let owner_trade_fee_denominator: u64 = 100; + let owner_withdraw_fee_numerator: u64 = 2; + let owner_withdraw_fee_denominator: u64 = 100; let source_amount: u64 = 100; let curve = FlatCurve { - fee_numerator, - fee_denominator, + trade_fee_numerator, + trade_fee_denominator, + owner_trade_fee_numerator, + owner_trade_fee_denominator, + owner_withdraw_fee_numerator, + owner_withdraw_fee_denominator, }; let result = curve .swap(source_amount, swap_source_amount, swap_destination_amount) .unwrap(); - let amount_swapped = 99; + let amount_swapped = 97; assert_eq!(result.new_source_amount, 1100); assert_eq!(result.amount_swapped, amount_swapped); + assert_eq!(result.trade_fee, 1); + assert_eq!(result.owner_fee, 2); assert_eq!( result.new_destination_amount, swap_destination_amount - amount_swapped @@ -441,11 +647,19 @@ mod tests { #[test] fn pack_flat_curve() { - let fee_numerator = 1; - let fee_denominator = 4; + let trade_fee_numerator = 1; + let trade_fee_denominator = 4; + let owner_trade_fee_numerator = 2; + let owner_trade_fee_denominator = 5; + let owner_withdraw_fee_numerator = 4; + let owner_withdraw_fee_denominator = 10; let curve = FlatCurve { - fee_numerator, - fee_denominator, + trade_fee_numerator, + trade_fee_denominator, + owner_trade_fee_numerator, + owner_trade_fee_denominator, + owner_withdraw_fee_numerator, + owner_withdraw_fee_denominator, }; let mut packed = [0u8; FlatCurve::LEN]; @@ -454,19 +668,31 @@ mod tests { assert_eq!(curve, unpacked); let mut packed = vec![]; - packed.extend_from_slice(&fee_numerator.to_le_bytes()); - packed.extend_from_slice(&fee_denominator.to_le_bytes()); + packed.extend_from_slice(&trade_fee_numerator.to_le_bytes()); + packed.extend_from_slice(&trade_fee_denominator.to_le_bytes()); + packed.extend_from_slice(&owner_trade_fee_numerator.to_le_bytes()); + packed.extend_from_slice(&owner_trade_fee_denominator.to_le_bytes()); + packed.extend_from_slice(&owner_withdraw_fee_numerator.to_le_bytes()); + packed.extend_from_slice(&owner_withdraw_fee_denominator.to_le_bytes()); let unpacked = FlatCurve::unpack(&packed).unwrap(); assert_eq!(curve, unpacked); } #[test] fn pack_constant_product_curve() { - let fee_numerator = 1; - let fee_denominator = 4; + let trade_fee_numerator = 1; + let trade_fee_denominator = 4; + let owner_trade_fee_numerator = 2; + let owner_trade_fee_denominator = 5; + let owner_withdraw_fee_numerator = 4; + let owner_withdraw_fee_denominator = 10; let curve = ConstantProductCurve { - fee_numerator, - fee_denominator, + trade_fee_numerator, + trade_fee_denominator, + owner_trade_fee_numerator, + owner_trade_fee_denominator, + owner_withdraw_fee_numerator, + owner_withdraw_fee_denominator, }; let mut packed = [0u8; ConstantProductCurve::LEN]; @@ -475,19 +701,31 @@ mod tests { assert_eq!(curve, unpacked); let mut packed = vec![]; - packed.extend_from_slice(&fee_numerator.to_le_bytes()); - packed.extend_from_slice(&fee_denominator.to_le_bytes()); + packed.extend_from_slice(&trade_fee_numerator.to_le_bytes()); + packed.extend_from_slice(&trade_fee_denominator.to_le_bytes()); + packed.extend_from_slice(&owner_trade_fee_numerator.to_le_bytes()); + packed.extend_from_slice(&owner_trade_fee_denominator.to_le_bytes()); + packed.extend_from_slice(&owner_withdraw_fee_numerator.to_le_bytes()); + packed.extend_from_slice(&owner_withdraw_fee_denominator.to_le_bytes()); let unpacked = ConstantProductCurve::unpack(&packed).unwrap(); assert_eq!(curve, unpacked); } #[test] fn pack_swap_curve() { - let fee_numerator = 1; - let fee_denominator = 4; + let trade_fee_numerator = 1; + let trade_fee_denominator = 4; + let owner_trade_fee_numerator = 2; + let owner_trade_fee_denominator = 5; + let owner_withdraw_fee_numerator = 4; + let owner_withdraw_fee_denominator = 10; let curve = ConstantProductCurve { - fee_numerator, - fee_denominator, + trade_fee_numerator, + trade_fee_denominator, + owner_trade_fee_numerator, + owner_trade_fee_denominator, + owner_withdraw_fee_numerator, + owner_withdraw_fee_denominator, }; let curve_type = CurveType::ConstantProduct; let swap_curve = SwapCurve { @@ -502,9 +740,13 @@ mod tests { let mut packed = vec![]; packed.push(curve_type as u8); - packed.extend_from_slice(&fee_numerator.to_le_bytes()); - packed.extend_from_slice(&fee_denominator.to_le_bytes()); - packed.extend_from_slice(&[0u8; 48]); // padding + packed.extend_from_slice(&trade_fee_numerator.to_le_bytes()); + packed.extend_from_slice(&trade_fee_denominator.to_le_bytes()); + packed.extend_from_slice(&owner_trade_fee_numerator.to_le_bytes()); + packed.extend_from_slice(&owner_trade_fee_denominator.to_le_bytes()); + packed.extend_from_slice(&owner_withdraw_fee_numerator.to_le_bytes()); + packed.extend_from_slice(&owner_withdraw_fee_denominator.to_le_bytes()); + packed.extend_from_slice(&[0u8; 16]); // padding let unpacked = SwapCurve::unpack_from_slice(&packed).unwrap(); assert_eq!(swap_curve, unpacked); } diff --git a/token-swap/program/src/error.rs b/token-swap/program/src/error.rs index b6bc7cd6..21f9ac03 100644 --- a/token-swap/program/src/error.rs +++ b/token-swap/program/src/error.rs @@ -46,8 +46,8 @@ pub enum SwapError { /// The output token is invalid for swap. #[error("InvalidOutput")] InvalidOutput, - /// The calculation failed. - #[error("CalculationFailure")] + /// General calculation failure due to overflow or underflow + #[error("General calculation failure due to overflow or underflow")] CalculationFailure, /// Invalid instruction number passed in. #[error("Invalid instruction")] @@ -64,6 +64,15 @@ pub enum SwapError { /// The pool token mint has a freeze authority. #[error("Pool token mint has a freeze authority")] InvalidFreezeAuthority, + /// The pool fee token account is incorrect + #[error("Pool fee token account incorrect")] + IncorrectFeeAccount, + /// Given pool token amount results in zero trading tokens + #[error("Given pool token amount results in zero trading tokens")] + ZeroTradingTokens, + /// The fee calculation failed due to overflow, underflow, or unexpected 0 + #[error("Fee calculation failed due to overflow, underflow, or unexpected 0")] + FeeCalculationFailure, } impl From for ProgramError { fn from(e: SwapError) -> Self { diff --git a/token-swap/program/src/instruction.rs b/token-swap/program/src/instruction.rs index cff4f273..0227ee01 100644 --- a/token-swap/program/src/instruction.rs +++ b/token-swap/program/src/instruction.rs @@ -2,7 +2,7 @@ #![allow(clippy::too_many_arguments)] -use crate::curve::{ConstantProductCurve, CurveType, FlatCurve, SwapCurve}; +use crate::curve::SwapCurve; use crate::error::SwapError; use solana_sdk::{ instruction::{AccountMeta, Instruction}, @@ -24,8 +24,11 @@ pub enum SwapInstruction { /// 2. `[]` token_a Account. Must be non zero, owned by $authority. /// 3. `[]` token_b Account. Must be non zero, owned by $authority. /// 4. `[writable]` Pool Token Mint. Must be empty, owned by $authority. - /// 5. `[writable]` Pool Token Account to deposit the minted tokens. Must be empty, owned by user. - /// 6. '[]` Token program id + /// 5. `[]` Pool Token Account to deposit trading and withdraw fees. + /// Must be empty, not owned by $authority + /// 6. `[writable]` Pool Token Account to deposit the initial pool token + /// supply. Must be empty, not owned by $authority. + /// 7. '[]` Token program id Initialize { /// nonce used to create valid program address nonce: u8, @@ -42,7 +45,9 @@ pub enum SwapInstruction { /// 3. `[writable]` token_(A|B) Base Account to swap INTO. Must be the SOURCE token. /// 4. `[writable]` token_(A|B) Base Account to swap FROM. Must be the DESTINATION token. /// 5. `[writable]` token_(A|B) DESTINATION Account assigned to USER as the owner. - /// 6. '[]` Token program id + /// 6. `[writable]` Pool token mint, to generate trading fees + /// 7. `[writable]` Fee account, to receive trading fees + /// 8. '[]` Token program id Swap { /// SOURCE amount to transfer, output to DESTINATION is based on the exchange rate amount_in: u64, @@ -82,7 +87,8 @@ pub enum SwapInstruction { /// 5. `[writable]` token_b Swap Account to withdraw FROM. /// 6. `[writable]` token_a user Account to credit. /// 7. `[writable]` token_b user Account to credit. - /// 8. '[]` Token program id + /// 8. `[writable]` Fee account, to receive withdrawal fees + /// 9. '[]` Token program id Withdraw { /// Amount of pool tokens to burn. User receives an output of token a /// and b based on the percentage of the pool tokens that are returned. @@ -203,25 +209,11 @@ pub fn initialize( token_a_pubkey: &Pubkey, token_b_pubkey: &Pubkey, pool_pubkey: &Pubkey, + fee_pubkey: &Pubkey, destination_pubkey: &Pubkey, nonce: u8, - curve_type: CurveType, - fee_numerator: u64, - fee_denominator: u64, + swap_curve: SwapCurve, ) -> Result { - let swap_curve = SwapCurve { - curve_type, - calculator: match curve_type { - CurveType::ConstantProduct => Box::new(ConstantProductCurve { - fee_numerator, - fee_denominator, - }), - CurveType::Flat => Box::new(FlatCurve { - fee_numerator, - fee_denominator, - }), - }, - }; let init_data = SwapInstruction::Initialize { nonce, swap_curve }; let data = init_data.pack(); @@ -231,6 +223,7 @@ pub fn initialize( AccountMeta::new_readonly(*token_a_pubkey, false), AccountMeta::new_readonly(*token_b_pubkey, false), AccountMeta::new(*pool_pubkey, false), + AccountMeta::new_readonly(*fee_pubkey, false), AccountMeta::new(*destination_pubkey, false), AccountMeta::new_readonly(*token_program_id, false), ]; @@ -291,6 +284,7 @@ pub fn withdraw( swap_pubkey: &Pubkey, authority_pubkey: &Pubkey, pool_mint_pubkey: &Pubkey, + fee_account_pubkey: &Pubkey, source_pubkey: &Pubkey, swap_token_a_pubkey: &Pubkey, swap_token_b_pubkey: &Pubkey, @@ -316,6 +310,7 @@ pub fn withdraw( AccountMeta::new(*swap_token_b_pubkey, false), AccountMeta::new(*destination_token_a_pubkey, false), AccountMeta::new(*destination_token_b_pubkey, false), + AccountMeta::new(*fee_account_pubkey, false), AccountMeta::new_readonly(*token_program_id, false), ]; @@ -336,6 +331,8 @@ pub fn swap( swap_source_pubkey: &Pubkey, swap_destination_pubkey: &Pubkey, destination_pubkey: &Pubkey, + pool_mint_pubkey: &Pubkey, + pool_fee_pubkey: &Pubkey, amount_in: u64, minimum_amount_out: u64, ) -> Result { @@ -352,6 +349,8 @@ pub fn swap( AccountMeta::new(*swap_source_pubkey, false), AccountMeta::new(*swap_destination_pubkey, false), AccountMeta::new(*destination_pubkey, false), + AccountMeta::new(*pool_mint_pubkey, false), + AccountMeta::new(*pool_fee_pubkey, false), AccountMeta::new_readonly(*token_program_id, false), ]; @@ -377,15 +376,25 @@ pub fn unpack(input: &[u8]) -> Result<&T, ProgramError> { mod tests { use super::*; + use crate::curve::{CurveType, FlatCurve}; + #[test] fn test_instruction_packing() { - let fee_numerator: u64 = 1; - let fee_denominator: u64 = 4; + let trade_fee_numerator: u64 = 1; + let trade_fee_denominator: u64 = 4; + let owner_trade_fee_numerator: u64 = 2; + let owner_trade_fee_denominator: u64 = 5; + let owner_withdraw_fee_numerator: u64 = 1; + let owner_withdraw_fee_denominator: u64 = 3; let nonce: u8 = 255; let curve_type = CurveType::Flat; let calculator = Box::new(FlatCurve { - fee_numerator, - fee_denominator, + trade_fee_numerator, + trade_fee_denominator, + owner_trade_fee_numerator, + owner_trade_fee_denominator, + owner_withdraw_fee_numerator, + owner_withdraw_fee_denominator, }); let swap_curve = SwapCurve { curve_type, @@ -397,9 +406,13 @@ mod tests { expect.push(0 as u8); expect.push(nonce); expect.push(curve_type as u8); - expect.extend_from_slice(&fee_numerator.to_le_bytes()); - expect.extend_from_slice(&fee_denominator.to_le_bytes()); - expect.extend_from_slice(&[0u8; 48]); // padding + expect.extend_from_slice(&trade_fee_numerator.to_le_bytes()); + expect.extend_from_slice(&trade_fee_denominator.to_le_bytes()); + expect.extend_from_slice(&owner_trade_fee_numerator.to_le_bytes()); + expect.extend_from_slice(&owner_trade_fee_denominator.to_le_bytes()); + expect.extend_from_slice(&owner_withdraw_fee_numerator.to_le_bytes()); + expect.extend_from_slice(&owner_withdraw_fee_denominator.to_le_bytes()); + expect.extend_from_slice(&[0u8; 16]); // padding assert_eq!(packed, expect); let unpacked = SwapInstruction::unpack(&expect).unwrap(); assert_eq!(unpacked, check); diff --git a/token-swap/program/src/processor.rs b/token-swap/program/src/processor.rs index 20390e2d..45008dee 100644 --- a/token-swap/program/src/processor.rs +++ b/token-swap/program/src/processor.rs @@ -27,6 +27,10 @@ const SWAP_PROGRAM_ID: Pubkey = Pubkey::new_from_array([2u8; 32]); #[cfg(not(target_arch = "bpf"))] const TOKEN_PROGRAM_ID: Pubkey = Pubkey::new_from_array([1u8; 32]); +/// Hardcode the number of token types in a pool, used to calculate the +/// equivalent pool tokens for the owner trading fee. +const TOKENS_IN_POOL: u64 = 2; + /// Program state handler. pub struct Processor {} impl Processor { @@ -146,6 +150,7 @@ impl Processor { let token_a_info = next_account_info(account_info_iter)?; let token_b_info = next_account_info(account_info_iter)?; let pool_mint_info = next_account_info(account_info_iter)?; + let fee_account_info = next_account_info(account_info_iter)?; let destination_info = next_account_info(account_info_iter)?; let token_program_info = next_account_info(account_info_iter)?; @@ -159,6 +164,7 @@ impl Processor { } let token_a = Self::unpack_token_account(&token_a_info.data.borrow())?; let token_b = Self::unpack_token_account(&token_b_info.data.borrow())?; + let fee_account = Self::unpack_token_account(&fee_account_info.data.borrow())?; let destination = Self::unpack_token_account(&destination_info.data.borrow())?; let pool_mint = Self::unpack_mint(&pool_mint_info.data.borrow())?; if *authority_info.key != token_a.owner { @@ -170,6 +176,9 @@ impl Processor { if *authority_info.key == destination.owner { return Err(SwapError::InvalidOutputOwner.into()); } + if *authority_info.key == fee_account.owner { + return Err(SwapError::InvalidOutputOwner.into()); + } if COption::Some(*authority_info.key) != pool_mint.mint_authority { return Err(SwapError::InvalidOwner.into()); } @@ -202,6 +211,9 @@ impl Processor { if pool_mint.freeze_authority.is_some() { return Err(SwapError::InvalidFreezeAuthority.into()); } + if *pool_mint_info.key != fee_account.mint { + return Err(SwapError::IncorrectPoolMint.into()); + } let initial_amount = swap_curve.calculator.new_pool_supply(); @@ -222,6 +234,9 @@ impl Processor { token_a: *token_a_info.key, token_b: *token_b_info.key, pool_mint: *pool_mint_info.key, + token_a_mint: token_a.mint, + token_b_mint: token_b.mint, + pool_fee_account: *fee_account_info.key, swap_curve, }; SwapInfo::pack(obj, &mut swap_info.data.borrow_mut())?; @@ -242,6 +257,8 @@ impl Processor { let swap_source_info = next_account_info(account_info_iter)?; let swap_destination_info = next_account_info(account_info_iter)?; let destination_info = next_account_info(account_info_iter)?; + let pool_mint_info = next_account_info(account_info_iter)?; + let pool_fee_account_info = next_account_info(account_info_iter)?; let token_program_info = next_account_info(account_info_iter)?; let token_swap = SwapInfo::unpack(&swap_info.data.borrow())?; @@ -262,14 +279,22 @@ impl Processor { if *swap_source_info.key == *swap_destination_info.key { return Err(SwapError::InvalidInput.into()); } + if *pool_mint_info.key != token_swap.pool_mint { + return Err(SwapError::IncorrectPoolMint.into()); + } + if *pool_fee_account_info.key != token_swap.pool_fee_account { + return Err(SwapError::IncorrectFeeAccount.into()); + } + let source_account = Self::unpack_token_account(&swap_source_info.data.borrow())?; let dest_account = Self::unpack_token_account(&swap_destination_info.data.borrow())?; + let pool_mint = Self::unpack_mint(&pool_mint_info.data.borrow())?; let result = token_swap .swap_curve .calculator .swap(amount_in, source_account.amount, dest_account.amount) - .ok_or(SwapError::CalculationFailure)?; + .ok_or(SwapError::ZeroTradingTokens)?; if result.amount_swapped < minimum_amount_out { return Err(SwapError::ExceededSlippage.into()); } @@ -291,6 +316,30 @@ impl Processor { token_swap.nonce, result.amount_swapped, )?; + + // mint pool tokens equivalent to the owner fee + let source_account = Self::unpack_token_account(&swap_source_info.data.borrow())?; + let pool_token_amount = token_swap + .swap_curve + .calculator + .owner_fee_to_pool_tokens( + result.owner_fee, + source_account.amount, + pool_mint.supply, + TOKENS_IN_POOL, + ) + .ok_or(SwapError::FeeCalculationFailure)?; + if pool_token_amount > 0 { + Self::token_mint_to( + swap_info.key, + token_program_info.clone(), + pool_mint_info.clone(), + pool_fee_account_info.clone(), + authority_info.clone(), + token_swap.nonce, + pool_token_amount, + )?; + } Ok(()) } @@ -334,14 +383,14 @@ impl Processor { let calculator = token_swap.swap_curve.calculator; let a_amount = calculator - .liquidity_tokens(pool_token_amount, pool_mint.supply, token_a.amount) - .ok_or(SwapError::CalculationFailure)?; + .pool_tokens_to_trading_tokens(pool_token_amount, pool_mint.supply, token_a.amount) + .ok_or(SwapError::ZeroTradingTokens)?; if a_amount > maximum_token_a_amount { return Err(SwapError::ExceededSlippage.into()); } let b_amount = calculator - .liquidity_tokens(pool_token_amount, pool_mint.supply, token_b.amount) - .ok_or(SwapError::CalculationFailure)?; + .pool_tokens_to_trading_tokens(pool_token_amount, pool_mint.supply, token_b.amount) + .ok_or(SwapError::ZeroTradingTokens)?; if b_amount > maximum_token_b_amount { return Err(SwapError::ExceededSlippage.into()); } @@ -394,6 +443,7 @@ impl Processor { let token_b_info = next_account_info(account_info_iter)?; let dest_token_a_info = next_account_info(account_info_iter)?; let dest_token_b_info = next_account_info(account_info_iter)?; + let pool_fee_account_info = next_account_info(account_info_iter)?; let token_program_info = next_account_info(account_info_iter)?; let token_swap = SwapInfo::unpack(&swap_info.data.borrow())?; @@ -409,6 +459,9 @@ impl Processor { if *pool_mint_info.key != token_swap.pool_mint { return Err(SwapError::IncorrectPoolMint.into()); } + if *pool_fee_account_info.key != token_swap.pool_fee_account { + return Err(SwapError::IncorrectFeeAccount.into()); + } let token_a = Self::unpack_token_account(&token_a_info.data.borrow())?; let token_b = Self::unpack_token_account(&token_b_info.data.borrow())?; @@ -416,15 +469,27 @@ impl Processor { let calculator = token_swap.swap_curve.calculator; - let a_amount = calculator - .liquidity_tokens(pool_token_amount, pool_mint.supply, token_a.amount) + let withdraw_fee = if *pool_fee_account_info.key == *source_info.key { + // withdrawing from the fee account, don't assess withdraw fee + 0 + } else { + calculator + .owner_withdraw_fee(pool_token_amount) + .ok_or(SwapError::FeeCalculationFailure)? + }; + let pool_token_amount = pool_token_amount + .checked_sub(withdraw_fee) .ok_or(SwapError::CalculationFailure)?; + + let a_amount = calculator + .pool_tokens_to_trading_tokens(pool_token_amount, pool_mint.supply, token_a.amount) + .ok_or(SwapError::ZeroTradingTokens)?; if a_amount < minimum_token_a_amount { return Err(SwapError::ExceededSlippage.into()); } let b_amount = calculator - .liquidity_tokens(pool_token_amount, pool_mint.supply, token_b.amount) - .ok_or(SwapError::CalculationFailure)?; + .pool_tokens_to_trading_tokens(pool_token_amount, pool_mint.supply, token_b.amount) + .ok_or(SwapError::ZeroTradingTokens)?; if b_amount < minimum_token_b_amount { return Err(SwapError::ExceededSlippage.into()); } @@ -447,6 +512,17 @@ impl Processor { token_swap.nonce, b_amount, )?; + if withdraw_fee > 0 { + Self::token_transfer( + swap_info.key, + token_program_info.clone(), + source_info.clone(), + pool_fee_account_info.clone(), + authority_info.clone(), + token_swap.nonce, + withdraw_fee, + )?; + } Self::token_burn( swap_info.key, token_program_info.clone(), @@ -585,6 +661,13 @@ impl PrintProgramError for SwapError { SwapError::InvalidFreezeAuthority => { info!("Error: Pool token mint has a freeze authority") } + SwapError::IncorrectFeeAccount => info!("Error: Pool fee token account incorrect"), + SwapError::ZeroTradingTokens => { + info!("Error: Given pool token amount results in zero withdrawal amount") + } + SwapError::FeeCalculationFailure => info!( + "Error: The fee calculation failed due to overflow, underflow, or unexpected 0" + ), } } } @@ -618,14 +701,14 @@ mod tests { struct SwapAccountInfo { nonce: u8, - curve_type: CurveType, authority_key: Pubkey, - fee_numerator: u64, - fee_denominator: u64, + swap_curve: SwapCurve, swap_key: Pubkey, swap_account: Account, pool_mint_key: Pubkey, pool_mint_account: Account, + pool_fee_key: Pubkey, + pool_fee_account: Account, pool_token_key: Pubkey, pool_token_account: Account, token_a_key: Pubkey, @@ -641,9 +724,7 @@ mod tests { impl SwapAccountInfo { pub fn new( user_key: &Pubkey, - curve_type: CurveType, - fee_numerator: u64, - fee_denominator: u64, + swap_curve: SwapCurve, token_a_amount: u64, token_b_amount: u64, ) -> Self { @@ -662,6 +743,14 @@ mod tests { &user_key, 0, ); + let (pool_fee_key, pool_fee_account) = mint_token( + &TOKEN_PROGRAM_ID, + &pool_mint_key, + &mut pool_mint_account, + &authority_key, + &user_key, + 0, + ); let (token_a_mint_key, mut token_a_mint_account) = create_mint(&TOKEN_PROGRAM_ID, &user_key, None); let (token_a_key, token_a_account) = mint_token( @@ -685,14 +774,14 @@ mod tests { SwapAccountInfo { nonce, - curve_type, authority_key, - fee_numerator, - fee_denominator, + swap_curve, swap_key, swap_account, pool_mint_key, pool_mint_account, + pool_fee_key, + pool_fee_account, pool_token_key, pool_token_account, token_a_key, @@ -716,11 +805,10 @@ mod tests { &self.token_a_key, &self.token_b_key, &self.pool_mint_key, + &self.pool_fee_key, &self.pool_token_key, self.nonce, - self.curve_type, - self.fee_numerator, - self.fee_denominator, + self.swap_curve.clone(), ) .unwrap(), vec![ @@ -729,6 +817,7 @@ mod tests { &mut self.token_a_account, &mut self.token_b_account, &mut self.pool_mint_account, + &mut self.pool_fee_account, &mut self.pool_token_account, &mut Account::default(), ], @@ -842,6 +931,8 @@ mod tests { &swap_source_key, &swap_destination_key, &user_destination_key, + &self.pool_mint_key, + &self.pool_fee_key, amount_in, minimum_amount_out, ) @@ -853,6 +944,8 @@ mod tests { &mut swap_source_account, &mut swap_destination_account, &mut user_destination_account, + &mut self.pool_mint_account, + &mut self.pool_fee_account, &mut Account::default(), ], )?; @@ -983,6 +1076,7 @@ mod tests { &self.swap_key, &self.authority_key, &self.pool_mint_key, + &self.pool_fee_key, &pool_key, &self.token_a_key, &self.token_b_key, @@ -1002,6 +1096,7 @@ mod tests { &mut self.token_b_account, &mut token_a_account, &mut token_b_account, + &mut self.pool_fee_account, &mut Account::default(), ], ) @@ -1166,21 +1261,30 @@ mod tests { #[test] fn test_initialize() { let user_key = pubkey_rand(); - let fee_numerator = 1; - let fee_denominator = 2; + let trade_fee_numerator = 1; + let trade_fee_denominator = 2; + let owner_trade_fee_numerator = 1; + let owner_trade_fee_denominator = 10; + let owner_withdraw_fee_numerator = 1; + let owner_withdraw_fee_denominator = 5; let token_a_amount = 1000; let token_b_amount = 2000; let pool_token_amount = 10; let curve_type = CurveType::ConstantProduct; - - let mut accounts = SwapAccountInfo::new( - &user_key, + let swap_curve = SwapCurve { curve_type, - fee_numerator, - fee_denominator, - token_a_amount, - token_b_amount, - ); + calculator: Box::new(ConstantProductCurve { + trade_fee_numerator, + trade_fee_denominator, + owner_trade_fee_numerator, + owner_trade_fee_denominator, + owner_withdraw_fee_numerator, + owner_withdraw_fee_denominator, + }), + }; + + let mut accounts = + SwapAccountInfo::new(&user_key, swap_curve, token_a_amount, token_b_amount); // wrong nonce for authority_key { @@ -1283,6 +1387,25 @@ mod tests { accounts.pool_token_account = old_account; } + // pool fee account owner is swap authority + { + let (_pool_fee_key, pool_fee_account) = mint_token( + &TOKEN_PROGRAM_ID, + &accounts.pool_mint_key, + &mut accounts.pool_mint_account, + &accounts.authority_key, + &accounts.authority_key, + 0, + ); + let old_account = accounts.pool_fee_account; + accounts.pool_fee_account = pool_fee_account; + assert_eq!( + Err(SwapError::InvalidOutputOwner.into()), + accounts.initialize_swap() + ); + accounts.pool_fee_account = old_account; + } + // pool mint authority is not swap authority { let (_pool_mint_key, pool_mint_account) = @@ -1392,6 +1515,25 @@ mod tests { accounts.pool_token_account = old_pool_account; } + // pool fee account has wrong mint + { + let (_pool_fee_key, pool_fee_account) = mint_token( + &TOKEN_PROGRAM_ID, + &accounts.token_a_mint_key, + &mut accounts.token_a_mint_account, + &user_key, + &user_key, + 0, + ); + let old_account = accounts.pool_fee_account; + accounts.pool_fee_account = pool_fee_account; + assert_eq!( + Err(SwapError::IncorrectPoolMint.into()), + accounts.initialize_swap() + ); + accounts.pool_fee_account = old_account; + } + // token A account is delegated { do_process_instruction( @@ -1550,11 +1692,10 @@ mod tests { &accounts.token_a_key, &accounts.token_b_key, &accounts.pool_mint_key, + &accounts.pool_fee_key, &accounts.pool_token_key, accounts.nonce, - accounts.curve_type, - accounts.fee_numerator, - accounts.fee_denominator, + accounts.swap_curve.clone(), ) .unwrap(), vec![ @@ -1563,6 +1704,7 @@ mod tests { &mut accounts.token_a_account, &mut accounts.token_b_account, &mut accounts.pool_mint_account, + &mut accounts.pool_fee_account, &mut accounts.pool_token_account, &mut Account::default(), ], @@ -1594,14 +1736,19 @@ mod tests { // create valid flat swap { - let mut accounts = SwapAccountInfo::new( - &user_key, - CurveType::Flat, - fee_numerator, - fee_denominator, - token_a_amount, - token_b_amount, - ); + let swap_curve = SwapCurve { + curve_type: CurveType::Flat, + calculator: Box::new(FlatCurve { + trade_fee_numerator, + trade_fee_denominator, + owner_trade_fee_numerator, + owner_trade_fee_denominator, + owner_withdraw_fee_numerator, + owner_withdraw_fee_denominator, + }), + }; + let mut accounts = + SwapAccountInfo::new(&user_key, swap_curve, token_a_amount, token_b_amount); accounts.initialize_swap().unwrap(); } @@ -1615,10 +1762,16 @@ mod tests { let swap_info = SwapInfo::unpack(&accounts.swap_account.data).unwrap(); assert_eq!(swap_info.is_initialized, true); assert_eq!(swap_info.nonce, accounts.nonce); - assert_eq!(swap_info.swap_curve.curve_type, accounts.curve_type); + assert_eq!( + swap_info.swap_curve.curve_type, + accounts.swap_curve.curve_type + ); assert_eq!(swap_info.token_a, accounts.token_a_key); assert_eq!(swap_info.token_b, accounts.token_b_key); assert_eq!(swap_info.pool_mint, accounts.pool_mint_key); + assert_eq!(swap_info.token_a_mint, accounts.token_a_mint_key); + assert_eq!(swap_info.token_b_mint, accounts.token_b_mint_key); + assert_eq!(swap_info.pool_fee_account, accounts.pool_fee_key); let token_a = Processor::unpack_token_account(&accounts.token_a_account.data).unwrap(); assert_eq!(token_a.amount, token_a_amount); let token_b = Processor::unpack_token_account(&accounts.token_b_account.data).unwrap(); @@ -1633,20 +1786,29 @@ mod tests { fn test_deposit() { let user_key = pubkey_rand(); let depositor_key = pubkey_rand(); - let fee_numerator = 1; - let fee_denominator = 2; + let trade_fee_numerator = 1; + let trade_fee_denominator = 2; + let owner_trade_fee_numerator = 1; + let owner_trade_fee_denominator = 10; + let owner_withdraw_fee_numerator = 1; + let owner_withdraw_fee_denominator = 5; let token_a_amount = 1000; let token_b_amount = 9000; let curve_type = CurveType::ConstantProduct; - - let mut accounts = SwapAccountInfo::new( - &user_key, + let swap_curve = SwapCurve { curve_type, - fee_numerator, - fee_denominator, - token_a_amount, - token_b_amount, - ); + calculator: Box::new(ConstantProductCurve { + trade_fee_numerator, + trade_fee_denominator, + owner_trade_fee_numerator, + owner_trade_fee_denominator, + owner_withdraw_fee_numerator, + owner_withdraw_fee_denominator, + }), + }; + + let mut accounts = + SwapAccountInfo::new(&user_key, swap_curve, token_a_amount, token_b_amount); let deposit_a = token_a_amount / 10; let deposit_b = token_b_amount / 10; @@ -2044,7 +2206,7 @@ mod tests { mut pool_account, ) = accounts.setup_token_accounts(&user_key, &depositor_key, deposit_a, deposit_b, 0); assert_eq!( - Err(SwapError::CalculationFailure.into()), + Err(SwapError::ZeroTradingTokens.into()), accounts.deposit( &depositor_key, &pool_key, @@ -2153,32 +2315,38 @@ mod tests { #[test] fn test_withdraw() { let user_key = pubkey_rand(); - let fee_numerator = 1; - let fee_denominator = 2; + let trade_fee_numerator = 1; + let trade_fee_denominator = 2; + let owner_trade_fee_numerator = 1; + let owner_trade_fee_denominator = 10; + let owner_withdraw_fee_numerator = 1; + let owner_withdraw_fee_denominator = 5; let token_a_amount = 1000; let token_b_amount = 2000; let curve_type = CurveType::ConstantProduct; - - let mut accounts = SwapAccountInfo::new( - &user_key, + let swap_curve = SwapCurve { curve_type, - fee_numerator, - fee_denominator, - token_a_amount, - token_b_amount, - ); - let withdrawer_key = pubkey_rand(); - let calculator = ConstantProductCurve { - fee_numerator, - fee_denominator, + calculator: Box::new(ConstantProductCurve { + trade_fee_numerator, + trade_fee_denominator, + owner_trade_fee_numerator, + owner_trade_fee_denominator, + owner_withdraw_fee_numerator, + owner_withdraw_fee_denominator, + }), }; + + let withdrawer_key = pubkey_rand(); let initial_a = token_a_amount / 10; let initial_b = token_b_amount / 10; - let initial_pool = calculator.new_pool_supply() / 10; + let initial_pool = swap_curve.calculator.new_pool_supply() / 10; let withdraw_amount = initial_pool / 4; let minimum_a_amount = initial_a / 40; let minimum_b_amount = initial_b / 40; + let mut accounts = + SwapAccountInfo::new(&user_key, swap_curve, token_a_amount, token_b_amount); + // swap not initialized { let ( @@ -2355,6 +2523,59 @@ mod tests { ); } + // wrong pool fee account + { + let ( + token_a_key, + mut token_a_account, + token_b_key, + mut token_b_account, + wrong_pool_key, + wrong_pool_account, + ) = accounts.setup_token_accounts( + &user_key, + &withdrawer_key, + initial_a, + initial_b, + withdraw_amount, + ); + let ( + _token_a_key, + _token_a_account, + _token_b_key, + _token_b_account, + pool_key, + mut pool_account, + ) = accounts.setup_token_accounts( + &user_key, + &withdrawer_key, + initial_a, + initial_b, + withdraw_amount, + ); + let old_pool_fee_account = accounts.pool_fee_account; + let old_pool_fee_key = accounts.pool_fee_key; + accounts.pool_fee_account = wrong_pool_account; + accounts.pool_fee_key = wrong_pool_key; + assert_eq!( + Err(SwapError::IncorrectFeeAccount.into()), + accounts.withdraw( + &withdrawer_key, + &pool_key, + &mut pool_account, + &token_a_key, + &mut token_a_account, + &token_b_key, + &mut token_b_account, + withdraw_amount, + minimum_a_amount, + minimum_b_amount, + ), + ); + accounts.pool_fee_account = old_pool_fee_account; + accounts.pool_fee_key = old_pool_fee_key; + } + // no approval { let ( @@ -2374,6 +2595,7 @@ mod tests { &accounts.swap_key, &accounts.authority_key, &accounts.pool_mint_key, + &accounts.pool_fee_key, &pool_key, &accounts.token_a_key, &accounts.token_b_key, @@ -2393,6 +2615,7 @@ mod tests { &mut accounts.token_b_account, &mut token_a_account, &mut token_b_account, + &mut accounts.pool_fee_account, &mut Account::default(), ], ) @@ -2425,6 +2648,7 @@ mod tests { &accounts.swap_key, &accounts.authority_key, &accounts.pool_mint_key, + &accounts.pool_fee_key, &pool_key, &accounts.token_a_key, &accounts.token_b_key, @@ -2444,6 +2668,7 @@ mod tests { &mut accounts.token_b_account, &mut token_a_account, &mut token_b_account, + &mut accounts.pool_fee_account, &mut Account::default(), ], ) @@ -2580,7 +2805,7 @@ mod tests { initial_pool, ); assert_eq!( - Err(SwapError::CalculationFailure.into()), + Err(SwapError::ZeroTradingTokens.into()), accounts.withdraw( &withdrawer_key, &pool_key, @@ -2683,17 +2908,29 @@ mod tests { let swap_token_b = Processor::unpack_token_account(&accounts.token_b_account.data).unwrap(); let pool_mint = Processor::unpack_mint(&accounts.pool_mint_account.data).unwrap(); - let calculator = ConstantProductCurve { - fee_numerator, - fee_denominator, - }; - - let withdrawn_a = calculator - .liquidity_tokens(withdraw_amount, pool_mint.supply, swap_token_a.amount) + let withdraw_fee = accounts + .swap_curve + .calculator + .owner_withdraw_fee(withdraw_amount) + .unwrap(); + let withdrawn_a = accounts + .swap_curve + .calculator + .pool_tokens_to_trading_tokens( + withdraw_amount - withdraw_fee, + pool_mint.supply, + swap_token_a.amount, + ) .unwrap(); assert_eq!(swap_token_a.amount, token_a_amount - withdrawn_a); - let withdrawn_b = calculator - .liquidity_tokens(withdraw_amount, pool_mint.supply, swap_token_b.amount) + let withdrawn_b = accounts + .swap_curve + .calculator + .pool_tokens_to_trading_tokens( + withdraw_amount - withdraw_fee, + pool_mint.supply, + swap_token_b.amount, + ) .unwrap(); assert_eq!(swap_token_b.amount, token_b_amount - withdrawn_b); let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap(); @@ -2702,14 +2939,75 @@ mod tests { assert_eq!(token_b.amount, initial_b + withdrawn_b); let pool_account = Processor::unpack_token_account(&pool_account.data).unwrap(); assert_eq!(pool_account.amount, initial_pool - withdraw_amount); + let fee_account = + Processor::unpack_token_account(&accounts.pool_fee_account.data).unwrap(); + assert_eq!(fee_account.amount, withdraw_fee); + } + + // correct withdrawal from fee account + { + let ( + token_a_key, + mut token_a_account, + token_b_key, + mut token_b_account, + _pool_key, + mut _pool_account, + ) = accounts.setup_token_accounts(&user_key, &withdrawer_key, 0, 0, 0); + + let pool_fee_key = accounts.pool_fee_key.clone(); + let mut pool_fee_account = accounts.pool_fee_account.clone(); + let fee_account = Processor::unpack_token_account(&pool_fee_account.data).unwrap(); + let pool_fee_amount = fee_account.amount; + + accounts + .withdraw( + &user_key, + &pool_fee_key, + &mut pool_fee_account, + &token_a_key, + &mut token_a_account, + &token_b_key, + &mut token_b_account, + pool_fee_amount, + 0, + 0, + ) + .unwrap(); + + let swap_token_a = + Processor::unpack_token_account(&accounts.token_a_account.data).unwrap(); + let swap_token_b = + Processor::unpack_token_account(&accounts.token_b_account.data).unwrap(); + let pool_mint = Processor::unpack_mint(&accounts.pool_mint_account.data).unwrap(); + let withdrawn_a = accounts + .swap_curve + .calculator + .pool_tokens_to_trading_tokens( + pool_fee_amount, + pool_mint.supply, + swap_token_a.amount, + ) + .unwrap(); + let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap(); + assert_eq!(token_a.amount, withdrawn_a); + let withdrawn_b = accounts + .swap_curve + .calculator + .pool_tokens_to_trading_tokens( + pool_fee_amount, + pool_mint.supply, + swap_token_b.amount, + ) + .unwrap(); + let token_b = Processor::unpack_token_account(&token_b_account.data).unwrap(); + assert_eq!(token_b.amount, withdrawn_b); } } fn check_valid_swap_curve(curve_type: CurveType, calculator: Box) { let user_key = pubkey_rand(); let swapper_key = pubkey_rand(); - let fee_numerator = 1; - let fee_denominator = 10; let token_a_amount = 1000; let token_b_amount = 5000; @@ -2720,9 +3018,7 @@ mod tests { let mut accounts = SwapAccountInfo::new( &user_key, - curve_type, - fee_numerator, - fee_denominator, + swap_curve.clone(), token_a_amount, token_b_amount, ); @@ -2744,6 +3040,8 @@ mod tests { // swap one way let a_to_b_amount = initial_a / 10; let minimum_b_amount = 0; + let pool_mint = Processor::unpack_mint(&accounts.pool_mint_account.data).unwrap(); + let initial_supply = pool_mint.supply; accounts .swap( &swapper_key, @@ -2775,9 +3073,24 @@ mod tests { let token_b = Processor::unpack_token_account(&token_b_account.data).unwrap(); assert_eq!(token_b.amount, initial_b + results.amount_swapped); + let first_fee = swap_curve + .calculator + .owner_fee_to_pool_tokens( + results.owner_fee, + token_a_amount, + initial_supply, + TOKENS_IN_POOL, + ) + .unwrap(); + let fee_account = Processor::unpack_token_account(&accounts.pool_fee_account.data).unwrap(); + assert_eq!(fee_account.amount, first_fee); + let first_swap_amount = results.amount_swapped; // swap the other way + let pool_mint = Processor::unpack_mint(&accounts.pool_mint_account.data).unwrap(); + let initial_supply = pool_mint.supply; + let b_to_a_amount = initial_b / 10; let minimum_a_amount = 0; accounts @@ -2800,7 +3113,8 @@ mod tests { .unwrap(); let swap_token_a = Processor::unpack_token_account(&accounts.token_a_account.data).unwrap(); - assert_eq!(swap_token_a.amount, results.new_destination_amount); + let token_a_amount = swap_token_a.amount; + assert_eq!(token_a_amount, results.new_destination_amount); let token_a = Processor::unpack_token_account(&token_a_account.data).unwrap(); assert_eq!( token_a.amount, @@ -2808,30 +3122,55 @@ mod tests { ); let swap_token_b = Processor::unpack_token_account(&accounts.token_b_account.data).unwrap(); - assert_eq!(swap_token_b.amount, results.new_source_amount); + let token_b_amount = swap_token_b.amount; + assert_eq!(token_b_amount, results.new_source_amount); let token_b = Processor::unpack_token_account(&token_b_account.data).unwrap(); assert_eq!( token_b.amount, initial_b + first_swap_amount - b_to_a_amount ); + + let second_fee = swap_curve + .calculator + .owner_fee_to_pool_tokens( + results.owner_fee, + token_b_amount, + initial_supply, + TOKENS_IN_POOL, + ) + .unwrap(); + let fee_account = Processor::unpack_token_account(&accounts.pool_fee_account.data).unwrap(); + assert_eq!(fee_account.amount, first_fee + second_fee); } #[test] fn test_valid_swap_curves() { - let fee_numerator = 1; - let fee_denominator = 10; + let trade_fee_numerator = 1; + let trade_fee_denominator = 10; + let owner_trade_fee_numerator = 1; + let owner_trade_fee_denominator = 30; + let owner_withdraw_fee_numerator = 1; + let owner_withdraw_fee_denominator = 30; check_valid_swap_curve( CurveType::ConstantProduct, Box::new(ConstantProductCurve { - fee_numerator, - fee_denominator, + trade_fee_numerator, + trade_fee_denominator, + owner_trade_fee_numerator, + owner_trade_fee_denominator, + owner_withdraw_fee_numerator, + owner_withdraw_fee_denominator, }), ); check_valid_swap_curve( CurveType::Flat, Box::new(FlatCurve { - fee_numerator, - fee_denominator, + trade_fee_numerator, + trade_fee_denominator, + owner_trade_fee_numerator, + owner_trade_fee_denominator, + owner_withdraw_fee_numerator, + owner_withdraw_fee_denominator, }), ); } @@ -2840,20 +3179,29 @@ mod tests { fn test_invalid_swap() { let user_key = pubkey_rand(); let swapper_key = pubkey_rand(); - let fee_numerator = 1; - let fee_denominator = 10; + let trade_fee_numerator = 1; + let trade_fee_denominator = 4; + let owner_trade_fee_numerator = 1; + let owner_trade_fee_denominator = 10; + let owner_withdraw_fee_numerator = 1; + let owner_withdraw_fee_denominator = 5; let token_a_amount = 1000; let token_b_amount = 5000; let curve_type = CurveType::ConstantProduct; - - let mut accounts = SwapAccountInfo::new( - &user_key, + let swap_curve = SwapCurve { curve_type, - fee_numerator, - fee_denominator, - token_a_amount, - token_b_amount, - ); + calculator: Box::new(ConstantProductCurve { + trade_fee_numerator, + trade_fee_denominator, + owner_trade_fee_numerator, + owner_trade_fee_denominator, + owner_withdraw_fee_numerator, + owner_withdraw_fee_denominator, + }), + }; + let mut accounts = + SwapAccountInfo::new(&user_key, swap_curve, token_a_amount, token_b_amount); + let initial_a = token_a_amount / 5; let initial_b = token_b_amount / 5; let minimum_b_amount = initial_b / 2; @@ -2945,6 +3293,8 @@ mod tests { &accounts.token_a_key, &accounts.token_b_key, &token_b_key, + &accounts.pool_mint_key, + &accounts.pool_fee_key, initial_a, minimum_b_amount, ) @@ -2956,6 +3306,8 @@ mod tests { &mut accounts.token_a_account, &mut accounts.token_b_account, &mut token_b_account, + &mut accounts.pool_mint_account, + &mut accounts.pool_fee_account, &mut Account::default(), ], ), @@ -3010,6 +3362,8 @@ mod tests { &token_a_key, &token_b_key, &token_b_key, + &accounts.pool_mint_key, + &accounts.pool_fee_key, initial_a, minimum_b_amount, ) @@ -3021,6 +3375,8 @@ mod tests { &mut token_a_account, &mut token_b_account.clone(), &mut token_b_account, + &mut accounts.pool_mint_account, + &mut accounts.pool_fee_account, &mut Account::default(), ], ), @@ -3079,6 +3435,74 @@ mod tests { ); } + // incorrect mint provided + { + let ( + token_a_key, + mut token_a_account, + token_b_key, + mut token_b_account, + _pool_key, + _pool_account, + ) = accounts.setup_token_accounts(&user_key, &swapper_key, initial_a, initial_b, 0); + let (pool_mint_key, pool_mint_account) = + create_mint(&TOKEN_PROGRAM_ID, &accounts.authority_key, None); + let old_pool_key = accounts.pool_mint_key; + let old_pool_account = accounts.pool_mint_account; + accounts.pool_mint_key = pool_mint_key; + accounts.pool_mint_account = pool_mint_account; + + assert_eq!( + Err(SwapError::IncorrectPoolMint.into()), + accounts.swap( + &swapper_key, + &token_a_key, + &mut token_a_account, + &swap_token_a_key, + &swap_token_b_key, + &token_b_key, + &mut token_b_account, + initial_a, + minimum_b_amount, + ) + ); + + accounts.pool_mint_key = old_pool_key; + accounts.pool_mint_account = old_pool_account; + } + + // incorrect fee account provided + { + let ( + token_a_key, + mut token_a_account, + token_b_key, + mut token_b_account, + wrong_pool_key, + wrong_pool_account, + ) = accounts.setup_token_accounts(&user_key, &swapper_key, initial_a, initial_b, 0); + let old_pool_fee_account = accounts.pool_fee_account; + let old_pool_fee_key = accounts.pool_fee_key; + accounts.pool_fee_account = wrong_pool_account; + accounts.pool_fee_key = wrong_pool_key; + assert_eq!( + Err(SwapError::IncorrectFeeAccount.into()), + accounts.swap( + &swapper_key, + &token_a_key, + &mut token_a_account, + &swap_token_a_key, + &swap_token_b_key, + &token_b_key, + &mut token_b_account, + initial_a, + minimum_b_amount, + ) + ); + accounts.pool_fee_account = old_pool_fee_account; + accounts.pool_fee_key = old_pool_fee_key; + } + // no approval { let ( @@ -3101,6 +3525,8 @@ mod tests { &accounts.token_a_key, &accounts.token_b_key, &token_b_key, + &accounts.pool_mint_key, + &accounts.pool_fee_key, initial_a, minimum_b_amount, ) @@ -3112,6 +3538,8 @@ mod tests { &mut accounts.token_a_account, &mut accounts.token_b_account, &mut token_b_account, + &mut accounts.pool_mint_account, + &mut accounts.pool_fee_account, &mut Account::default(), ], ), @@ -3129,7 +3557,7 @@ mod tests { _pool_account, ) = accounts.setup_token_accounts(&user_key, &swapper_key, initial_a, initial_b, 0); assert_eq!( - Err(SwapError::CalculationFailure.into()), + Err(SwapError::ZeroTradingTokens.into()), accounts.swap( &swapper_key, &token_b_key, diff --git a/token-swap/program/src/state.rs b/token-swap/program/src/state.rs index e722a298..7b648fed 100644 --- a/token-swap/program/src/state.rs +++ b/token-swap/program/src/state.rs @@ -25,14 +25,22 @@ pub struct SwapInfo { pub token_program_id: Pubkey, /// Token A - /// The Liquidity token is issued against this value. pub token_a: Pubkey, /// Token B pub token_b: Pubkey, + /// Pool tokens are issued when A or B tokens are deposited. /// Pool tokens can be withdrawn back to the original A or B token. pub pool_mint: Pubkey, + /// Mint information for token A + pub token_a_mint: Pubkey, + /// Mint information for token B + pub token_b_mint: Pubkey, + + /// Pool token account to receive trading and / or withdrawal fees + pub pool_fee_account: Pubkey, + /// Swap curve parameters, to be unpacked and used by the SwapCurve, which /// calculates swaps, deposits, and withdrawals pub swap_curve: SwapCurve, @@ -46,14 +54,24 @@ impl IsInitialized for SwapInfo { } impl Pack for SwapInfo { - const LEN: usize = 195; + const LEN: usize = 291; /// Unpacks a byte buffer into a [SwapInfo](struct.SwapInfo.html). fn unpack_from_slice(input: &[u8]) -> Result { - let input = array_ref![input, 0, 195]; + let input = array_ref![input, 0, 291]; #[allow(clippy::ptr_offset_with_cast)] - let (is_initialized, nonce, token_program_id, token_a, token_b, pool_mint, swap_curve) = - array_refs![input, 1, 1, 32, 32, 32, 32, 65]; + let ( + is_initialized, + nonce, + token_program_id, + token_a, + token_b, + pool_mint, + token_a_mint, + token_b_mint, + pool_fee_account, + swap_curve, + ) = array_refs![input, 1, 1, 32, 32, 32, 32, 32, 32, 32, 65]; Ok(Self { is_initialized: match is_initialized { [0] => false, @@ -65,20 +83,36 @@ impl Pack for SwapInfo { token_a: Pubkey::new_from_array(*token_a), token_b: Pubkey::new_from_array(*token_b), pool_mint: Pubkey::new_from_array(*pool_mint), + token_a_mint: Pubkey::new_from_array(*token_a_mint), + token_b_mint: Pubkey::new_from_array(*token_b_mint), + pool_fee_account: Pubkey::new_from_array(*pool_fee_account), swap_curve: SwapCurve::unpack_from_slice(swap_curve)?, }) } fn pack_into_slice(&self, output: &mut [u8]) { - let output = array_mut_ref![output, 0, 195]; - let (is_initialized, nonce, token_program_id, token_a, token_b, pool_mint, swap_curve) = - mut_array_refs![output, 1, 1, 32, 32, 32, 32, 65]; + let output = array_mut_ref![output, 0, 291]; + let ( + is_initialized, + nonce, + token_program_id, + token_a, + token_b, + pool_mint, + token_a_mint, + token_b_mint, + pool_fee_account, + swap_curve, + ) = mut_array_refs![output, 1, 1, 32, 32, 32, 32, 32, 32, 32, 65]; is_initialized[0] = self.is_initialized as u8; nonce[0] = self.nonce; token_program_id.copy_from_slice(self.token_program_id.as_ref()); token_a.copy_from_slice(self.token_a.as_ref()); token_b.copy_from_slice(self.token_b.as_ref()); pool_mint.copy_from_slice(self.pool_mint.as_ref()); + token_a_mint.copy_from_slice(self.token_a_mint.as_ref()); + token_b_mint.copy_from_slice(self.token_b_mint.as_ref()); + pool_fee_account.copy_from_slice(self.pool_fee_account.as_ref()); self.swap_curve.pack_into_slice(&mut swap_curve[..]); } } @@ -99,15 +133,29 @@ mod tests { let token_a_raw = [1u8; 32]; let token_b_raw = [2u8; 32]; let pool_mint_raw = [3u8; 32]; + let token_a_mint_raw = [4u8; 32]; + let token_b_mint_raw = [5u8; 32]; + let pool_fee_account_raw = [6u8; 32]; let token_program_id = Pubkey::new_from_array(token_program_id_raw); let token_a = Pubkey::new_from_array(token_a_raw); let token_b = Pubkey::new_from_array(token_b_raw); let pool_mint = Pubkey::new_from_array(pool_mint_raw); - let fee_numerator = 1; - let fee_denominator = 4; + let token_a_mint = Pubkey::new_from_array(token_a_mint_raw); + let token_b_mint = Pubkey::new_from_array(token_b_mint_raw); + let pool_fee_account = Pubkey::new_from_array(pool_fee_account_raw); + let trade_fee_numerator = 1; + let trade_fee_denominator = 4; + let owner_trade_fee_numerator = 3; + let owner_trade_fee_denominator = 10; + let owner_withdraw_fee_numerator = 2; + let owner_withdraw_fee_denominator = 7; let calculator = Box::new(FlatCurve { - fee_numerator, - fee_denominator, + trade_fee_numerator, + trade_fee_denominator, + owner_trade_fee_numerator, + owner_trade_fee_denominator, + owner_withdraw_fee_numerator, + owner_withdraw_fee_denominator, }); let swap_curve = SwapCurve { curve_type, @@ -121,6 +169,9 @@ mod tests { token_a, token_b, pool_mint, + token_a_mint, + token_b_mint, + pool_fee_account, swap_curve, }; @@ -136,12 +187,17 @@ mod tests { packed.extend_from_slice(&token_a_raw); packed.extend_from_slice(&token_b_raw); packed.extend_from_slice(&pool_mint_raw); + packed.extend_from_slice(&token_a_mint_raw); + packed.extend_from_slice(&token_b_mint_raw); + packed.extend_from_slice(&pool_fee_account_raw); packed.push(curve_type_raw); - packed.push(fee_numerator as u8); - packed.extend_from_slice(&[0u8; 7]); // padding - packed.push(fee_denominator as u8); - packed.extend_from_slice(&[0u8; 7]); // padding - packed.extend_from_slice(&[0u8; 48]); // padding + packed.extend_from_slice(&trade_fee_numerator.to_le_bytes()); + packed.extend_from_slice(&trade_fee_denominator.to_le_bytes()); + packed.extend_from_slice(&owner_trade_fee_numerator.to_le_bytes()); + packed.extend_from_slice(&owner_trade_fee_denominator.to_le_bytes()); + packed.extend_from_slice(&owner_withdraw_fee_numerator.to_le_bytes()); + packed.extend_from_slice(&owner_withdraw_fee_denominator.to_le_bytes()); + packed.extend_from_slice(&[0u8; 16]); // padding let unpacked = SwapInfo::unpack(&packed).unwrap(); assert_eq!(swap_info, unpacked);