diff --git a/src/banking_stage.rs b/src/banking_stage.rs index 55a40a24e7..f094b1745b 100644 --- a/src/banking_stage.rs +++ b/src/banking_stage.rs @@ -6,7 +6,7 @@ use crate::entry::Entry; use crate::leader_confirmation_service::LeaderConfirmationService; use crate::packet::Packets; use crate::packet::SharedPackets; -use crate::poh_recorder::{PohRecorder, PohRecorderError}; +use crate::poh_recorder::{PohRecorder, PohRecorderError, WorkingBank}; use crate::poh_service::{PohService, PohServiceConfig}; use crate::result::{Error, Result}; use crate::service::Service; @@ -20,7 +20,7 @@ use solana_sdk::pubkey::Pubkey; use solana_sdk::timing::{self, duration_as_us, MAX_ENTRY_IDS}; use solana_sdk::transaction::Transaction; use std::sync::atomic::AtomicBool; -use std::sync::mpsc::{channel, Receiver, RecvTimeoutError, Sender}; +use std::sync::mpsc::{channel, Receiver, RecvTimeoutError}; use std::sync::{Arc, Mutex}; use std::thread::{self, Builder, JoinHandle}; use std::time::Duration; @@ -52,19 +52,26 @@ impl BankingStage { ) -> (Self, Receiver>) { let (entry_sender, entry_receiver) = channel(); let shared_verified_receiver = Arc::new(Mutex::new(verified_receiver)); - let poh_recorder = PohRecorder::new(bank.tick_height(), *last_entry_id, max_tick_height); + let working_bank = WorkingBank { + bank: bank.clone(), + sender: entry_sender, + min_tick_height: bank.tick_height(), + max_tick_height, + }; + + let poh_recorder = PohRecorder::new(bank.tick_height(), *last_entry_id); // Single thread to generate entries from many banks. // This thread talks to poh_service and broadcasts the entries once they have been recorded. // Once an entry has been recorded, its last_id is registered with the bank. let poh_exit = Arc::new(AtomicBool::new(false)); - let poh_service = PohService::new( - bank.clone(), - entry_sender.clone(), - poh_recorder.clone(), - config, - poh_exit.clone(), - ); + + let (poh_service, leader_sender) = + PohService::new(poh_recorder.clone(), config, poh_exit.clone()); + + leader_sender + .send(working_bank.clone()) + .expect("failed to send leader to poh_service"); // Single thread to compute confirmation let leader_confirmation_service = @@ -76,7 +83,7 @@ impl BankingStage { let thread_bank = bank.clone(); let thread_verified_receiver = shared_verified_receiver.clone(); let thread_poh_recorder = poh_recorder.clone(); - let thread_sender = entry_sender.clone(); + let thread_leader = working_bank.clone(); Builder::new() .name("solana-banking-stage-tx".to_string()) .spawn(move || { @@ -86,7 +93,7 @@ impl BankingStage { &thread_bank, &thread_verified_receiver, &thread_poh_recorder, - &thread_sender, + &thread_leader, ) { Err(Error::RecvTimeoutError(RecvTimeoutError::Timeout)) => (), Ok(more_unprocessed_packets) => { @@ -129,7 +136,7 @@ impl BankingStage { txs: &[Transaction], results: &[bank::Result<()>], poh: &PohRecorder, - entry_sender: &Sender>, + working_bank: &WorkingBank, ) -> Result<()> { let processed_transactions: Vec<_> = results .iter() @@ -151,7 +158,7 @@ impl BankingStage { if !processed_transactions.is_empty() { let hash = Transaction::hash(&processed_transactions); // record and unlock will unlock all the successfull transactions - poh.record(hash, processed_transactions, entry_sender)?; + poh.record(hash, processed_transactions, working_bank)?; } Ok(()) } @@ -160,7 +167,7 @@ impl BankingStage { bank: &Bank, txs: &[Transaction], poh: &PohRecorder, - entry_sender: &Sender>, + working_bank: &WorkingBank, ) -> Result<()> { let now = Instant::now(); // Once accounts are locked, other threads cannot encode transactions that will modify the @@ -179,7 +186,7 @@ impl BankingStage { let record_time = { let now = Instant::now(); - Self::record_transactions(txs, &results, poh, entry_sender)?; + Self::record_transactions(txs, &results, poh, working_bank)?; now.elapsed() }; @@ -213,7 +220,7 @@ impl BankingStage { bank: &Arc, transactions: &[Transaction], poh: &PohRecorder, - entry_sender: &Sender>, + working_bank: &WorkingBank, ) -> Result<(usize)> { let mut chunk_start = 0; while chunk_start != transactions.len() { @@ -223,7 +230,7 @@ impl BankingStage { bank, &transactions[chunk_start..chunk_end], poh, - entry_sender, + working_bank, ); if let Err(Error::PohRecorderError(PohRecorderError::MaxHeightReached)) = result { break; @@ -239,7 +246,7 @@ impl BankingStage { bank: &Arc, verified_receiver: &Arc>>, poh: &PohRecorder, - entry_sender: &Sender>, + working_bank: &WorkingBank, ) -> Result { let recv_start = Instant::now(); let mms = verified_receiver @@ -291,7 +298,7 @@ impl BankingStage { debug!("verified transactions {}", verified_transactions.len()); let processed = - Self::process_transactions(bank, &verified_transactions, poh, entry_sender)?; + Self::process_transactions(bank, &verified_transactions, poh, working_bank)?; if processed < verified_transactions.len() { bank_shutdown = true; // Collect any unprocessed transactions in this batch for forwarding @@ -596,7 +603,14 @@ mod tests { let (genesis_block, mint_keypair) = GenesisBlock::new(10_000); let bank = Arc::new(Bank::new(&genesis_block)); let (entry_sender, entry_receiver) = channel(); - let poh_recorder = PohRecorder::new(bank.tick_height(), bank.last_id(), std::u64::MAX); + let working_bank = WorkingBank { + bank: bank.clone(), + sender: entry_sender, + min_tick_height: bank.tick_height(), + max_tick_height: std::u64::MAX, + }; + + let poh_recorder = PohRecorder::new(bank.tick_height(), bank.last_id()); let pubkey = Keypair::new().pubkey(); let transactions = vec![ @@ -605,7 +619,7 @@ mod tests { ]; let mut results = vec![Ok(()), Ok(())]; - BankingStage::record_transactions(&transactions, &results, &poh_recorder, &entry_sender) + BankingStage::record_transactions(&transactions, &results, &poh_recorder, &working_bank) .unwrap(); let entries = entry_receiver.recv().unwrap(); assert_eq!(entries[0].transactions.len(), transactions.len()); @@ -615,14 +629,14 @@ mod tests { 1, ProgramError::ResultWithNegativeTokens, )); - BankingStage::record_transactions(&transactions, &results, &poh_recorder, &entry_sender) + BankingStage::record_transactions(&transactions, &results, &poh_recorder, &working_bank) .unwrap(); let entries = entry_receiver.recv().unwrap(); assert_eq!(entries[0].transactions.len(), transactions.len()); // Other BankErrors should not be recorded results[0] = Err(BankError::AccountNotFound); - BankingStage::record_transactions(&transactions, &results, &poh_recorder, &entry_sender) + BankingStage::record_transactions(&transactions, &results, &poh_recorder, &working_bank) .unwrap(); let entries = entry_receiver.recv().unwrap(); assert_eq!(entries[0].transactions.len(), transactions.len() - 1); @@ -643,17 +657,22 @@ mod tests { )]; let (entry_sender, entry_receiver) = channel(); - let mut poh_recorder = - PohRecorder::new(bank.tick_height(), bank.last_id(), bank.tick_height() + 1); + let working_bank = WorkingBank { + bank: bank.clone(), + sender: entry_sender, + min_tick_height: bank.tick_height(), + max_tick_height: bank.tick_height() + 1, + }; + let poh_recorder = PohRecorder::new(bank.tick_height(), bank.last_id()); BankingStage::process_and_record_transactions( &bank, &transactions, &poh_recorder, - &entry_sender, + &working_bank, ) .unwrap(); - poh_recorder.tick(&bank, &entry_sender).unwrap(); + poh_recorder.tick(&working_bank).unwrap(); let mut need_tick = true; // read entries until I find mine, might be ticks... @@ -682,7 +701,7 @@ mod tests { &bank, &transactions, &poh_recorder, - &entry_sender + &working_bank ), Err(Error::PohRecorderError(PohRecorderError::MaxHeightReached)) ); diff --git a/src/poh_recorder.rs b/src/poh_recorder.rs index 7f4a2cfc6f..7af64dab90 100644 --- a/src/poh_recorder.rs +++ b/src/poh_recorder.rs @@ -14,70 +14,92 @@ use std::sync::{Arc, Mutex}; pub enum PohRecorderError { InvalidCallingObject, MaxHeightReached, + MinHeightNotReached, +} + +#[derive(Clone)] +pub struct WorkingBank { + pub bank: Arc, + pub sender: Sender>, + pub min_tick_height: u64, + pub max_tick_height: u64, } #[derive(Clone)] pub struct PohRecorder { poh: Arc>, - max_tick_height: u64, + tick_cache: Arc>>, } impl PohRecorder { - pub fn max_tick_height(&self) -> u64 { - self.max_tick_height - } - - pub fn hash(&self) -> Result<()> { + pub fn hash(&self) { // TODO: amortize the cost of this lock by doing the loop in here for // some min amount of hashes let mut poh = self.poh.lock().unwrap(); - self.check_tick_height(&poh)?; - poh.hash(); + } + fn flush_cache(&self, working_bank: &WorkingBank) -> Result<()> { + let mut cache = vec![]; + std::mem::swap(&mut cache, &mut self.tick_cache.lock().unwrap()); + if !cache.is_empty() { + for t in &cache { + working_bank.bank.register_tick(&t.id); + } + working_bank.sender.send(cache)?; + } Ok(()) } - pub fn tick(&mut self, bank: &Arc, sender: &Sender>) -> Result<()> { + pub fn tick(&self, working_bank: &WorkingBank) -> Result<()> { // Register and send the entry out while holding the lock if the max PoH height // hasn't been reached. // This guarantees PoH order and Entry production and banks LastId queue is the same let mut poh = self.poh.lock().unwrap(); - self.check_tick_height(&poh)?; + Self::check_tick_height(&poh, working_bank).map_err(|e| { + let tick = Self::generate_tick(&mut poh); + self.tick_cache.lock().unwrap().push(tick); + e + })?; + ; + self.flush_cache(working_bank)?; - self.register_and_send_tick(&mut *poh, bank, sender) + Self::register_and_send_tick(&mut *poh, working_bank) } pub fn record( &self, mixin: Hash, txs: Vec, - sender: &Sender>, + working_bank: &WorkingBank, ) -> Result<()> { // Register and send the entry out while holding the lock. // This guarantees PoH order and Entry production and banks LastId queue is the same. let mut poh = self.poh.lock().unwrap(); - self.check_tick_height(&poh)?; + Self::check_tick_height(&poh, working_bank)?; + self.flush_cache(working_bank)?; - self.record_and_send_txs(&mut *poh, mixin, txs, sender) + Self::record_and_send_txs(&mut *poh, mixin, txs, working_bank) } /// A recorder to synchronize PoH with the following data structures /// * bank - the LastId's queue is updated on `tick` and `record` events /// * sender - the Entry channel that outputs to the ledger - pub fn new(tick_height: u64, last_entry_id: Hash, max_tick_height: u64) -> Self { + pub fn new(tick_height: u64, last_entry_id: Hash) -> Self { let poh = Arc::new(Mutex::new(Poh::new(last_entry_id, tick_height))); - PohRecorder { - poh, - max_tick_height, - } + let tick_cache = Arc::new(Mutex::new(vec![])); + PohRecorder { poh, tick_cache } } - fn check_tick_height(&self, poh: &Poh) -> Result<()> { - if poh.tick_height >= self.max_tick_height { + fn check_tick_height(poh: &Poh, working_bank: &WorkingBank) -> Result<()> { + if poh.tick_height < working_bank.min_tick_height { + Err(Error::PohRecorderError( + PohRecorderError::MinHeightNotReached, + )) + } else if poh.tick_height >= working_bank.max_tick_height { Err(Error::PohRecorderError(PohRecorderError::MaxHeightReached)) } else { Ok(()) @@ -85,11 +107,10 @@ impl PohRecorder { } fn record_and_send_txs( - &self, poh: &mut Poh, mixin: Hash, txs: Vec, - sender: &Sender>, + working_bank: &WorkingBank, ) -> Result<()> { let entry = poh.record(mixin); assert!(!txs.is_empty(), "Entries without transactions are used to track real-time passing in the ledger and can only be generated with PohRecorder::tick function"); @@ -98,24 +119,23 @@ impl PohRecorder { id: entry.id, transactions: txs, }; - sender.send(vec![entry])?; + working_bank.sender.send(vec![entry])?; Ok(()) } - fn register_and_send_tick( - &self, - poh: &mut Poh, - bank: &Arc, - sender: &Sender>, - ) -> Result<()> { + fn generate_tick(poh: &mut Poh) -> Entry { let tick = poh.tick(); - let tick = Entry { + Entry { num_hashes: tick.num_hashes, id: tick.id, transactions: vec![], - }; - bank.register_tick(&tick.id); - sender.send(vec![tick])?; + } + } + + fn register_and_send_tick(poh: &mut Poh, working_bank: &WorkingBank) -> Result<()> { + let tick = Self::generate_tick(poh); + working_bank.bank.register_tick(&tick.id); + working_bank.sender.send(vec![tick])?; Ok(()) } } @@ -135,29 +155,97 @@ mod tests { let bank = Arc::new(Bank::new(&genesis_block)); let prev_id = bank.last_id(); let (entry_sender, entry_receiver) = channel(); - let mut poh_recorder = PohRecorder::new(0, prev_id, 2); + let poh_recorder = PohRecorder::new(0, prev_id); + + let working_bank = WorkingBank { + bank, + sender: entry_sender, + min_tick_height: 0, + max_tick_height: 2, + }; //send some data let h1 = hash(b"hello world!"); let tx = test_tx(); poh_recorder - .record(h1, vec![tx.clone()], &entry_sender) + .record(h1, vec![tx.clone()], &working_bank) .unwrap(); //get some events let _e = entry_receiver.recv().unwrap(); - poh_recorder.tick(&bank, &entry_sender).unwrap(); + poh_recorder.tick(&working_bank).unwrap(); let _e = entry_receiver.recv().unwrap(); - poh_recorder.tick(&bank, &entry_sender).unwrap(); + poh_recorder.tick(&working_bank).unwrap(); let _e = entry_receiver.recv().unwrap(); // max tick height reached - assert!(poh_recorder.tick(&bank, &entry_sender).is_err()); - assert!(poh_recorder.record(h1, vec![tx], &entry_sender).is_err()); + assert!(poh_recorder.tick(&working_bank).is_err()); + assert!(poh_recorder.record(h1, vec![tx], &working_bank).is_err()); //make sure it handles channel close correctly drop(entry_receiver); - assert!(poh_recorder.tick(&bank, &entry_sender).is_err()); + assert!(poh_recorder.tick(&working_bank).is_err()); + } + + #[test] + fn test_poh_recorder_tick_cache() { + let (genesis_block, _mint_keypair) = GenesisBlock::new(2); + let bank = Arc::new(Bank::new(&genesis_block)); + let prev_id = bank.last_id(); + let (entry_sender, entry_receiver) = channel(); + let poh_recorder = PohRecorder::new(0, prev_id); + + let working_bank = WorkingBank { + bank, + sender: entry_sender, + min_tick_height: 1, + max_tick_height: 2, + }; + + // tick should be cached + assert!(poh_recorder.tick(&working_bank).is_err()); + assert!(entry_receiver.try_recv().is_err()); + + // working_bank should be at the right height + poh_recorder.tick(&working_bank).unwrap(); + + let entries = entry_receiver.recv().unwrap(); + assert_eq!(entries.len(), 1); + let entries = entry_receiver.recv().unwrap(); + assert_eq!(entries.len(), 1); + } + + #[test] + fn test_poh_recorder_tick_cache_old_working_bank() { + let (genesis_block, _mint_keypair) = GenesisBlock::new(2); + let bank = Arc::new(Bank::new(&genesis_block)); + let prev_id = bank.last_id(); + let (entry_sender, entry_receiver) = channel(); + let poh_recorder = PohRecorder::new(0, prev_id); + + let working_bank = WorkingBank { + bank, + sender: entry_sender, + min_tick_height: 1, + max_tick_height: 1, + }; + + // tick should be cached + assert_matches!( + poh_recorder.tick(&working_bank), + Err(Error::PohRecorderError( + PohRecorderError::MinHeightNotReached + )) + ); + + // working_bank should be past MaxHeight + assert_matches!( + poh_recorder.tick(&working_bank), + Err(Error::PohRecorderError(PohRecorderError::MaxHeightReached)) + ); + assert_eq!(poh_recorder.tick_cache.lock().unwrap().len(), 2); + + assert!(entry_receiver.try_recv().is_err()); } } diff --git a/src/poh_service.rs b/src/poh_service.rs index d09031c155..5b3178cbda 100644 --- a/src/poh_service.rs +++ b/src/poh_service.rs @@ -1,14 +1,12 @@ //! The `poh_service` module implements a service that records the passing of //! "ticks", a measure of time in the PoH stream -use crate::entry::Entry; -use crate::poh_recorder::PohRecorder; -use crate::result::Result; +use crate::poh_recorder::{PohRecorder, PohRecorderError, WorkingBank}; +use crate::result::{Error, Result}; use crate::service::Service; -use solana_runtime::bank::Bank; use solana_sdk::timing::NUM_TICKS_PER_SECOND; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::mpsc::Sender; +use std::sync::mpsc::{channel, Receiver, Sender, TryRecvError}; use std::sync::Arc; use std::thread::{self, sleep, Builder, JoinHandle}; use std::time::Duration; @@ -46,55 +44,78 @@ impl PohService { } pub fn new( - bank: Arc, - sender: Sender>, poh_recorder: PohRecorder, config: PohServiceConfig, poh_exit: Arc, - ) -> Self { + ) -> (Self, Sender) { // PohService is a headless producer, so when it exits it should notify the banking stage. // Since channel are not used to talk between these threads an AtomicBool is used as a // signal. let poh_exit_ = poh_exit.clone(); + let (working_bank_sender, working_bank_receiver) = channel(); // Single thread to generate ticks let tick_producer = Builder::new() .name("solana-poh-service-tick_producer".to_string()) .spawn(move || { - let mut poh_recorder_ = poh_recorder; - let sender = sender.clone(); - let bank = bank.clone(); - let return_value = - Self::tick_producer(&bank, &sender, &mut poh_recorder_, config, &poh_exit_); + let mut poh_recorder = poh_recorder; + let working_bank_receiver = working_bank_receiver; + let return_value = Self::tick_producer( + &working_bank_receiver, + &mut poh_recorder, + config, + &poh_exit_, + ); poh_exit_.store(true, Ordering::Relaxed); return_value }) .unwrap(); - Self { - tick_producer, - poh_exit, - } + ( + Self { + tick_producer, + poh_exit, + }, + working_bank_sender, + ) } fn tick_producer( - bank: &Arc, - sender: &Sender>, + working_bank_receiver: &Receiver, poh: &mut PohRecorder, config: PohServiceConfig, poh_exit: &AtomicBool, ) -> Result<()> { + let mut working_bank = None; loop { + if working_bank.is_none() { + let result = working_bank_receiver.try_recv(); + working_bank = match result { + Err(TryRecvError::Empty) => None, + _ => Some(result?), + }; + } match config { PohServiceConfig::Tick(num) => { for _ in 1..num { - poh.hash()?; + poh.hash(); } } PohServiceConfig::Sleep(duration) => { sleep(duration); } } - poh.tick(&bank, sender)?; + let result = if let Some(ref current_leader) = working_bank { + poh.tick(current_leader) + } else { + Ok(()) + }; + match result { + Err(Error::PohRecorderError(PohRecorderError::MinHeightNotReached)) => (), + Err(Error::PohRecorderError(PohRecorderError::MaxHeightReached)) => { + working_bank = None; + } + e => e?, + }; if poh_exit.load(Ordering::Relaxed) { return Ok(()); } @@ -114,9 +135,11 @@ impl Service for PohService { mod tests { use super::*; use crate::test_tx::test_tx; + use solana_runtime::bank::Bank; use solana_sdk::genesis_block::GenesisBlock; use solana_sdk::hash::hash; use std::sync::mpsc::channel; + use std::sync::mpsc::RecvError; #[test] fn test_poh_service() { @@ -124,12 +147,18 @@ mod tests { let bank = Arc::new(Bank::new(&genesis_block)); let prev_id = bank.last_id(); let (entry_sender, entry_receiver) = channel(); - let poh_recorder = PohRecorder::new(bank.tick_height(), prev_id, std::u64::MAX); + let poh_recorder = PohRecorder::new(bank.tick_height(), prev_id); let exit = Arc::new(AtomicBool::new(false)); + let working_bank = WorkingBank { + bank: bank.clone(), + sender: entry_sender, + min_tick_height: bank.tick_height(), + max_tick_height: std::u64::MAX, + }; let entry_producer: JoinHandle> = { let poh_recorder = poh_recorder.clone(); - let entry_sender = entry_sender.clone(); + let working_bank = working_bank.clone(); let exit = exit.clone(); Builder::new() @@ -139,7 +168,7 @@ mod tests { // send some data let h1 = hash(b"hello world!"); let tx = test_tx(); - poh_recorder.record(h1, vec![tx], &entry_sender).unwrap(); + poh_recorder.record(h1, vec![tx], &working_bank).unwrap(); if exit.load(Ordering::Relaxed) { break Ok(()); @@ -150,14 +179,16 @@ mod tests { }; const HASHES_PER_TICK: u64 = 2; - let poh_service = PohService::new( - bank, - entry_sender, - poh_recorder, + let (poh_service, working_bank_sender) = PohService::new( + poh_recorder.clone(), PohServiceConfig::Tick(HASHES_PER_TICK as usize), Arc::new(AtomicBool::new(false)), ); + working_bank_sender + .send(working_bank.clone()) + .expect("send"); + // get some events let mut hashes = 0; let mut need_tick = true; @@ -193,4 +224,44 @@ mod tests { let _ = entry_producer.join().unwrap(); } + #[test] + fn test_poh_service_drops_working_bank() { + let (genesis_block, _mint_keypair) = GenesisBlock::new(2); + let bank = Arc::new(Bank::new(&genesis_block)); + let prev_id = bank.last_id(); + let (entry_sender, entry_receiver) = channel(); + let poh_recorder = PohRecorder::new(bank.tick_height(), prev_id); + let exit = Arc::new(AtomicBool::new(false)); + let working_bank = WorkingBank { + bank: bank.clone(), + sender: entry_sender, + min_tick_height: bank.tick_height() + 3, + max_tick_height: bank.tick_height() + 5, + }; + + let (poh_service, working_bank_sender) = PohService::new( + poh_recorder.clone(), + PohServiceConfig::default(), + Arc::new(AtomicBool::new(false)), + ); + + working_bank_sender.send(working_bank).expect("send"); + + // all 5 ticks are expected + // First 3 ticks must be sent all at once, since bank shouldn't see them until + // the bank's min_tick_height(3) is reached. + let entries = entry_receiver.recv().unwrap(); + assert_eq!(entries.len(), 3); + let entries = entry_receiver.recv().unwrap(); + assert_eq!(entries.len(), 1); + let entries = entry_receiver.recv().unwrap(); + assert_eq!(entries.len(), 1); + + //WorkingBank should be dropped by the PohService thread as well + assert_eq!(entry_receiver.recv(), Err(RecvError)); + + exit.store(true, Ordering::Relaxed); + poh_service.exit(); + let _ = poh_service.join().unwrap(); + } } diff --git a/src/result.rs b/src/result.rs index cc60a01f65..4e817a3836 100644 --- a/src/result.rs +++ b/src/result.rs @@ -20,6 +20,7 @@ pub enum Error { JoinError(Box), RecvError(std::sync::mpsc::RecvError), RecvTimeoutError(std::sync::mpsc::RecvTimeoutError), + TryRecvError(std::sync::mpsc::TryRecvError), Serialize(std::boxed::Box), BankError(bank::BankError), ClusterInfoError(cluster_info::ClusterInfoError), @@ -46,6 +47,11 @@ impl std::convert::From for Error { Error::RecvError(e) } } +impl std::convert::From for Error { + fn from(e: std::sync::mpsc::TryRecvError) -> Error { + Error::TryRecvError(e) + } +} impl std::convert::From for Error { fn from(e: std::sync::mpsc::RecvTimeoutError) -> Error { Error::RecvTimeoutError(e)