diff --git a/src/accountant.rs b/src/accountant.rs index 5f3b4076d..b031d4389 100644 --- a/src/accountant.rs +++ b/src/accountant.rs @@ -21,6 +21,7 @@ const MAX_ENTRY_IDS: usize = 1024 * 4; #[derive(Debug, PartialEq, Eq)] pub enum AccountingError { + AccountNotFound, InsufficientFunds, InvalidTransferSignature, } @@ -107,7 +108,17 @@ impl Accountant { /// Deduct tokens from the 'from' address the account has sufficient /// funds and isn't a duplicate. pub fn process_verified_transaction_debits(&self, tr: &Transaction) -> Result<()> { - if self.get_balance(&tr.from).unwrap_or(0) < tr.data.tokens { + let bals = self.balances.read().unwrap(); + + // Hold a write lock before the condition check, so that a debit can't occur + // between checking the balance and the withdraw. + let option = bals.get(&tr.from); + if option.is_none() { + return Err(AccountingError::AccountNotFound); + } + let mut bal = option.unwrap().write().unwrap(); + + if *bal < tr.data.tokens { return Err(AccountingError::InsufficientFunds); } @@ -115,8 +126,6 @@ impl Accountant { return Err(AccountingError::InvalidTransferSignature); } - let bals = self.balances.read().unwrap(); - let mut bal = bals[&tr.from].write().unwrap(); *bal -= tr.data.tokens; Ok(()) @@ -257,6 +266,16 @@ mod tests { assert_eq!(acc.get_balance(&bob_pubkey).unwrap(), 1_500); } + #[test] + fn test_account_not_found() { + let mint = Mint::new(1); + let acc = Accountant::new(&mint); + assert_eq!( + acc.transfer(1, &KeyPair::new(), mint.pubkey(), mint.last_id()), + Err(AccountingError::AccountNotFound) + ); + } + #[test] fn test_invalid_transfer() { let alice = Mint::new(11_000);