Checked math: support cm!(a += b)

This commit is contained in:
Christian Kamm 2022-09-05 14:14:42 +02:00
parent 3c51449040
commit 79a7bdc299
15 changed files with 134 additions and 65 deletions

View File

@ -24,3 +24,13 @@ pub fn checked_math(input: TokenStream) -> TokenStream {
TokenStream::from(expanded)
}
/// Like checked_math(), but panics with "math error" on None results
#[proc_macro]
#[proc_macro_error]
pub fn checked_math_or_panic(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as syn::Expr);
let expanded = transform::checked::transform_expr_or_panic(input);
TokenStream::from(expanded)
}

View File

@ -2,6 +2,39 @@ use proc_macro_error::abort;
use quote::quote;
use syn::{spanned::Spanned, BinOp, Expr, ExprBinary, ExprUnary, Ident, Lit, UnOp};
pub fn transform_expr_or_panic(expr: Expr) -> proc_macro2::TokenStream {
match expr {
Expr::Group(g) => {
let expr = transform_expr_or_panic(*g.expr);
quote! {(#expr)}
}
Expr::AssignOp(assign_op) => {
// Rewrite `left += right` into `left = checked!(left + right).unwrap()`
let bin_op = Expr::Binary(ExprBinary {
attrs: vec![],
left: assign_op.left.clone(),
right: assign_op.right.clone(),
op: match assign_op.op {
BinOp::AddEq(t) => BinOp::Add(syn::token::Add(t.spans[0])),
BinOp::SubEq(t) => BinOp::Sub(syn::token::Sub(t.spans[0])),
BinOp::MulEq(t) => BinOp::Mul(syn::token::Star(t.spans[0])),
BinOp::DivEq(t) => BinOp::Div(syn::token::Div(t.spans[0])),
_ => panic!("unsupported AssignOp.op: {:#?}", assign_op.op),
},
});
let left = assign_op.left;
let bin_op_tokens = transform_expr(bin_op);
quote! {
#left = (#bin_op_tokens).unwrap_or_else(|| panic!("math error"))
}
}
_ => {
let toks = transform_expr(expr);
quote! { (#toks).unwrap_or_else(|| panic!("math error")) }
}
}
}
pub fn transform_expr(mut expr: Expr) -> proc_macro2::TokenStream {
match expr {
Expr::Unary(unary) => transform_unary(unary),

View File

@ -1,4 +1,6 @@
use checked_math::checked_math;
use checked_math::{checked_math, checked_math_or_panic};
use std::cell::RefCell;
use std::rc::Rc;
fn f() -> u8 {
3u8
@ -35,4 +37,32 @@ fn main() {
let s = S{};
let result = checked_math!{ 12u8 + s.m() };
assert!(result == Some(14));
let r = checked_math_or_panic!(num + 4u8);
assert_eq!(r, 6);
let mut m = 2u8;
checked_math_or_panic!(m += 4);
assert_eq!(m, 6);
let g = Rc::new(RefCell::new(0u8));
let single_eval_test = || -> Rc<RefCell<u8>> {
*g.borrow_mut() += 1;
g.clone()
};
*single_eval_test().borrow_mut() += 10;
assert_eq!(*g.borrow(), 11);
// I don't get why this passes:
// The macro should call the left hand side expression multiple times?
checked_math_or_panic!(*single_eval_test().borrow_mut() += 10);
assert_eq!(*g.borrow(), 22);
eprintln!("Ignore STDERR messages if the test passes: the panics were captured");
assert!(std::panic::catch_unwind(|| {
let mut m = 2u8;
checked_math_or_panic!(m /= 0);
assert_eq!(m, 0); // unreached
}).is_err());
}

View File

@ -283,7 +283,7 @@ pub fn flash_loan_end<'key, 'accounts, 'remaining, 'info>(
token::transfer(transfer_ctx, repay)?;
let repay = I80F48::from(repay);
change = cm!(change + repay);
cm!(change += repay);
}
changes.push(TokenVaultChange {
@ -349,7 +349,7 @@ pub fn flash_loan_end<'key, 'accounts, 'remaining, 'info>(
};
let loan_origination_fee = cm!(loan * bank.loan_origination_fee_rate);
bank.collected_fees_native = cm!(bank.collected_fees_native + loan_origination_fee);
cm!(bank.collected_fees_native += loan_origination_fee);
let is_active =
bank.change_without_fee(position, cm!(change.amount - loan_origination_fee))?;

View File

@ -232,7 +232,7 @@ pub fn liq_token_bankruptcy(
let mut indexed_total_deposits = I80F48::ZERO;
for bank_ai in bank_ais.iter() {
let bank = bank_ai.load::<Bank>()?;
indexed_total_deposits = cm!(indexed_total_deposits + bank.indexed_deposits);
cm!(indexed_total_deposits += bank.indexed_deposits);
}
// This is the solution to:
@ -255,7 +255,7 @@ pub fn liq_token_bankruptcy(
// enable dusting, because each deposit() is allowed to round up. thus multiple deposit
// could bring the total position slightly above zero otherwise
liqee_liab_active = bank.deposit_with_dusting(liqee_liab, amount_for_bank)?;
amount_to_credit = cm!(amount_to_credit - amount_for_bank);
cm!(amount_to_credit -= amount_for_bank);
if amount_to_credit <= 0 {
break;
}

View File

@ -87,13 +87,13 @@ pub fn perp_settle_pnl(ctx: Context<PerpSettlePnl>, max_settle_amount: I80F48) -
// Settle for the maximum possible capped to max_settle_amount
let settlement = a_pnl.abs().min(b_pnl.abs()).min(max_settle_amount);
a_perp_position.quote_position_native = cm!(a_perp_position.quote_position_native - settlement);
b_perp_position.quote_position_native = cm!(b_perp_position.quote_position_native + settlement);
cm!(a_perp_position.quote_position_native -= settlement);
cm!(b_perp_position.quote_position_native += settlement);
// Update the account's net_settled with the new PnL
let settlement_i64 = settlement.checked_to_num::<i64>().unwrap();
account_a.fixed.net_settled = cm!(account_a.fixed.net_settled + settlement_i64);
account_b.fixed.net_settled = cm!(account_b.fixed.net_settled - settlement_i64);
cm!(account_a.fixed.net_settled += settlement_i64);
cm!(account_b.fixed.net_settled -= settlement_i64);
// Transfer token balances
// TODO: Need to guarantee that QUOTE_TOKEN_INDEX token exists at this point. I.E. create it when placing perp order.

View File

@ -79,7 +79,7 @@ pub fn token_deposit(ctx: Context<TokenDeposit>, amount: u64) -> Result<()> {
// Update the net deposits - adjust by price so different tokens are on the same basis (in USD terms)
let amount_usd = cm!(amount_i80f48 * oracle_price).to_num::<i64>();
account.fixed.net_deposits = cm!(account.fixed.net_deposits + amount_usd);
cm!(account.fixed.net_deposits += amount_usd);
emit!(TokenBalanceLog {
mango_group: ctx.accounts.group.key(),

View File

@ -85,8 +85,8 @@ pub fn token_update_index_and_rate(ctx: Context<TokenUpdateIndexAndRate>) -> Res
let mut indexed_total_borrows = I80F48::ZERO;
for ai in ctx.remaining_accounts.iter() {
let bank = ai.load::<Bank>()?;
indexed_total_deposits = cm!(indexed_total_deposits + bank.indexed_deposits);
indexed_total_borrows = cm!(indexed_total_borrows + bank.indexed_borrows);
cm!(indexed_total_deposits += bank.indexed_deposits);
cm!(indexed_total_borrows += bank.indexed_borrows);
}
// compute and set latest index and average utilization on each bank
@ -98,7 +98,7 @@ pub fn token_update_index_and_rate(ctx: Context<TokenUpdateIndexAndRate>) -> Res
let (deposit_index, borrow_index, borrow_fees) =
some_bank.compute_index(indexed_total_deposits, indexed_total_borrows, diff_ts)?;
some_bank.collected_fees_native = cm!(some_bank.collected_fees_native + borrow_fees);
cm!(some_bank.collected_fees_native += borrow_fees);
let new_avg_utilization = some_bank.compute_new_avg_utilization(
indexed_total_deposits,

View File

@ -136,7 +136,7 @@ pub fn token_withdraw(ctx: Context<TokenWithdraw>, amount: u64, allow_borrow: bo
// Update the net deposits - adjust by price so different tokens are on the same basis (in USD terms)
let amount_usd = cm!(amount_i80f48 * oracle_price).to_num::<i64>();
account.fixed.net_deposits = cm!(account.fixed.net_deposits - amount_usd);
cm!(account.fixed.net_deposits -= amount_usd);
//
// Health check

View File

@ -233,19 +233,19 @@ impl Bank {
let new_indexed_value = cm!(position.indexed_position + indexed_change);
if new_indexed_value.is_negative() {
// pay back borrows only, leaving a negative position
self.indexed_borrows = cm!(self.indexed_borrows - indexed_change);
cm!(self.indexed_borrows -= indexed_change);
position.indexed_position = new_indexed_value;
return Ok(true);
} else if new_native_position < I80F48::ONE && allow_dusting {
// if there's less than one token deposited, zero the position
self.dust = cm!(self.dust + new_native_position);
self.indexed_borrows = cm!(self.indexed_borrows + position.indexed_position);
cm!(self.dust += new_native_position);
cm!(self.indexed_borrows += position.indexed_position);
position.indexed_position = I80F48::ZERO;
return Ok(false);
}
// pay back all borrows
self.indexed_borrows = cm!(self.indexed_borrows + position.indexed_position); // position.value is negative
cm!(self.indexed_borrows += position.indexed_position); // position.value is negative
position.indexed_position = I80F48::ZERO;
// deposit the rest
// note: .max(0) because there's a scenario where new_indexed_value == 0 and new_native_position < 0
@ -254,8 +254,8 @@ impl Bank {
// add to deposits
let indexed_change = div_rounding_up(native_amount, self.deposit_index);
self.indexed_deposits = cm!(self.indexed_deposits + indexed_change);
position.indexed_position = cm!(position.indexed_position + indexed_change);
cm!(self.indexed_deposits += indexed_change);
cm!(position.indexed_position += indexed_change);
Ok(true)
}
@ -324,21 +324,21 @@ impl Bank {
// withdraw deposits only
if new_native_position < I80F48::ONE && allow_dusting {
// zero the account collecting the leftovers in `dust`
self.dust = cm!(self.dust + new_native_position);
self.indexed_deposits = cm!(self.indexed_deposits - position.indexed_position);
cm!(self.dust += new_native_position);
cm!(self.indexed_deposits -= position.indexed_position);
position.indexed_position = I80F48::ZERO;
return Ok((false, I80F48::ZERO));
} else {
// withdraw some deposits leaving a positive balance
let indexed_change = cm!(native_amount / self.deposit_index);
self.indexed_deposits = cm!(self.indexed_deposits - indexed_change);
position.indexed_position = cm!(position.indexed_position - indexed_change);
cm!(self.indexed_deposits -= indexed_change);
cm!(position.indexed_position -= indexed_change);
return Ok((true, I80F48::ZERO));
}
}
// withdraw all deposits
self.indexed_deposits = cm!(self.indexed_deposits - position.indexed_position);
cm!(self.indexed_deposits -= position.indexed_position);
position.indexed_position = I80F48::ZERO;
// borrow the rest
native_amount = -new_native_position;
@ -347,14 +347,14 @@ impl Bank {
let mut loan_origination_fee = I80F48::ZERO;
if with_loan_origination_fee {
loan_origination_fee = cm!(self.loan_origination_fee_rate * native_amount);
self.collected_fees_native = cm!(self.collected_fees_native + loan_origination_fee);
native_amount = cm!(native_amount + loan_origination_fee);
cm!(self.collected_fees_native += loan_origination_fee);
cm!(native_amount += loan_origination_fee);
}
// add to borrows
let indexed_change = cm!(native_amount / self.borrow_index);
self.indexed_borrows = cm!(self.indexed_borrows + indexed_change);
position.indexed_position = cm!(position.indexed_position - indexed_change);
cm!(self.indexed_borrows += indexed_change);
cm!(position.indexed_position -= indexed_change);
Ok((true, loan_origination_fee))
}
@ -367,7 +367,7 @@ impl Bank {
) -> Result<(bool, I80F48)> {
let loan_origination_fee =
cm!(self.loan_origination_fee_rate * already_borrowed_native_amount);
self.collected_fees_native = cm!(self.collected_fees_native + loan_origination_fee);
cm!(self.collected_fees_native += loan_origination_fee);
let (position_is_active, _) =
self.withdraw_internal(position, loan_origination_fee, false, !position.is_in_use())?;

View File

@ -613,7 +613,7 @@ impl HealthCache {
pub fn health(&self, health_type: HealthType) -> I80F48 {
let mut health = I80F48::ZERO;
let sum = |contrib| {
health = cm!(health + contrib);
cm!(health += contrib);
};
self.health_sum(health_type, sum);
health
@ -657,7 +657,7 @@ impl HealthCache {
// We need to make sure that if balance is before * price, then change = -before
// brings it to exactly zero.
let removed_contribution = (-change) * entry.oracle_price;
entry.balance = cm!(entry.balance - removed_contribution);
cm!(entry.balance -= removed_contribution);
Ok(())
}
@ -682,23 +682,19 @@ impl HealthCache {
}
{
let quote_entry = &mut self.token_infos[quote_entry_index];
reserved_amount =
cm!(reserved_amount + reserved_quote_change * quote_entry.oracle_price);
cm!(reserved_amount += reserved_quote_change * quote_entry.oracle_price);
}
// Apply it to the tokens
{
let base_entry = &mut self.token_infos[base_entry_index];
base_entry.serum3_max_reserved = cm!(base_entry.serum3_max_reserved + reserved_amount);
base_entry.balance =
cm!(base_entry.balance + free_base_change * base_entry.oracle_price);
cm!(base_entry.serum3_max_reserved += reserved_amount);
cm!(base_entry.balance += free_base_change * base_entry.oracle_price);
}
{
let quote_entry = &mut self.token_infos[quote_entry_index];
quote_entry.serum3_max_reserved =
cm!(quote_entry.serum3_max_reserved + reserved_amount);
quote_entry.balance =
cm!(quote_entry.balance + free_quote_change * quote_entry.oracle_price);
cm!(quote_entry.serum3_max_reserved += reserved_amount);
cm!(quote_entry.balance += free_quote_change * quote_entry.oracle_price);
}
// Apply it to the serum3 info
@ -707,7 +703,7 @@ impl HealthCache {
.iter_mut()
.find(|m| m.market_index == market_index)
.ok_or_else(|| error_msg!("serum3 market {} not found", market_index))?;
market_entry.reserved = cm!(market_entry.reserved + reserved_amount);
cm!(market_entry.reserved += reserved_amount);
Ok(())
}
@ -782,9 +778,9 @@ impl HealthCache {
let mut liabs = I80F48::ZERO;
let sum = |contrib| {
if contrib > 0 {
assets = cm!(assets + contrib);
cm!(assets += contrib);
} else {
liabs = cm!(liabs - contrib);
cm!(liabs -= contrib);
}
};
self.health_sum(health_type, sum);
@ -1006,16 +1002,16 @@ pub fn new_health_cache(
// add the amounts that are freely settleable
let base_free = I80F48::from_num(oo.native_coin_free);
let quote_free = I80F48::from_num(cm!(oo.native_pc_free + oo.referrer_rebates_accrued));
base_info.balance = cm!(base_info.balance + base_free * base_info.oracle_price);
quote_info.balance = cm!(quote_info.balance + quote_free * quote_info.oracle_price);
cm!(base_info.balance += base_free * base_info.oracle_price);
cm!(quote_info.balance += quote_free * quote_info.oracle_price);
// add the reserved amount to both sides, to have the worst-case covered
let reserved_base = I80F48::from_num(cm!(oo.native_coin_total - oo.native_coin_free));
let reserved_quote = I80F48::from_num(cm!(oo.native_pc_total - oo.native_pc_free));
let reserved_balance =
cm!(reserved_base * base_info.oracle_price + reserved_quote * quote_info.oracle_price);
base_info.serum3_max_reserved = cm!(base_info.serum3_max_reserved + reserved_balance);
quote_info.serum3_max_reserved = cm!(quote_info.serum3_max_reserved + reserved_balance);
cm!(base_info.serum3_max_reserved += reserved_balance);
cm!(quote_info.serum3_max_reserved += reserved_balance);
serum3_infos.push(Serum3Info {
reserved: reserved_balance,

View File

@ -739,10 +739,10 @@ impl<
let mut perp_account = self.ensure_perp_position(perp_market_index).unwrap().0;
match side {
Side::Bid => {
perp_account.bids_base_lots = cm!(perp_account.bids_base_lots + order.quantity);
cm!(perp_account.bids_base_lots += order.quantity);
}
Side::Ask => {
perp_account.asks_base_lots = cm!(perp_account.asks_base_lots + order.quantity);
cm!(perp_account.asks_base_lots += order.quantity);
}
};
let slot = order.owner_slot as usize;
@ -767,10 +767,10 @@ impl<
// accounting
match order_side {
Side::Bid => {
perp_account.bids_base_lots = cm!(perp_account.bids_base_lots - quantity);
cm!(perp_account.bids_base_lots -= quantity);
}
Side::Ask => {
perp_account.asks_base_lots = cm!(perp_account.asks_base_lots - quantity);
cm!(perp_account.asks_base_lots -= quantity);
}
}
}
@ -814,10 +814,10 @@ impl<
} else {
match side {
Side::Bid => {
pa.bids_base_lots = cm!(pa.bids_base_lots - base_change.abs());
cm!(pa.bids_base_lots -= base_change.abs());
}
Side::Ask => {
pa.asks_base_lots = cm!(pa.asks_base_lots - base_change.abs());
cm!(pa.asks_base_lots -= base_change.abs());
}
}
Ok(())

View File

@ -212,19 +212,19 @@ impl PerpPosition {
pub fn add_taker_trade(&mut self, side: Side, base_lots: i64, quote_lots: i64) {
match side {
Side::Bid => {
self.taker_base_lots = cm!(self.taker_base_lots + base_lots);
self.taker_quote_lots = cm!(self.taker_quote_lots - quote_lots);
cm!(self.taker_base_lots += base_lots);
cm!(self.taker_quote_lots -= quote_lots);
}
Side::Ask => {
self.taker_base_lots = cm!(self.taker_base_lots - base_lots);
self.taker_quote_lots = cm!(self.taker_quote_lots + quote_lots);
cm!(self.taker_base_lots -= base_lots);
cm!(self.taker_quote_lots += quote_lots);
}
}
}
/// Remove taker trade after it has been processed on EventQueue
pub fn remove_taker_trade(&mut self, base_change: i64, quote_change: i64) {
self.taker_base_lots = cm!(self.taker_base_lots - base_change);
self.taker_quote_lots = cm!(self.taker_quote_lots - quote_change);
cm!(self.taker_base_lots -= base_change);
cm!(self.taker_quote_lots -= quote_change);
}
pub fn is_active(&self) -> bool {
@ -273,10 +273,10 @@ impl PerpPosition {
}
let old_position = self.base_position_lots;
let is_increasing = old_position == 0 || old_position.signum() == base_change.signum();
self.quote_running_native = cm!(self.quote_running_native + quote_change);
cm!(self.quote_running_native += quote_change);
match is_increasing {
true => {
self.quote_entry_native = cm!(self.quote_entry_native + quote_change);
cm!(self.quote_entry_native += quote_change);
}
false => {
let new_position = cm!(old_position + base_change);

View File

@ -251,8 +251,8 @@ impl<'a> Book<'a> {
match_base_lots == max_match_by_quote || match_base_lots == remaining_base_lots;
let match_quote_lots = cm!(match_base_lots * best_opposing_price);
remaining_base_lots = cm!(remaining_base_lots - match_base_lots);
remaining_quote_lots = cm!(remaining_quote_lots - match_quote_lots);
cm!(remaining_base_lots -= match_base_lots);
cm!(remaining_quote_lots -= match_quote_lots);
let new_best_opposing_quantity = cm!(best_opposing.quantity - match_base_lots);
let maker_out = new_best_opposing_quantity == 0;

View File

@ -15,7 +15,7 @@ pub(crate) use zip;
#[macro_export]
macro_rules! checked_math {
($x: expr) => {
checked_math::checked_math!($x).unwrap_or_else(|| panic!("math error"))
checked_math::checked_math_or_panic!($x)
};
}
pub(crate) use checked_math;