diff --git a/src/poh_recorder.rs b/src/poh_recorder.rs index 18982ae8e0..b948cd942f 100644 --- a/src/poh_recorder.rs +++ b/src/poh_recorder.rs @@ -29,12 +29,12 @@ impl PohRecorder { // TODO: amortize the cost of this lock by doing the loop in here for // some min amount of hashes let mut poh = self.poh.lock().unwrap(); - if self.is_max_tick_height_reached(&poh) { - Err(Error::PohRecorderError(PohRecorderError::MaxHeightReached)) - } else { - poh.hash(); - Ok(()) - } + + self.check_tick_height(&poh)?; + + poh.hash(); + + Ok(()) } pub fn tick(&mut self) -> Result<()> { @@ -42,24 +42,20 @@ impl PohRecorder { // hasn't been reached. // This guarantees PoH order and Entry production and banks LastId queue is the same let mut poh = self.poh.lock().unwrap(); - if self.is_max_tick_height_reached(&poh) { - Err(Error::PohRecorderError(PohRecorderError::MaxHeightReached)) - } else { - self.register_and_send_tick(&mut *poh)?; - Ok(()) - } + + self.check_tick_height(&poh)?; + + self.register_and_send_tick(&mut *poh) } pub fn record(&self, mixin: Hash, txs: Vec) -> Result<()> { // Register and send the entry out while holding the lock. // This guarantees PoH order and Entry production and banks LastId queue is the same. let mut poh = self.poh.lock().unwrap(); - if self.is_max_tick_height_reached(&poh) { - Err(Error::PohRecorderError(PohRecorderError::MaxHeightReached)) - } else { - self.record_and_send_txs(&mut *poh, mixin, txs)?; - Ok(()) - } + + self.check_tick_height(&poh)?; + + self.record_and_send_txs(&mut *poh, mixin, txs) } /// A recorder to synchronize PoH with the following data structures @@ -80,11 +76,12 @@ impl PohRecorder { } } - fn is_max_tick_height_reached(&self, poh: &Poh) -> bool { - if let Some(max_tick_height) = self.max_tick_height { - poh.tick_height >= max_tick_height - } else { - false + fn check_tick_height(&self, poh: &Poh) -> Result<()> { + match self.max_tick_height { + Some(max_tick_height) if poh.tick_height >= max_tick_height => { + Err(Error::PohRecorderError(PohRecorderError::MaxHeightReached)) + } + _ => Ok(()), } } @@ -132,12 +129,12 @@ mod tests { let bank = Arc::new(Bank::new(&mint)); let prev_id = bank.last_id(); let (entry_sender, entry_receiver) = channel(); - let mut poh_recorder = PohRecorder::new(bank, entry_sender, prev_id, None); + let mut poh_recorder = PohRecorder::new(bank, entry_sender, prev_id, Some(3)); //send some data let h1 = hash(b"hello world!"); let tx = test_tx(); - assert!(poh_recorder.record(h1, vec![tx]).is_ok()); + assert!(poh_recorder.record(h1, vec![tx.clone()]).is_ok()); //get some events let e = entry_receiver.recv().unwrap(); assert_eq!(e[0].tick_height, 1); @@ -150,6 +147,10 @@ mod tests { let e = entry_receiver.recv().unwrap(); assert_eq!(e[0].tick_height, 2); + // max tick height reached + assert!(poh_recorder.tick().is_err()); + assert!(poh_recorder.record(h1, vec![tx]).is_err()); + //make sure it handles channel close correctly drop(entry_receiver); assert!(poh_recorder.tick().is_err());