Add checked_math library for convenient overflow checking

Instead of
    x.checked_add(y).ok_or(error!(MangoError::MathError))?
we can write
    cm!(x + y)
This commit is contained in:
Christian Kamm 2022-03-11 09:57:30 +01:00
parent fce2316b03
commit 449fe4dc6d
18 changed files with 315 additions and 39 deletions

32
Cargo.lock generated
View File

@ -515,6 +515,17 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "checked_math"
version = "0.1.0"
dependencies = [
"proc-macro-error",
"proc-macro2",
"quote",
"syn",
"trybuild",
]
[[package]]
name = "chrono"
version = "0.4.19"
@ -1158,6 +1169,12 @@ version = "0.26.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78cc372d058dcf6d5ecd98510e7fbc9e5aec4d21de70f65fea8fecebcd881bd4"
[[package]]
name = "glob"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574"
[[package]]
name = "goblin"
version = "0.4.3"
@ -1536,6 +1553,7 @@ dependencies = [
"base64 0.13.0",
"bincode",
"bytemuck",
"checked_math",
"env_logger",
"fixed",
"fixed-macro",
@ -3343,6 +3361,20 @@ version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642"
[[package]]
name = "trybuild"
version = "1.0.56"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d60539445867cdd9680b2bfe2d0428f1814b7d5c9652f09d8d3eae9d19308db"
dependencies = [
"glob",
"once_cell",
"serde",
"serde_json",
"termcolor",
"toml",
]
[[package]]
name = "typenum"
version = "1.15.0"

View File

@ -1,4 +1,5 @@
[workspace]
members = [
"programs/*"
"programs/*",
"lib/*"
]

View File

@ -0,0 +1,24 @@
[package]
name = "checked_math"
description = "Proc macros for changing the checking behavior of math expressions"
version = "0.1.0"
authors = ["Ryan Levick<ryan.levick@gmail.com>", "Christian Kamm <mail@ckamm.de>"]
license = "MIT"
edition = "2021"
autotests = false
[lib]
proc-macro = true
[[test]]
name = "tests"
path = "tests/progress.rs"
[dev-dependencies]
trybuild = "1.0"
[dependencies]
syn = { version = "1.0.86", features = ["full", "extra-traits"] }
quote = "1.0.15"
proc-macro2 = "1.0.36"
proc-macro-error = "1.0.4"

21
lib/checked_math/LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2019 Ryan Levick
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,5 @@
# Source
This is a modified version of `overflow` from https://github.com/rylev/overflow/
originally by Ryan Levick. See LICENSE.

View File

@ -0,0 +1,26 @@
extern crate proc_macro;
mod transform;
use proc_macro::TokenStream;
use proc_macro_error::proc_macro_error;
use syn::parse_macro_input;
/// Produces a semantically equivalent expression as the one provided
/// except that each math call is substituted with the equivalent version
/// of the `checked` API.
///
/// Examples:
/// - `checked_math!{ 1 }` will become `Some(1)`
/// - `checked_math!{ a + b }` will become `a.checked_add(b)`
///
/// The macro is intened to be used for arithmetic expressions only and
/// significantly restricts the available syntax.
#[proc_macro]
#[proc_macro_error]
pub fn checked_math(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as syn::Expr);
let expanded = transform::checked::transform_expr(input);
TokenStream::from(expanded)
}

View File

@ -0,0 +1,108 @@
use proc_macro_error::abort;
use quote::quote;
use syn::{spanned::Spanned, BinOp, Expr, ExprBinary, ExprUnary, Ident, Lit, UnOp};
pub fn transform_expr(mut expr: Expr) -> proc_macro2::TokenStream {
match expr {
Expr::Unary(unary) => transform_unary(unary),
Expr::Binary(binary) => transform_binary(binary),
Expr::MethodCall(ref mut mc) => {
if mc.method == "pow" {
mc.method = syn::Ident::new("checked_pow", mc.method.span());
quote! { #mc }
} else if mc.method == "abs" {
mc.method = syn::Ident::new("checked_abs", mc.method.span());
quote! { #mc }
} else if mc.args.is_empty() {
quote! { Some(#mc) }
} else {
abort!(mc, "method calls with arguments are not supported");
}
}
Expr::Call(ref mut c) => {
if c.args.is_empty() {
quote! { Some(#c) }
} else {
abort!(c, "calls with arguments are not supported");
}
}
Expr::Paren(p) => {
let expr = transform_expr(*p.expr);
quote! {
(#expr)
}
}
Expr::Group(g) => {
let expr = transform_expr(*g.expr);
quote! {
(#expr)
}
}
Expr::Lit(lit) => match lit.lit {
Lit::Int(_) | Lit::Float(_) => quote! { Some(#lit) },
_ => abort!(lit, "unsupported literal"),
},
Expr::Path(_) | Expr::Field(_) => {
quote! { Some(#expr) }
}
_ => {
abort!(expr, "unsupported expr {:?}", expr);
}
}
}
fn transform_unary(unary: ExprUnary) -> proc_macro2::TokenStream {
let expr = transform_expr(*unary.expr);
let op = unary.op;
match op {
UnOp::Neg(_) => {
quote! {
{
match #expr {
Some(e) => e.checked_neg(),
None => None
}
}
}
}
UnOp::Deref(_) => quote! { #expr },
UnOp::Not(_) => abort!(expr, "unsupported unary expr"),
}
}
fn transform_binary(binary: ExprBinary) -> proc_macro2::TokenStream {
let left = transform_expr(*binary.left);
let right = transform_expr(*binary.right);
let op = binary.op;
let method_name = match op {
BinOp::Add(_) => Some("checked_add"),
BinOp::Sub(_) => Some("checked_sub"),
BinOp::Mul(_) => Some("checked_mul"),
BinOp::Div(_) => Some("checked_div"),
BinOp::Rem(_) => Some("checked_rem"),
BinOp::Shl(_) => Some("checked_shl"),
BinOp::Shr(_) => Some("checked_shr"),
_ => abort!(op, "unsupported binary expr"),
};
method_name
.map(|method_name| {
let method_name = Ident::new(method_name, op.span());
quote! {
{
match (#left, #right) {
(Some(left), Some(right)) => left.#method_name(right),
_ => None
}
}
}
})
.unwrap_or_else(|| {
quote! {
match (#left, #right) {
(Some(left), Some(right)) => left #op right,
_ => None
}
}
})
}

View File

@ -0,0 +1 @@
pub mod checked;

View File

@ -0,0 +1,38 @@
use checked_math::checked_math;
fn f() -> u8 {
3u8
}
struct S {}
impl S {
fn m(&self) -> u8 {
2u8
}
}
fn main() {
let num = 2u8;
let result = checked_math!{ (num + (2u8 / 10)) * 5 };
assert!(result == Some(10));
let result = checked_math!{ ((num.pow(20) << 20) + 255) + 2u8 * 2u8 };
assert!(result == None);
let result = checked_math!{ -std::i8::MIN };
assert!(result == None);
let result = checked_math!{ 12u8 + 6u8 / 3 };
assert!(result == Some(14));
let result = checked_math!{ 12u8 + 6u8 / f() };
assert!(result == Some(14));
let result = checked_math!{ 12u8 + 6u8 / num };
assert!(result == Some(15));
let s = S{};
let result = checked_math!{ 12u8 + s.m() };
assert!(result == Some(14));
}

View File

@ -0,0 +1,5 @@
#[test]
fn tests() {
let t = trybuild::TestCases::new();
t.pass("tests/01-success.rs");
}

View File

@ -31,6 +31,7 @@ serde = "^1.0"
solana-program = "1.9.5"
static_assertions = "1.1"
#serum_dex = { version = "0.4.0", git = "https://github.com/blockworks-foundation/serum-dex.git", default-features=false, features = ["no-entrypoint", "program"] }
checked_math = { path = "../../lib/checked_math" }
[dev-dependencies]

View File

@ -7,6 +7,8 @@ pub enum MangoError {
#[msg("")]
SomeError,
#[msg("")]
MathError,
#[msg("")]
UnexpectedOracle,
#[msg("")]
UnknownOracleType,

View File

@ -61,7 +61,7 @@ pub fn deposit(ctx: Context<Deposit>, amount: u64) -> Result<()> {
// Update the bank and position
let position_is_active = {
let mut bank = ctx.accounts.bank.load_mut()?;
bank.deposit(position, amount)
bank.deposit(position, amount)?
};
// Transfer the actual tokens

View File

@ -176,9 +176,9 @@ fn adjust_for_post_cpi_amounts(
// user has either withdrawn or deposited
if *pre_cpi_amount > vault.amount {
bank.withdraw(&mut position, pre_cpi_amount - vault.amount);
bank.withdraw(&mut position, pre_cpi_amount - vault.amount)?;
} else {
bank.deposit(&mut position, vault.amount - pre_cpi_amount);
bank.deposit(&mut position, vault.amount - pre_cpi_amount)?;
}
}
}

View File

@ -85,7 +85,7 @@ pub fn withdraw(ctx: Context<Withdraw>, amount: u64, allow_borrow: bool) -> Resu
);
// Update the bank and position
let position_is_active = bank.withdraw(position, amount);
let position_is_active = bank.withdraw(position, amount)?;
// Transfer the actual tokens
let group_seeds = group_seeds!(group);

View File

@ -2,6 +2,7 @@ use anchor_lang::prelude::*;
use fixed::types::I80F48;
use super::{IndexedPosition, TokenIndex};
use crate::util::checked_math as cm;
#[account(zero_copy)]
pub struct Bank {
@ -44,80 +45,82 @@ impl Bank {
}
/// Returns whether the position is active
pub fn deposit(&mut self, position: &mut IndexedPosition, native_amount: u64) -> bool {
pub fn deposit(&mut self, position: &mut IndexedPosition, native_amount: u64) -> Result<bool> {
let mut native_amount = I80F48::from_num(native_amount);
let native_position = position.native(self);
if native_position.is_negative() {
let new_native_position = native_position + native_amount;
let new_native_position = cm!(native_position + native_amount);
if new_native_position.is_negative() {
// pay back borrows only, leaving a negative position
let indexed_change = native_amount / self.borrow_index + I80F48::DELTA;
self.indexed_total_borrows -= indexed_change;
position.indexed_value += indexed_change;
return true;
let indexed_change = cm!(native_amount / self.borrow_index + I80F48::DELTA);
self.indexed_total_borrows = cm!(self.indexed_total_borrows - indexed_change);
position.indexed_value = cm!(position.indexed_value + indexed_change);
return Ok(true);
} else if new_native_position < I80F48::ONE {
// if there's less than one token deposited, zero the position
self.dust += new_native_position;
self.indexed_total_borrows += position.indexed_value;
self.dust = cm!(self.dust + new_native_position);
self.indexed_total_borrows =
cm!(self.indexed_total_borrows + position.indexed_value);
position.indexed_value = I80F48::ZERO;
return false;
return Ok(false);
}
// pay back all borrows
self.indexed_total_borrows += position.indexed_value; // position.value is negative
self.indexed_total_borrows = cm!(self.indexed_total_borrows + position.indexed_value); // position.value is negative
position.indexed_value = I80F48::ZERO;
// deposit the rest
native_amount += native_position;
native_amount = cm!(native_amount + native_position);
}
// add to deposits
// Adding DELTA to amount/index helps because (amount/index)*index <= amount, but
// we want to ensure that users can withdraw the same amount they have deposited, so
// (amount/index + delta)*index >= amount is a better guarantee.
let indexed_change = native_amount / self.deposit_index + I80F48::DELTA;
self.indexed_total_deposits += indexed_change;
position.indexed_value += indexed_change;
let indexed_change = cm!(native_amount / self.deposit_index + I80F48::DELTA);
self.indexed_total_deposits = cm!(self.indexed_total_deposits + indexed_change);
position.indexed_value = cm!(position.indexed_value + indexed_change);
true
Ok(true)
}
/// Returns whether the position is active
pub fn withdraw(&mut self, position: &mut IndexedPosition, native_amount: u64) -> bool {
pub fn withdraw(&mut self, position: &mut IndexedPosition, native_amount: u64) -> Result<bool> {
let mut native_amount = I80F48::from_num(native_amount);
let native_position = position.native(self);
if native_position.is_positive() {
let new_native_position = native_position - native_amount;
let new_native_position = cm!(native_position - native_amount);
if !new_native_position.is_negative() {
// withdraw deposits only
if new_native_position < I80F48::ONE {
// zero the account collecting the leftovers in `dust`
self.dust += new_native_position;
self.indexed_total_deposits -= position.indexed_value;
self.dust = cm!(self.dust + new_native_position);
self.indexed_total_deposits =
cm!(self.indexed_total_deposits - position.indexed_value);
position.indexed_value = I80F48::ZERO;
return false;
return Ok(false);
} else {
// withdraw some deposits leaving >1 native token
let indexed_change = native_amount / self.deposit_index;
self.indexed_total_deposits -= indexed_change;
position.indexed_value -= indexed_change;
return true;
let indexed_change = cm!(native_amount / self.deposit_index);
self.indexed_total_deposits = cm!(self.indexed_total_deposits - indexed_change);
position.indexed_value = cm!(position.indexed_value - indexed_change);
return Ok(true);
}
}
// withdraw all deposits
self.indexed_total_deposits -= position.indexed_value;
self.indexed_total_deposits = cm!(self.indexed_total_deposits - position.indexed_value);
position.indexed_value = I80F48::ZERO;
// borrow the rest
native_amount = -new_native_position;
}
// add to borrows
let indexed_change = native_amount / self.borrow_index;
self.indexed_total_borrows += indexed_change;
position.indexed_value -= indexed_change;
let indexed_change = cm!(native_amount / self.borrow_index);
self.indexed_total_borrows = cm!(self.indexed_total_borrows + indexed_change);
position.indexed_value = cm!(position.indexed_value - indexed_change);
true
Ok(true)
}
}

View File

@ -5,6 +5,7 @@ use pyth_client::load_price;
use crate::error::MangoError;
use crate::state::{determine_oracle_type, Bank, MangoAccount, OracleType, StubOracle};
use crate::util;
use crate::util::checked_math as cm;
pub fn compute_health(account: &MangoAccount, ais: &[AccountInfo]) -> Result<I80F48> {
let active_len = account.indexed_positions.iter_active().count();
@ -58,15 +59,16 @@ fn compute_health_detail(
}
};
let native_basis = position.native(&bank) * price;
let native_position = position.native(&bank);
let native_basis = cm!(native_position * price);
if native_basis.is_positive() {
assets += bank.init_asset_weight * native_basis;
assets = cm!(assets + bank.init_asset_weight * native_basis);
} else {
liabilities -= bank.init_liab_weight * native_basis;
liabilities = cm!(liabilities - bank.init_liab_weight * native_basis);
}
}
// TODO: Serum open orders
Ok(assets - liabilities)
Ok(cm!(assets - liabilities))
}

View File

@ -6,5 +6,12 @@ macro_rules! zip {
zip!($($y), +))
)
}
pub(crate) use zip;
#[macro_export]
macro_rules! checked_math {
($x: expr) => {
checked_math::checked_math!($x).ok_or(error!(crate::error::MangoError::MathError))?
};
}
pub(crate) use checked_math;