diff --git a/core/src/banking_stage/scheduler_messages.rs b/core/src/banking_stage/scheduler_messages.rs index f5b6ae9258..172087e2cf 100644 --- a/core/src/banking_stage/scheduler_messages.rs +++ b/core/src/banking_stage/scheduler_messages.rs @@ -1,7 +1,7 @@ use { super::immutable_deserialized_packet::ImmutableDeserializedPacket, solana_sdk::{clock::Slot, transaction::SanitizedTransaction}, - std::sync::Arc, + std::{fmt::Display, sync::Arc}, }; /// A unique identifier for a transaction batch. @@ -14,6 +14,12 @@ impl TransactionBatchId { } } +impl Display for TransactionBatchId { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + /// A unique identifier for a transaction. #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] pub struct TransactionId(u64); @@ -24,6 +30,12 @@ impl TransactionId { } } +impl Display for TransactionId { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + /// Message: [Scheduler -> Worker] /// Transactions to be consumed (i.e. executed, recorded, and committed) pub struct ConsumeWork { diff --git a/core/src/banking_stage/transaction_scheduler/in_flight_tracker.rs b/core/src/banking_stage/transaction_scheduler/in_flight_tracker.rs new file mode 100644 index 0000000000..243f14c669 --- /dev/null +++ b/core/src/banking_stage/transaction_scheduler/in_flight_tracker.rs @@ -0,0 +1,123 @@ +use { + super::{batch_id_generator::BatchIdGenerator, thread_aware_account_locks::ThreadId}, + crate::banking_stage::scheduler_messages::TransactionBatchId, + std::collections::HashMap, +}; + +/// Tracks the number of transactions that are in flight for each thread. +pub struct InFlightTracker { + num_in_flight_per_thread: Vec, + cus_in_flight_per_thread: Vec, + batches: HashMap, + batch_id_generator: BatchIdGenerator, +} + +struct BatchEntry { + thread_id: ThreadId, + num_transactions: usize, + total_cus: u64, +} + +impl InFlightTracker { + pub fn new(num_threads: usize) -> Self { + Self { + num_in_flight_per_thread: vec![0; num_threads], + cus_in_flight_per_thread: vec![0; num_threads], + batches: HashMap::new(), + batch_id_generator: BatchIdGenerator::default(), + } + } + + /// Returns the number of transactions that are in flight for each thread. + pub fn num_in_flight_per_thread(&self) -> &[usize] { + &self.num_in_flight_per_thread + } + + /// Returns the number of cus that are in flight for each thread. + pub fn cus_in_flight_per_thread(&self) -> &[u64] { + &self.cus_in_flight_per_thread + } + + /// Tracks number of transactions and CUs in-flight for the `thread_id`. + /// Returns a `TransactionBatchId` that can be used to stop tracking the batch + /// when it is complete. + pub fn track_batch( + &mut self, + num_transactions: usize, + total_cus: u64, + thread_id: ThreadId, + ) -> TransactionBatchId { + let batch_id = self.batch_id_generator.next(); + self.num_in_flight_per_thread[thread_id] += num_transactions; + self.cus_in_flight_per_thread[thread_id] += total_cus; + self.batches.insert( + batch_id, + BatchEntry { + thread_id, + num_transactions, + total_cus, + }, + ); + + batch_id + } + + /// Stop tracking the batch with given `batch_id`. + /// Removes the number of transactions for the scheduled thread. + /// Returns the thread id that the batch was scheduled on. + /// + /// # Panics + /// Panics if the batch id does not exist in the tracker. + pub fn complete_batch(&mut self, batch_id: TransactionBatchId) -> ThreadId { + let Some(BatchEntry { + thread_id, + num_transactions, + total_cus, + }) = self.batches.remove(&batch_id) + else { + panic!("batch id {batch_id} is not being tracked"); + }; + self.num_in_flight_per_thread[thread_id] -= num_transactions; + self.cus_in_flight_per_thread[thread_id] -= total_cus; + + thread_id + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[should_panic(expected = "is not being tracked")] + fn test_in_flight_tracker_untracked_batch() { + let mut in_flight_tracker = InFlightTracker::new(2); + in_flight_tracker.complete_batch(TransactionBatchId::new(5)); + } + + #[test] + fn test_in_flight_tracker() { + let mut in_flight_tracker = InFlightTracker::new(2); + + // Add a batch with 2 transactions, 10 kCUs to thread 0. + let batch_id_0 = in_flight_tracker.track_batch(2, 10_000, 0); + assert_eq!(in_flight_tracker.num_in_flight_per_thread(), &[2, 0]); + assert_eq!(in_flight_tracker.cus_in_flight_per_thread(), &[10_000, 0]); + + // Add a batch with 1 transaction, 15 kCUs to thread 1. + let batch_id_1 = in_flight_tracker.track_batch(1, 15_000, 1); + assert_eq!(in_flight_tracker.num_in_flight_per_thread(), &[2, 1]); + assert_eq!( + in_flight_tracker.cus_in_flight_per_thread(), + &[10_000, 15_000] + ); + + in_flight_tracker.complete_batch(batch_id_0); + assert_eq!(in_flight_tracker.num_in_flight_per_thread(), &[0, 1]); + assert_eq!(in_flight_tracker.cus_in_flight_per_thread(), &[0, 15_000]); + + in_flight_tracker.complete_batch(batch_id_1); + assert_eq!(in_flight_tracker.num_in_flight_per_thread(), &[0, 0]); + assert_eq!(in_flight_tracker.cus_in_flight_per_thread(), &[0, 0]); + } +} diff --git a/core/src/banking_stage/transaction_scheduler/mod.rs b/core/src/banking_stage/transaction_scheduler/mod.rs index 065efba0d8..96639f4ad4 100644 --- a/core/src/banking_stage/transaction_scheduler/mod.rs +++ b/core/src/banking_stage/transaction_scheduler/mod.rs @@ -7,8 +7,8 @@ mod transaction_state; #[allow(dead_code)] mod transaction_state_container; +mod batch_id_generator; +#[allow(dead_code)] +mod in_flight_tracker; #[allow(dead_code)] mod transaction_id_generator; - -#[allow(dead_code)] -mod batch_id_generator;