From 449fe4dc6d36ab15e2f88211b9d1aef149ebcda3 Mon Sep 17 00:00:00 2001 From: Christian Kamm Date: Fri, 11 Mar 2022 09:57:30 +0100 Subject: [PATCH] Add checked_math library for convenient overflow checking Instead of x.checked_add(y).ok_or(error!(MangoError::MathError))? we can write cm!(x + y) --- Cargo.lock | 32 ++++++ Cargo.toml | 3 +- lib/checked_math/Cargo.toml | 24 ++++ lib/checked_math/LICENSE | 21 ++++ lib/checked_math/README.md | 5 + lib/checked_math/src/lib.rs | 26 +++++ lib/checked_math/src/transform/checked.rs | 108 ++++++++++++++++++ lib/checked_math/src/transform/mod.rs | 1 + lib/checked_math/tests/01-success.rs | 38 ++++++ lib/checked_math/tests/progress.rs | 5 + programs/mango-v4/Cargo.toml | 1 + programs/mango-v4/src/error.rs | 2 + programs/mango-v4/src/instructions/deposit.rs | 2 +- .../mango-v4/src/instructions/margin_trade.rs | 4 +- .../mango-v4/src/instructions/withdraw.rs | 2 +- programs/mango-v4/src/state/bank.rs | 61 +++++----- programs/mango-v4/src/state/health.rs | 10 +- programs/mango-v4/src/util.rs | 9 +- 18 files changed, 315 insertions(+), 39 deletions(-) create mode 100644 lib/checked_math/Cargo.toml create mode 100644 lib/checked_math/LICENSE create mode 100644 lib/checked_math/README.md create mode 100644 lib/checked_math/src/lib.rs create mode 100644 lib/checked_math/src/transform/checked.rs create mode 100644 lib/checked_math/src/transform/mod.rs create mode 100644 lib/checked_math/tests/01-success.rs create mode 100644 lib/checked_math/tests/progress.rs diff --git a/Cargo.lock b/Cargo.lock index ed750a6e9..9fcbee44b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -515,6 +515,17 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "checked_math" +version = "0.1.0" +dependencies = [ + "proc-macro-error", + "proc-macro2", + "quote", + "syn", + "trybuild", +] + [[package]] name = "chrono" version = "0.4.19" @@ -1158,6 +1169,12 @@ version = "0.26.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78cc372d058dcf6d5ecd98510e7fbc9e5aec4d21de70f65fea8fecebcd881bd4" +[[package]] +name = "glob" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574" + [[package]] name = "goblin" version = "0.4.3" @@ -1536,6 +1553,7 @@ dependencies = [ "base64 0.13.0", "bincode", "bytemuck", + "checked_math", "env_logger", "fixed", "fixed-macro", @@ -3343,6 +3361,20 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642" +[[package]] +name = "trybuild" +version = "1.0.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d60539445867cdd9680b2bfe2d0428f1814b7d5c9652f09d8d3eae9d19308db" +dependencies = [ + "glob", + "once_cell", + "serde", + "serde_json", + "termcolor", + "toml", +] + [[package]] name = "typenum" version = "1.15.0" diff --git a/Cargo.toml b/Cargo.toml index a60de986d..5a1ddc454 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,5 @@ [workspace] members = [ - "programs/*" + "programs/*", + "lib/*" ] diff --git a/lib/checked_math/Cargo.toml b/lib/checked_math/Cargo.toml new file mode 100644 index 000000000..64e595df0 --- /dev/null +++ b/lib/checked_math/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "checked_math" +description = "Proc macros for changing the checking behavior of math expressions" +version = "0.1.0" +authors = ["Ryan Levick", "Christian Kamm "] +license = "MIT" +edition = "2021" +autotests = false + +[lib] +proc-macro = true + +[[test]] +name = "tests" +path = "tests/progress.rs" + +[dev-dependencies] +trybuild = "1.0" + +[dependencies] +syn = { version = "1.0.86", features = ["full", "extra-traits"] } +quote = "1.0.15" +proc-macro2 = "1.0.36" +proc-macro-error = "1.0.4" diff --git a/lib/checked_math/LICENSE b/lib/checked_math/LICENSE new file mode 100644 index 000000000..3582bf043 --- /dev/null +++ b/lib/checked_math/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Ryan Levick + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/lib/checked_math/README.md b/lib/checked_math/README.md new file mode 100644 index 000000000..efac2b274 --- /dev/null +++ b/lib/checked_math/README.md @@ -0,0 +1,5 @@ +# Source + +This is a modified version of `overflow` from https://github.com/rylev/overflow/ +originally by Ryan Levick. See LICENSE. + diff --git a/lib/checked_math/src/lib.rs b/lib/checked_math/src/lib.rs new file mode 100644 index 000000000..e5cd65b42 --- /dev/null +++ b/lib/checked_math/src/lib.rs @@ -0,0 +1,26 @@ +extern crate proc_macro; + +mod transform; + +use proc_macro::TokenStream; +use proc_macro_error::proc_macro_error; +use syn::parse_macro_input; + +/// Produces a semantically equivalent expression as the one provided +/// except that each math call is substituted with the equivalent version +/// of the `checked` API. +/// +/// Examples: +/// - `checked_math!{ 1 }` will become `Some(1)` +/// - `checked_math!{ a + b }` will become `a.checked_add(b)` +/// +/// The macro is intened to be used for arithmetic expressions only and +/// significantly restricts the available syntax. +#[proc_macro] +#[proc_macro_error] +pub fn checked_math(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as syn::Expr); + let expanded = transform::checked::transform_expr(input); + + TokenStream::from(expanded) +} diff --git a/lib/checked_math/src/transform/checked.rs b/lib/checked_math/src/transform/checked.rs new file mode 100644 index 000000000..d5d3cf4a5 --- /dev/null +++ b/lib/checked_math/src/transform/checked.rs @@ -0,0 +1,108 @@ +use proc_macro_error::abort; +use quote::quote; +use syn::{spanned::Spanned, BinOp, Expr, ExprBinary, ExprUnary, Ident, Lit, UnOp}; + +pub fn transform_expr(mut expr: Expr) -> proc_macro2::TokenStream { + match expr { + Expr::Unary(unary) => transform_unary(unary), + Expr::Binary(binary) => transform_binary(binary), + Expr::MethodCall(ref mut mc) => { + if mc.method == "pow" { + mc.method = syn::Ident::new("checked_pow", mc.method.span()); + quote! { #mc } + } else if mc.method == "abs" { + mc.method = syn::Ident::new("checked_abs", mc.method.span()); + quote! { #mc } + } else if mc.args.is_empty() { + quote! { Some(#mc) } + } else { + abort!(mc, "method calls with arguments are not supported"); + } + } + Expr::Call(ref mut c) => { + if c.args.is_empty() { + quote! { Some(#c) } + } else { + abort!(c, "calls with arguments are not supported"); + } + } + Expr::Paren(p) => { + let expr = transform_expr(*p.expr); + quote! { + (#expr) + } + } + Expr::Group(g) => { + let expr = transform_expr(*g.expr); + quote! { + (#expr) + } + } + Expr::Lit(lit) => match lit.lit { + Lit::Int(_) | Lit::Float(_) => quote! { Some(#lit) }, + _ => abort!(lit, "unsupported literal"), + }, + Expr::Path(_) | Expr::Field(_) => { + quote! { Some(#expr) } + } + _ => { + abort!(expr, "unsupported expr {:?}", expr); + } + } +} + +fn transform_unary(unary: ExprUnary) -> proc_macro2::TokenStream { + let expr = transform_expr(*unary.expr); + let op = unary.op; + match op { + UnOp::Neg(_) => { + quote! { + { + match #expr { + Some(e) => e.checked_neg(), + None => None + } + } + } + } + UnOp::Deref(_) => quote! { #expr }, + UnOp::Not(_) => abort!(expr, "unsupported unary expr"), + } +} + +fn transform_binary(binary: ExprBinary) -> proc_macro2::TokenStream { + let left = transform_expr(*binary.left); + let right = transform_expr(*binary.right); + let op = binary.op; + let method_name = match op { + BinOp::Add(_) => Some("checked_add"), + BinOp::Sub(_) => Some("checked_sub"), + BinOp::Mul(_) => Some("checked_mul"), + BinOp::Div(_) => Some("checked_div"), + BinOp::Rem(_) => Some("checked_rem"), + BinOp::Shl(_) => Some("checked_shl"), + BinOp::Shr(_) => Some("checked_shr"), + _ => abort!(op, "unsupported binary expr"), + }; + method_name + .map(|method_name| { + let method_name = Ident::new(method_name, op.span()); + quote! { + { + match (#left, #right) { + (Some(left), Some(right)) => left.#method_name(right), + _ => None + } + + } + } + }) + .unwrap_or_else(|| { + quote! { + match (#left, #right) { + (Some(left), Some(right)) => left #op right, + _ => None + } + } + }) +} diff --git a/lib/checked_math/src/transform/mod.rs b/lib/checked_math/src/transform/mod.rs new file mode 100644 index 000000000..561a0c291 --- /dev/null +++ b/lib/checked_math/src/transform/mod.rs @@ -0,0 +1 @@ +pub mod checked; diff --git a/lib/checked_math/tests/01-success.rs b/lib/checked_math/tests/01-success.rs new file mode 100644 index 000000000..815a59385 --- /dev/null +++ b/lib/checked_math/tests/01-success.rs @@ -0,0 +1,38 @@ +use checked_math::checked_math; + +fn f() -> u8 { + 3u8 +} + +struct S {} +impl S { + fn m(&self) -> u8 { + 2u8 + } +} + +fn main() { + let num = 2u8; + + let result = checked_math!{ (num + (2u8 / 10)) * 5 }; + assert!(result == Some(10)); + + let result = checked_math!{ ((num.pow(20) << 20) + 255) + 2u8 * 2u8 }; + assert!(result == None); + + let result = checked_math!{ -std::i8::MIN }; + assert!(result == None); + + let result = checked_math!{ 12u8 + 6u8 / 3 }; + assert!(result == Some(14)); + + let result = checked_math!{ 12u8 + 6u8 / f() }; + assert!(result == Some(14)); + + let result = checked_math!{ 12u8 + 6u8 / num }; + assert!(result == Some(15)); + + let s = S{}; + let result = checked_math!{ 12u8 + s.m() }; + assert!(result == Some(14)); +} diff --git a/lib/checked_math/tests/progress.rs b/lib/checked_math/tests/progress.rs new file mode 100644 index 000000000..2a89a9735 --- /dev/null +++ b/lib/checked_math/tests/progress.rs @@ -0,0 +1,5 @@ +#[test] +fn tests() { + let t = trybuild::TestCases::new(); + t.pass("tests/01-success.rs"); +} diff --git a/programs/mango-v4/Cargo.toml b/programs/mango-v4/Cargo.toml index 16e459765..16ae82550 100644 --- a/programs/mango-v4/Cargo.toml +++ b/programs/mango-v4/Cargo.toml @@ -31,6 +31,7 @@ serde = "^1.0" solana-program = "1.9.5" static_assertions = "1.1" #serum_dex = { version = "0.4.0", git = "https://github.com/blockworks-foundation/serum-dex.git", default-features=false, features = ["no-entrypoint", "program"] } +checked_math = { path = "../../lib/checked_math" } [dev-dependencies] diff --git a/programs/mango-v4/src/error.rs b/programs/mango-v4/src/error.rs index c8ebeaf6b..59af64ddd 100644 --- a/programs/mango-v4/src/error.rs +++ b/programs/mango-v4/src/error.rs @@ -7,6 +7,8 @@ pub enum MangoError { #[msg("")] SomeError, #[msg("")] + MathError, + #[msg("")] UnexpectedOracle, #[msg("")] UnknownOracleType, diff --git a/programs/mango-v4/src/instructions/deposit.rs b/programs/mango-v4/src/instructions/deposit.rs index 0e7214555..eb97fafd0 100644 --- a/programs/mango-v4/src/instructions/deposit.rs +++ b/programs/mango-v4/src/instructions/deposit.rs @@ -61,7 +61,7 @@ pub fn deposit(ctx: Context, amount: u64) -> Result<()> { // Update the bank and position let position_is_active = { let mut bank = ctx.accounts.bank.load_mut()?; - bank.deposit(position, amount) + bank.deposit(position, amount)? }; // Transfer the actual tokens diff --git a/programs/mango-v4/src/instructions/margin_trade.rs b/programs/mango-v4/src/instructions/margin_trade.rs index baea2ea89..0bad0d67a 100644 --- a/programs/mango-v4/src/instructions/margin_trade.rs +++ b/programs/mango-v4/src/instructions/margin_trade.rs @@ -176,9 +176,9 @@ fn adjust_for_post_cpi_amounts( // user has either withdrawn or deposited if *pre_cpi_amount > vault.amount { - bank.withdraw(&mut position, pre_cpi_amount - vault.amount); + bank.withdraw(&mut position, pre_cpi_amount - vault.amount)?; } else { - bank.deposit(&mut position, vault.amount - pre_cpi_amount); + bank.deposit(&mut position, vault.amount - pre_cpi_amount)?; } } } diff --git a/programs/mango-v4/src/instructions/withdraw.rs b/programs/mango-v4/src/instructions/withdraw.rs index 4e32059e1..dfcc46bc9 100644 --- a/programs/mango-v4/src/instructions/withdraw.rs +++ b/programs/mango-v4/src/instructions/withdraw.rs @@ -85,7 +85,7 @@ pub fn withdraw(ctx: Context, amount: u64, allow_borrow: bool) -> Resu ); // Update the bank and position - let position_is_active = bank.withdraw(position, amount); + let position_is_active = bank.withdraw(position, amount)?; // Transfer the actual tokens let group_seeds = group_seeds!(group); diff --git a/programs/mango-v4/src/state/bank.rs b/programs/mango-v4/src/state/bank.rs index 97052ed26..1b1880034 100644 --- a/programs/mango-v4/src/state/bank.rs +++ b/programs/mango-v4/src/state/bank.rs @@ -2,6 +2,7 @@ use anchor_lang::prelude::*; use fixed::types::I80F48; use super::{IndexedPosition, TokenIndex}; +use crate::util::checked_math as cm; #[account(zero_copy)] pub struct Bank { @@ -44,80 +45,82 @@ impl Bank { } /// Returns whether the position is active - pub fn deposit(&mut self, position: &mut IndexedPosition, native_amount: u64) -> bool { + pub fn deposit(&mut self, position: &mut IndexedPosition, native_amount: u64) -> Result { let mut native_amount = I80F48::from_num(native_amount); let native_position = position.native(self); if native_position.is_negative() { - let new_native_position = native_position + native_amount; + let new_native_position = cm!(native_position + native_amount); if new_native_position.is_negative() { // pay back borrows only, leaving a negative position - let indexed_change = native_amount / self.borrow_index + I80F48::DELTA; - self.indexed_total_borrows -= indexed_change; - position.indexed_value += indexed_change; - return true; + let indexed_change = cm!(native_amount / self.borrow_index + I80F48::DELTA); + self.indexed_total_borrows = cm!(self.indexed_total_borrows - indexed_change); + position.indexed_value = cm!(position.indexed_value + indexed_change); + return Ok(true); } else if new_native_position < I80F48::ONE { // if there's less than one token deposited, zero the position - self.dust += new_native_position; - self.indexed_total_borrows += position.indexed_value; + self.dust = cm!(self.dust + new_native_position); + self.indexed_total_borrows = + cm!(self.indexed_total_borrows + position.indexed_value); position.indexed_value = I80F48::ZERO; - return false; + return Ok(false); } // pay back all borrows - self.indexed_total_borrows += position.indexed_value; // position.value is negative + self.indexed_total_borrows = cm!(self.indexed_total_borrows + position.indexed_value); // position.value is negative position.indexed_value = I80F48::ZERO; // deposit the rest - native_amount += native_position; + native_amount = cm!(native_amount + native_position); } // add to deposits // Adding DELTA to amount/index helps because (amount/index)*index <= amount, but // we want to ensure that users can withdraw the same amount they have deposited, so // (amount/index + delta)*index >= amount is a better guarantee. - let indexed_change = native_amount / self.deposit_index + I80F48::DELTA; - self.indexed_total_deposits += indexed_change; - position.indexed_value += indexed_change; + let indexed_change = cm!(native_amount / self.deposit_index + I80F48::DELTA); + self.indexed_total_deposits = cm!(self.indexed_total_deposits + indexed_change); + position.indexed_value = cm!(position.indexed_value + indexed_change); - true + Ok(true) } /// Returns whether the position is active - pub fn withdraw(&mut self, position: &mut IndexedPosition, native_amount: u64) -> bool { + pub fn withdraw(&mut self, position: &mut IndexedPosition, native_amount: u64) -> Result { let mut native_amount = I80F48::from_num(native_amount); let native_position = position.native(self); if native_position.is_positive() { - let new_native_position = native_position - native_amount; + let new_native_position = cm!(native_position - native_amount); if !new_native_position.is_negative() { // withdraw deposits only if new_native_position < I80F48::ONE { // zero the account collecting the leftovers in `dust` - self.dust += new_native_position; - self.indexed_total_deposits -= position.indexed_value; + self.dust = cm!(self.dust + new_native_position); + self.indexed_total_deposits = + cm!(self.indexed_total_deposits - position.indexed_value); position.indexed_value = I80F48::ZERO; - return false; + return Ok(false); } else { // withdraw some deposits leaving >1 native token - let indexed_change = native_amount / self.deposit_index; - self.indexed_total_deposits -= indexed_change; - position.indexed_value -= indexed_change; - return true; + let indexed_change = cm!(native_amount / self.deposit_index); + self.indexed_total_deposits = cm!(self.indexed_total_deposits - indexed_change); + position.indexed_value = cm!(position.indexed_value - indexed_change); + return Ok(true); } } // withdraw all deposits - self.indexed_total_deposits -= position.indexed_value; + self.indexed_total_deposits = cm!(self.indexed_total_deposits - position.indexed_value); position.indexed_value = I80F48::ZERO; // borrow the rest native_amount = -new_native_position; } // add to borrows - let indexed_change = native_amount / self.borrow_index; - self.indexed_total_borrows += indexed_change; - position.indexed_value -= indexed_change; + let indexed_change = cm!(native_amount / self.borrow_index); + self.indexed_total_borrows = cm!(self.indexed_total_borrows + indexed_change); + position.indexed_value = cm!(position.indexed_value - indexed_change); - true + Ok(true) } } diff --git a/programs/mango-v4/src/state/health.rs b/programs/mango-v4/src/state/health.rs index 1dd5a51c1..266bc2ff8 100644 --- a/programs/mango-v4/src/state/health.rs +++ b/programs/mango-v4/src/state/health.rs @@ -5,6 +5,7 @@ use pyth_client::load_price; use crate::error::MangoError; use crate::state::{determine_oracle_type, Bank, MangoAccount, OracleType, StubOracle}; use crate::util; +use crate::util::checked_math as cm; pub fn compute_health(account: &MangoAccount, ais: &[AccountInfo]) -> Result { let active_len = account.indexed_positions.iter_active().count(); @@ -58,15 +59,16 @@ fn compute_health_detail( } }; - let native_basis = position.native(&bank) * price; + let native_position = position.native(&bank); + let native_basis = cm!(native_position * price); if native_basis.is_positive() { - assets += bank.init_asset_weight * native_basis; + assets = cm!(assets + bank.init_asset_weight * native_basis); } else { - liabilities -= bank.init_liab_weight * native_basis; + liabilities = cm!(liabilities - bank.init_liab_weight * native_basis); } } // TODO: Serum open orders - Ok(assets - liabilities) + Ok(cm!(assets - liabilities)) } diff --git a/programs/mango-v4/src/util.rs b/programs/mango-v4/src/util.rs index 3a61632b4..97890201c 100644 --- a/programs/mango-v4/src/util.rs +++ b/programs/mango-v4/src/util.rs @@ -6,5 +6,12 @@ macro_rules! zip { zip!($($y), +)) ) } - pub(crate) use zip; + +#[macro_export] +macro_rules! checked_math { + ($x: expr) => { + checked_math::checked_math!($x).ok_or(error!(crate::error::MangoError::MathError))? + }; +} +pub(crate) use checked_math;