diff --git a/core/src/banking_stage/thread_aware_account_locks.rs b/core/src/banking_stage/thread_aware_account_locks.rs index 1a740a0a5f..d17cc25556 100644 --- a/core/src/banking_stage/thread_aware_account_locks.rs +++ b/core/src/banking_stage/thread_aware_account_locks.rs @@ -18,6 +18,16 @@ type LockCount = u32; #[derive(Copy, Clone, PartialEq, Eq)] pub(crate) struct ThreadSet(u64); +struct AccountWriteLocks { + thread_id: ThreadId, + lock_count: LockCount, +} + +struct AccountReadLocks { + thread_set: ThreadSet, + lock_counts: [LockCount; MAX_THREADS], +} + /// Thread-aware account locks which allows for scheduling on threads /// that already hold locks on the account. This is useful for allowing /// queued transactions to be scheduled on a thread while the transaction @@ -27,11 +37,11 @@ pub(crate) struct ThreadAwareAccountLocks { num_threads: usize, // 0..MAX_THREADS /// Write locks - only one thread can hold a write lock at a time. /// Contains how many write locks are held by the thread. - write_locks: HashMap, + write_locks: HashMap, /// Read locks - multiple threads can hold a read lock at a time. /// Contains thread-set for easily checking which threads are scheduled. /// Contains how many read locks are held by each thread. - read_locks: HashMap, + read_locks: HashMap, } impl ThreadAwareAccountLocks { @@ -134,9 +144,10 @@ impl ThreadAwareAccountLocks { fn schedulable_threads(&self, account: &Pubkey) -> ThreadSet { match (self.write_locks.get(account), self.read_locks.get(account)) { (None, None) => ThreadSet::any(self.num_threads), - (None, Some((thread_set, _))) => { + (None, Some(read_locks)) => { if WRITE { - thread_set + read_locks + .thread_set .only_one_contained() .map(ThreadSet::only) .unwrap_or_else(ThreadSet::none) @@ -144,10 +155,13 @@ impl ThreadAwareAccountLocks { ThreadSet::any(self.num_threads) } } - (Some((thread_id, _)), None) => ThreadSet::only(*thread_id), - (Some((thread_id, _)), Some((thread_set, _))) => { - assert_eq!(thread_set.only_one_contained(), Some(*thread_id)); - *thread_set + (Some(write_locks), None) => ThreadSet::only(write_locks.thread_id), + (Some(write_locks), Some(read_locks)) => { + assert_eq!( + read_locks.thread_set.only_one_contained(), + Some(write_locks.thread_id) + ); + read_locks.thread_set } } } @@ -177,7 +191,10 @@ impl ThreadAwareAccountLocks { fn write_lock_account(&mut self, account: &Pubkey, thread_id: ThreadId) { match self.write_locks.entry(*account) { Entry::Occupied(mut entry) => { - let (lock_thread_id, lock_count) = entry.get_mut(); + let AccountWriteLocks { + thread_id: lock_thread_id, + lock_count, + } = entry.get_mut(); assert_eq!( *lock_thread_id, thread_id, "outstanding write lock must be on same thread" @@ -186,14 +203,17 @@ impl ThreadAwareAccountLocks { *lock_count += 1; } Entry::Vacant(entry) => { - entry.insert((thread_id, 1)); + entry.insert(AccountWriteLocks { + thread_id, + lock_count: 1, + }); } } // Check for outstanding read-locks - if let Some(&(read_thread_set, _)) = self.read_locks.get(account) { + if let Some(read_locks) = self.read_locks.get(account) { assert_eq!( - read_thread_set, + read_locks.thread_set, ThreadSet::only(thread_id), "outstanding read lock must be on same thread" ); @@ -205,7 +225,10 @@ impl ThreadAwareAccountLocks { fn write_unlock_account(&mut self, account: &Pubkey, thread_id: ThreadId) { match self.write_locks.entry(*account) { Entry::Occupied(mut entry) => { - let (lock_thread_id, lock_count) = entry.get_mut(); + let AccountWriteLocks { + thread_id: lock_thread_id, + lock_count, + } = entry.get_mut(); assert_eq!( *lock_thread_id, thread_id, "outstanding write lock must be on same thread" @@ -226,21 +249,27 @@ impl ThreadAwareAccountLocks { fn read_lock_account(&mut self, account: &Pubkey, thread_id: ThreadId) { match self.read_locks.entry(*account) { Entry::Occupied(mut entry) => { - let (thread_set, lock_counts) = entry.get_mut(); + let AccountReadLocks { + thread_set, + lock_counts, + } = entry.get_mut(); thread_set.insert(thread_id); lock_counts[thread_id] += 1; } Entry::Vacant(entry) => { let mut lock_counts = [0; MAX_THREADS]; lock_counts[thread_id] = 1; - entry.insert((ThreadSet::only(thread_id), lock_counts)); + entry.insert(AccountReadLocks { + thread_set: ThreadSet::only(thread_id), + lock_counts, + }); } } // Check for outstanding write-locks - if let Some((write_thread_id, _)) = self.write_locks.get(account) { + if let Some(write_locks) = self.write_locks.get(account) { assert_eq!( - write_thread_id, &thread_id, + write_locks.thread_id, thread_id, "outstanding write lock must be on same thread" ); } @@ -251,7 +280,10 @@ impl ThreadAwareAccountLocks { fn read_unlock_account(&mut self, account: &Pubkey, thread_id: ThreadId) { match self.read_locks.entry(*account) { Entry::Occupied(mut entry) => { - let (thread_set, lock_counts) = entry.get_mut(); + let AccountReadLocks { + thread_set, + lock_counts, + } = entry.get_mut(); assert!( thread_set.contains(thread_id), "outstanding read lock must be on same thread"