diff --git a/src/accountant.rs b/src/accountant.rs index 9ed88571c..8b4b13931 100644 --- a/src/accountant.rs +++ b/src/accountant.rs @@ -12,6 +12,7 @@ use signature::{KeyPair, PublicKey, Signature}; use std::collections::hash_map::Entry::Occupied; use std::collections::{HashMap, HashSet}; use std::result; +use std::sync::RwLock; use transaction::Transaction; #[derive(Debug, PartialEq, Eq)] @@ -28,11 +29,11 @@ fn apply_payment(balances: &mut HashMap, payment: &Payment) { } pub struct Accountant { - balances: HashMap, - pending: HashMap, - signatures: HashSet, - time_sources: HashSet, - last_time: DateTime, + balances: RwLock>, + pending: RwLock>, + signatures: RwLock>, + time_sources: RwLock>, + last_time: RwLock>, } impl Accountant { @@ -41,11 +42,11 @@ impl Accountant { let mut balances = HashMap::new(); apply_payment(&mut balances, deposit); Accountant { - balances, - pending: HashMap::new(), - signatures: HashSet::new(), - time_sources: HashSet::new(), - last_time: Utc.timestamp(0, 0), + balances: RwLock::new(balances), + pending: RwLock::new(HashMap::new()), + signatures: RwLock::new(HashSet::new()), + time_sources: RwLock::new(HashSet::new()), + last_time: RwLock::new(Utc.timestamp(0, 0)), } } @@ -58,16 +59,16 @@ impl Accountant { Self::new_from_deposit(&deposit) } - fn reserve_signature(&mut self, sig: &Signature) -> bool { - if self.signatures.contains(sig) { + fn reserve_signature(&self, sig: &Signature) -> bool { + if self.signatures.read().unwrap().contains(sig) { return false; } - self.signatures.insert(*sig); + self.signatures.write().unwrap().insert(*sig); true } /// Process a Transaction that has already been verified. - pub fn process_verified_transaction(&mut self, tr: &Transaction) -> Result<()> { + pub fn process_verified_transaction(&self, tr: &Transaction) -> Result<()> { if self.get_balance(&tr.from).unwrap_or(0) < tr.tokens { return Err(AccountingError::InsufficientFunds); } @@ -76,28 +77,28 @@ impl Accountant { return Err(AccountingError::InvalidTransferSignature); } - if let Some(x) = self.balances.get_mut(&tr.from) { + if let Some(x) = self.balances.write().unwrap().get_mut(&tr.from) { *x -= tr.tokens; } let mut plan = tr.plan.clone(); - plan.apply_witness(&Witness::Timestamp(self.last_time)); + plan.apply_witness(&Witness::Timestamp(*self.last_time.read().unwrap())); if let Some(ref payment) = plan.final_payment() { - apply_payment(&mut self.balances, payment); + apply_payment(&mut self.balances.write().unwrap(), payment); } else { - self.pending.insert(tr.sig, plan); + self.pending.write().unwrap().insert(tr.sig, plan); } Ok(()) } /// Process a Witness Signature that has already been verified. - fn process_verified_sig(&mut self, from: PublicKey, tx_sig: Signature) -> Result<()> { - if let Occupied(mut e) = self.pending.entry(tx_sig) { + fn process_verified_sig(&self, from: PublicKey, tx_sig: Signature) -> Result<()> { + if let Occupied(mut e) = self.pending.write().unwrap().entry(tx_sig) { e.get_mut().apply_witness(&Witness::Signature(from)); if let Some(ref payment) = e.get().final_payment() { - apply_payment(&mut self.balances, payment); + apply_payment(&mut self.balances.write().unwrap(), payment); e.remove_entry(); } }; @@ -106,16 +107,16 @@ impl Accountant { } /// Process a Witness Timestamp that has already been verified. - fn process_verified_timestamp(&mut self, from: PublicKey, dt: DateTime) -> Result<()> { + fn process_verified_timestamp(&self, from: PublicKey, dt: DateTime) -> Result<()> { // If this is the first timestamp we've seen, it probably came from the genesis block, // so we'll trust it. - if self.last_time == Utc.timestamp(0, 0) { - self.time_sources.insert(from); + if *self.last_time.read().unwrap() == Utc.timestamp(0, 0) { + self.time_sources.write().unwrap().insert(from); } - if self.time_sources.contains(&from) { - if dt > self.last_time { - self.last_time = dt; + if self.time_sources.read().unwrap().contains(&from) { + if dt > *self.last_time.read().unwrap() { + *self.last_time.write().unwrap() = dt; } } else { return Ok(()); @@ -123,23 +124,27 @@ impl Accountant { // Check to see if any timelocked transactions can be completed. let mut completed = vec![]; - for (key, plan) in &mut self.pending { - plan.apply_witness(&Witness::Timestamp(self.last_time)); + + // Hold 'pending' write lock until the end of this function. Otherwise another thread can + // double-spend if it enters before the modified plan is removed from 'pending'. + let mut pending = self.pending.write().unwrap(); + for (key, plan) in pending.iter_mut() { + plan.apply_witness(&Witness::Timestamp(*self.last_time.read().unwrap())); if let Some(ref payment) = plan.final_payment() { - apply_payment(&mut self.balances, payment); + apply_payment(&mut self.balances.write().unwrap(), payment); completed.push(key.clone()); } } for key in completed { - self.pending.remove(&key); + pending.remove(&key); } Ok(()) } /// Process an Transaction or Witness that has already been verified. - pub fn process_verified_event(&mut self, event: &Event) -> Result<()> { + pub fn process_verified_event(&self, event: &Event) -> Result<()> { match *event { Event::Transaction(ref tr) => self.process_verified_transaction(tr), Event::Signature { from, tx_sig, .. } => self.process_verified_sig(from, tx_sig), @@ -150,7 +155,7 @@ impl Accountant { /// Create, sign, and process a Transaction from `keypair` to `to` of /// `n` tokens where `last_id` is the last Entry ID observed by the client. pub fn transfer( - &mut self, + &self, n: i64, keypair: &KeyPair, to: PublicKey, @@ -165,7 +170,7 @@ impl Accountant { /// to `to` of `n` tokens on `dt` where `last_id` is the last Entry ID /// observed by the client. pub fn transfer_on_date( - &mut self, + &self, n: i64, keypair: &KeyPair, to: PublicKey, @@ -178,7 +183,7 @@ impl Accountant { } pub fn get_balance(&self, pubkey: &PublicKey) -> Option { - self.balances.get(pubkey).cloned() + self.balances.read().unwrap().get(pubkey).cloned() } } @@ -191,7 +196,7 @@ mod tests { fn test_accountant() { let alice = Mint::new(10_000); let bob_pubkey = KeyPair::new().pubkey(); - let mut acc = Accountant::new(&alice); + let acc = Accountant::new(&alice); acc.transfer(1_000, &alice.keypair(), bob_pubkey, alice.last_id()) .unwrap(); assert_eq!(acc.get_balance(&bob_pubkey).unwrap(), 1_000); @@ -204,7 +209,7 @@ mod tests { #[test] fn test_invalid_transfer() { let alice = Mint::new(11_000); - let mut acc = Accountant::new(&alice); + let acc = Accountant::new(&alice); let bob_pubkey = KeyPair::new().pubkey(); acc.transfer(1_000, &alice.keypair(), bob_pubkey, alice.last_id()) .unwrap(); @@ -221,7 +226,7 @@ mod tests { #[test] fn test_transfer_to_newb() { let alice = Mint::new(10_000); - let mut acc = Accountant::new(&alice); + let acc = Accountant::new(&alice); let alice_keypair = alice.keypair(); let bob_pubkey = KeyPair::new().pubkey(); acc.transfer(500, &alice_keypair, bob_pubkey, alice.last_id()) @@ -232,7 +237,7 @@ mod tests { #[test] fn test_transfer_on_date() { let alice = Mint::new(1); - let mut acc = Accountant::new(&alice); + let acc = Accountant::new(&alice); let alice_keypair = alice.keypair(); let bob_pubkey = KeyPair::new().pubkey(); let dt = Utc::now(); @@ -258,7 +263,7 @@ mod tests { #[test] fn test_transfer_after_date() { let alice = Mint::new(1); - let mut acc = Accountant::new(&alice); + let acc = Accountant::new(&alice); let alice_keypair = alice.keypair(); let bob_pubkey = KeyPair::new().pubkey(); let dt = Utc::now(); @@ -275,7 +280,7 @@ mod tests { #[test] fn test_cancel_transfer() { let alice = Mint::new(1); - let mut acc = Accountant::new(&alice); + let acc = Accountant::new(&alice); let alice_keypair = alice.keypair(); let bob_pubkey = KeyPair::new().pubkey(); let dt = Utc::now(); @@ -301,7 +306,7 @@ mod tests { #[test] fn test_duplicate_event_signature() { let alice = Mint::new(1); - let mut acc = Accountant::new(&alice); + let acc = Accountant::new(&alice); let sig = Signature::default(); assert!(acc.reserve_signature(&sig)); assert!(!acc.reserve_signature(&sig)); diff --git a/src/bin/testnode.rs b/src/bin/testnode.rs index 40047a9ff..bbce2d91e 100644 --- a/src/bin/testnode.rs +++ b/src/bin/testnode.rs @@ -32,7 +32,7 @@ fn main() { None }; - let mut acc = Accountant::new_from_deposit(&deposit.unwrap()); + let acc = Accountant::new_from_deposit(&deposit.unwrap()); let mut last_id = entry1.id; for entry in entries {