Improvements to margin_trade

- don't hard-code the group as the first passed account
- token::approve() banks for each token vault
- sign for each bank
- deal with using tokens without an existing position
- handle deactivation of token account if balance goes to 0
This commit is contained in:
Christian Kamm 2022-05-19 13:45:46 +02:00
parent 437f502c79
commit 21af012d1f
6 changed files with 284 additions and 133 deletions

View File

@ -1,10 +1,14 @@
use crate::error::MangoError;
use crate::state::{compute_health_from_fixed_accounts, Bank, Group, HealthType, MangoAccount};
use crate::{group_seeds, Mango};
use crate::util::LoadZeroCopy;
use crate::Mango;
use anchor_lang::prelude::*;
use anchor_spl::token::TokenAccount;
use fixed::types::I80F48;
use solana_program::instruction::Instruction;
use std::cell::Ref;
use std::collections::HashMap;
#[derive(Accounts)]
pub struct MarginTrade<'info> {
pub group: AccountLoader<'info, Group>,
@ -19,165 +23,175 @@ pub struct MarginTrade<'info> {
pub owner: Signer<'info>,
}
struct AllowedVault {
vault_cpi_ai_index: usize,
bank_health_ai_index: usize,
pre_amount: u64,
}
// TODO: add loan fees
pub fn margin_trade<'key, 'accounts, 'remaining, 'info>(
ctx: Context<'key, 'accounts, 'remaining, 'info, MarginTrade<'info>>,
banks_len: usize,
num_health_accounts: usize,
cpi_data: Vec<u8>,
) -> Result<()> {
let group = ctx.accounts.group.load()?;
let mut account = ctx.accounts.account.load_mut()?;
require!(account.is_bankrupt == 0, MangoError::IsBankrupt);
// remaining_accounts layout is expected as follows
// * banks_len number of banks
// * banks_len number of oracles
// * cpi_program
// * cpi_accounts
// Go over the banks passed as health accounts and:
// - Ensure that all banks that are passed in have activated positions.
// This is necessary because maybe the user wants to margin trade on a token
// that the account hasn't used before.
// - Collect the addresses of all banks to potentially sign for in cpi_ais.
// Note: This depends on the particular health account ordering.
let mut allowed_banks = HashMap::<Pubkey, Ref<Bank>>::new();
let mut allowed_vaults = HashMap::<Pubkey, usize>::new();
let health_ais = &ctx.remaining_accounts[0..num_health_accounts];
for (i, ai) in health_ais.iter().enumerate() {
match ai.load::<Bank>() {
Ok(bank) => {
require!(bank.group == account.group, MangoError::SomeError);
account.tokens.get_mut_or_create(bank.token_index)?;
allowed_vaults.insert(bank.vault, i);
allowed_banks.insert(*ai.key, bank);
}
Err(Error::AnchorError(error))
if error.error_code_number == ErrorCode::AccountDiscriminatorMismatch as u32 =>
{
break;
}
Err(error) => return Err(error),
};
}
// assert that user has passed in enough banks, this might be greater than his current
// total number of indexed positions, since
// user might end up withdrawing or depositing and activating a new indexed position
let cpi_program_id = *ctx.remaining_accounts[num_health_accounts].key;
// No self-calls via this method
require!(
banks_len >= account.tokens.iter_active().count(),
MangoError::SomeError // todo: SomeError
cpi_program_id != Mango::id(),
MangoError::InvalidMarginTradeTargetCpiProgram
);
// unpack remaining_accounts
let health_ais = &ctx.remaining_accounts[0..banks_len * 2];
// TODO: This relies on the particular shape of health_ais
let banks = &ctx.remaining_accounts[0..banks_len];
let cpi_program_id = *ctx.remaining_accounts[banks_len * 2].key;
// prepare account for cpi ix
let (cpi_ais, cpi_ams) = {
// we also need the group
let mut cpi_ais = [ctx.accounts.group.to_account_info()].to_vec();
// skip banks, oracles and cpi program from the remaining_accounts
let mut remaining_cpi_ais = ctx.remaining_accounts[banks_len * 2 + 1..].to_vec();
cpi_ais.append(&mut remaining_cpi_ais);
// todo: I'm wondering if there's a way to do this without putting cpi_ais on the heap.
// But fine to defer to the future
let mut cpi_ams = cpi_ais.to_account_metas(Option::None);
// we want group to be the signer, so that token vaults can be credited to or withdrawn from
cpi_ams[0].is_signer = true;
(cpi_ais, cpi_ams)
};
// sanity checks
for cpi_ai in &cpi_ais {
// since we are using group signer seeds to invoke cpi,
// assert that none of the cpi accounts is the mango program to prevent that invoker doesn't
// abuse this ix to do unwanted changes
// Validate the cpi accounts.
// - Collect the signers for each used mango bank, thereby allowing
// withdraws from the associated vaults.
// - Check that each group-owned token account is the vault of one of the allowed banks,
// and track its balance.
let cpi_ais = &ctx.remaining_accounts[num_health_accounts + 1..];
let mut cpi_ams = cpi_ais
.iter()
.flat_map(|item| item.to_account_metas(None))
.collect::<Vec<_>>();
require!(cpi_ais.len() == cpi_ams.len(), MangoError::SomeError);
let mut bank_signer_data = Vec::with_capacity(allowed_banks.len());
let mut used_vaults = Vec::with_capacity(allowed_vaults.len());
for (i, (ai, am)) in cpi_ais.iter().zip(cpi_ams.iter_mut()).enumerate() {
// The cpi is forbidden from calling back into mango indirectly
require!(
cpi_ai.key() != Mango::id(),
ai.key() != Mango::id(),
MangoError::InvalidMarginTradeTargetCpiProgram
);
// assert that user has passed in the bank for every
// token account he wants to deposit/withdraw from in cpi
if cpi_ai.owner == &TokenAccount::owner() {
let maybe_mango_vault_token_account =
Account::<TokenAccount>::try_from(cpi_ai).unwrap();
if maybe_mango_vault_token_account.owner == ctx.accounts.group.key() {
require!(
banks.iter().any(|bank_ai| {
let bank_loader = AccountLoader::<'_, Bank>::try_from(bank_ai).unwrap();
let bank = bank_loader.load().unwrap();
bank.mint == maybe_mango_vault_token_account.mint
}),
// todo: errorcode
MangoError::SomeError
)
// Each allowed bank used in the cpi becomes a signer
if ai.owner == &Mango::id() {
if let Some(bank) = allowed_banks.get(ai.key) {
am.is_signer = true;
// this is the data we'll need later to build the PDA account signer seeds
bank_signer_data.push((bank.token_index.to_le_bytes(), [bank.bump]));
}
}
// Every group-owned token account must be a vault of one of the banks.
if ai.owner == &TokenAccount::owner() {
let token_account = Account::<TokenAccount>::try_from(ai).unwrap();
if token_account.owner == ctx.accounts.group.key() {
if let Some(&bank_index) = allowed_vaults.get(&ai.key) {
used_vaults.push(AllowedVault {
vault_cpi_ai_index: i,
bank_health_ai_index: bank_index,
pre_amount: token_account.amount,
});
} else {
// This is to protect users, because if their cpi deposits to a vault and they forgot
// to pass in the bank for the vault, their account would not be credited.
require!(false, MangoError::SomeError);
}
}
}
}
// compute pre cpi health
// TODO: check maint type?
let pre_cpi_health =
compute_health_from_fixed_accounts(&account, HealthType::Init, health_ais)?;
require!(pre_cpi_health > 0, MangoError::HealthMustBePositive);
compute_health_from_fixed_accounts(&account, HealthType::Maint, health_ais)?;
require!(pre_cpi_health >= 0, MangoError::HealthMustBePositive);
msg!("pre_cpi_health {:?}", pre_cpi_health);
// get rid of Ref<> to avoid limiting the cpi call
drop(allowed_banks);
drop(group);
drop(account);
// prepare and invoke cpi
let cpi_ix = Instruction {
program_id: cpi_program_id,
data: cpi_data,
accounts: cpi_ams,
};
let group_seeds = group_seeds!(group);
let pre_cpi_amounts = get_pre_cpi_amounts(&ctx, &cpi_ais);
solana_program::program::invoke_signed(&cpi_ix, &cpi_ais, &[group_seeds])?;
adjust_for_post_cpi_amounts(
&ctx,
&cpi_ais,
pre_cpi_amounts,
&mut banks.to_vec(),
&mut account,
)?;
let group_key = ctx.accounts.group.key();
let signers = bank_signer_data
.iter()
.map(|(token_index, bump)| {
[
group_key.as_ref(),
b"Bank".as_ref(),
&token_index[..],
&bump[..],
]
})
.collect::<Vec<_>>();
let signers_ref = signers.iter().map(|v| &v[..]).collect::<Vec<_>>();
solana_program::program::invoke_signed(&cpi_ix, &cpi_ais, &signers_ref)?;
let mut account = ctx.accounts.account.load_mut()?;
let inactive_tokens =
adjust_for_post_cpi_vault_amounts(health_ais, cpi_ais, used_vaults, &mut account)?;
// compute post cpi health
// todo: this is not working, the health is computed on old bank state and not taking into account
// withdraws done in adjust_for_post_cpi_token_amounts
let post_cpi_health =
compute_health_from_fixed_accounts(&account, HealthType::Init, health_ais)?;
require!(post_cpi_health > 0, MangoError::HealthMustBePositive);
require!(post_cpi_health >= 0, MangoError::HealthMustBePositive);
msg!("post_cpi_health {:?}", post_cpi_health);
// deactivate inactive token accounts after health check
for raw_token_index in inactive_tokens {
account.tokens.deactivate(raw_token_index);
}
Ok(())
}
fn get_pre_cpi_amounts(ctx: &Context<MarginTrade>, cpi_ais: &[AccountInfo]) -> Vec<u64> {
let mut amounts = vec![];
for token_account in cpi_ais
.iter()
.filter(|ai| ai.owner == &TokenAccount::owner())
{
let vault = Account::<TokenAccount>::try_from(token_account).unwrap();
if vault.owner == ctx.accounts.group.key() {
amounts.push(vault.amount)
}
}
amounts
}
fn adjust_for_post_cpi_amounts(
ctx: &Context<MarginTrade>,
fn adjust_for_post_cpi_vault_amounts(
health_ais: &[AccountInfo],
cpi_ais: &[AccountInfo],
pre_cpi_amounts: Vec<u64>,
banks: &mut [AccountInfo],
used_vaults: Vec<AllowedVault>,
account: &mut MangoAccount,
) -> Result<()> {
let token_accounts_iter = cpi_ais
.iter()
.filter(|ai| ai.owner == &TokenAccount::owner());
for (token_account, pre_cpi_amount) in
// token_accounts and pre_cpi_amounts are assumed to be in correct order
token_accounts_iter.zip(pre_cpi_amounts.iter())
{
let vault = Account::<TokenAccount>::try_from(token_account).unwrap();
if vault.owner == ctx.accounts.group.key() {
// find bank for token account
let bank_ai = banks
.iter()
.find(|bank_ai| {
let bank_loader = AccountLoader::<'_, Bank>::try_from(bank_ai).unwrap();
let bank = bank_loader.load().unwrap();
bank.mint == vault.mint
})
.ok_or(MangoError::SomeError)?; // todo: replace SomeError
let bank_loader = AccountLoader::<'_, Bank>::try_from(bank_ai)?;
let mut bank = bank_loader.load_mut()?;
let position = account.tokens.get_mut_or_create(bank.token_index)?.0;
let change = I80F48::from(vault.amount) - I80F48::from(*pre_cpi_amount);
bank.change_with_fee(position, change)?;
) -> Result<Vec<usize>> {
let mut inactive_token_raw_indexes = Vec::with_capacity(used_vaults.len());
for info in used_vaults {
let vault = Account::<TokenAccount>::try_from(&cpi_ais[info.vault_cpi_ai_index]).unwrap();
let mut bank = health_ais[info.bank_health_ai_index].load_mut::<Bank>()?;
let (position, raw_index) = account.tokens.get_mut_or_create(bank.token_index)?;
let is_active = bank.change_with_fee(
position,
I80F48::from(vault.amount) - I80F48::from(info.pre_amount),
)?;
if !is_active {
inactive_token_raw_indexes.push(raw_index);
}
}
Ok(())
Ok(inactive_token_raw_indexes)
}

View File

@ -1,12 +1,11 @@
use anchor_lang::prelude::*;
use anchor_spl::token::Mint;
use anchor_spl::token::Token;
use anchor_spl::token::TokenAccount;
use anchor_spl::token::{self, Mint, Token, TokenAccount};
use fixed::types::I80F48;
use fixed_macro::types::I80F48;
// TODO: ALTs are unavailable
//use crate::address_lookup_table;
use crate::error::*;
use crate::state::*;
use crate::util::fill16_from_str;
@ -75,6 +74,18 @@ pub struct RegisterToken<'info> {
pub rent: Sysvar<'info, Rent>,
}
impl<'info> RegisterToken<'info> {
pub fn approve_ctx(&self) -> CpiContext<'_, '_, '_, 'info, token::Approve<'info>> {
let program = self.token_program.to_account_info();
let accounts = token::Approve {
to: self.vault.to_account_info(),
delegate: self.bank.to_account_info(),
authority: self.group.to_account_info(),
};
CpiContext::new(program, accounts)
}
}
#[derive(AnchorSerialize, AnchorDeserialize, Default)]
pub struct InterestRateParams {
pub util0: f32,
@ -102,6 +113,17 @@ pub fn register_token(
) -> Result<()> {
// TODO: Error if mint is already configured (technically, init of vault will fail)
// Approve the bank account for withdraws from the vault. This allows us to later sign with a
// bank for foreign cpi calls in margin_trade and thereby give the foreign program the ability
// to withdraw - without the ability to set new delegates or close the token account.
// TODO: we need to refresh this approve occasionally?!
let group = ctx.accounts.group.load()?;
let group_seeds = group_seeds!(group);
token::approve(
ctx.accounts.approve_ctx().with_signer(&[group_seeds]),
u64::MAX,
)?;
let mut bank = ctx.accounts.bank.load_init()?;
*bank = Bank {
name: fill16_from_str(name)?,
@ -130,6 +152,7 @@ pub fn register_token(
liquidation_fee: I80F48::from_num(liquidation_fee),
dust: I80F48::ZERO,
token_index,
bump: *ctx.bumps.get("bank").ok_or(MangoError::SomeError)?,
reserved: Default::default(),
};

View File

@ -99,10 +99,10 @@ pub mod mango_v4 {
pub fn margin_trade<'key, 'accounts, 'remaining, 'info>(
ctx: Context<'key, 'accounts, 'remaining, 'info, MarginTrade<'info>>,
banks_len: usize,
num_health_accounts: usize,
cpi_data: Vec<u8>,
) -> Result<()> {
instructions::margin_trade(ctx, banks_len, cpi_data)
instructions::margin_trade(ctx, num_health_accounts, cpi_data)
}
///

View File

@ -59,9 +59,11 @@ pub struct Bank {
// Index into TokenInfo on the group
pub token_index: TokenIndex,
pub reserved: [u8; 6],
pub bump: u8,
pub reserved: [u8; 5],
}
const_assert_eq!(size_of::<Bank>(), 16 + 32 * 4 + 8 + 16 * 18 + 2 + 6);
const_assert_eq!(size_of::<Bank>(), 16 + 32 * 4 + 8 + 16 * 18 + 3 + 5);
const_assert_eq!(size_of::<Bank>() % 8, 0);
impl std::fmt::Debug for Bank {
@ -342,6 +344,20 @@ impl Bank {
}
}
#[macro_export]
macro_rules! bank_seeds {
( $bank:expr ) => {
&[
$bank.group.as_ref(),
b"Bank".as_ref(),
$bank.token_index.to_le_bytes(),
&[$bank.bump],
]
};
}
pub use bank_seeds;
#[cfg(test)]
mod tests {
use bytemuck::Zeroable;

View File

@ -247,6 +247,7 @@ pub async fn account_position(solana: &SolanaCookie, account: Pubkey, bank: Pubk
pub struct MarginTradeInstruction<'keypair> {
pub account: Pubkey,
pub owner: &'keypair Keypair,
pub mango_token_bank: Pubkey,
pub mango_token_vault: Pubkey,
pub margin_trade_program_id: Pubkey,
pub deposit_account: Pubkey,
@ -265,20 +266,24 @@ impl<'keypair> ClientInstruction for MarginTradeInstruction<'keypair> {
let account: MangoAccount = account_loader.load(&self.account).await.unwrap();
let instruction = Self::Instruction {
banks_len: account.tokens.iter_active().count(),
cpi_data: self.margin_trade_program_ix_cpi_data.clone(),
};
let accounts = Self::Accounts {
group: account.group,
account: self.account,
owner: self.owner.pubkey(),
};
let health_check_metas =
derive_health_check_remaining_account_metas(&account_loader, &account, None, true)
.await;
let health_check_metas = derive_health_check_remaining_account_metas(
&account_loader,
&account,
Some(self.mango_token_bank),
true,
)
.await;
let instruction = Self::Instruction {
num_health_accounts: health_check_metas.len(),
cpi_data: self.margin_trade_program_ix_cpi_data.clone(),
};
let mut instruction = make_instruction(program_id, &accounts, instruction);
instruction.accounts.extend(health_check_metas.into_iter());
@ -287,6 +292,11 @@ impl<'keypair> ClientInstruction for MarginTradeInstruction<'keypair> {
is_writable: false,
is_signer: false,
});
instruction.accounts.push(AccountMeta {
pubkey: self.mango_token_bank,
is_writable: false,
is_signer: false,
});
instruction.accounts.push(AccountMeta {
pubkey: self.mango_token_vault,
is_writable: true,

View File

@ -105,14 +105,15 @@ async fn test_margin_trade() -> Result<(), BanksClientError> {
MarginTradeInstruction {
account,
owner,
mango_token_bank: bank,
mango_token_vault: vault,
margin_trade_program_id: margin_trade.program,
deposit_account: margin_trade.token_account.pubkey(),
deposit_account_owner: margin_trade.token_account_owner,
margin_trade_program_ix_cpi_data: {
let ix = margin_trade::instruction::MarginTrade {
amount_from: 2,
amount_to: 1,
amount_from: withdraw_amount,
amount_to: deposit_amount,
deposit_account_owner_bump_seeds: margin_trade.token_account_bump,
};
ix.data()
@ -133,5 +134,92 @@ async fn test_margin_trade() -> Result<(), BanksClientError> {
withdraw_amount - deposit_amount
);
//
// TEST: Bringing the balance to 0 deactivates the token
//
let deposit_amount_initial = solana.token_account_balance(vault).await;
let margin_account_initial = solana
.token_account_balance(margin_trade.token_account.pubkey())
.await;
let withdraw_amount = deposit_amount_initial;
let deposit_amount = 0;
{
send_tx(
solana,
MarginTradeInstruction {
account,
owner,
mango_token_bank: bank,
mango_token_vault: vault,
margin_trade_program_id: margin_trade.program,
deposit_account: margin_trade.token_account.pubkey(),
deposit_account_owner: margin_trade.token_account_owner,
margin_trade_program_ix_cpi_data: {
let ix = margin_trade::instruction::MarginTrade {
amount_from: withdraw_amount,
amount_to: deposit_amount,
deposit_account_owner_bump_seeds: margin_trade.token_account_bump,
};
ix.data()
},
},
)
.await
.unwrap();
}
assert_eq!(solana.token_account_balance(vault).await, 0);
assert_eq!(
solana
.token_account_balance(margin_trade.token_account.pubkey())
.await,
margin_account_initial + withdraw_amount
);
// Check that position is fully deactivated
let account_data: MangoAccount = solana.get_account(account).await;
assert_eq!(account_data.tokens.iter_active().count(), 0);
//
// TEST: Activating a token via margin trade
//
let margin_account_initial = solana
.token_account_balance(margin_trade.token_account.pubkey())
.await;
let withdraw_amount = 0;
let deposit_amount = margin_account_initial;
{
send_tx(
solana,
MarginTradeInstruction {
account,
owner,
mango_token_bank: bank,
mango_token_vault: vault,
margin_trade_program_id: margin_trade.program,
deposit_account: margin_trade.token_account.pubkey(),
deposit_account_owner: margin_trade.token_account_owner,
margin_trade_program_ix_cpi_data: {
let ix = margin_trade::instruction::MarginTrade {
amount_from: withdraw_amount,
amount_to: deposit_amount,
deposit_account_owner_bump_seeds: margin_trade.token_account_bump,
};
ix.data()
},
},
)
.await
.unwrap();
}
assert_eq!(solana.token_account_balance(vault).await, deposit_amount);
assert_eq!(
solana
.token_account_balance(margin_trade.token_account.pubkey())
.await,
0
);
// Check that position is active
let account_data: MangoAccount = solana.get_account(account).await;
assert_eq!(account_data.tokens.iter_active().count(), 1);
Ok(())
}