diff --git a/gossip/src/cluster_info.rs b/gossip/src/cluster_info.rs index c93a10ce7..e1414afde 100644 --- a/gossip/src/cluster_info.rs +++ b/gossip/src/cluster_info.rs @@ -1690,7 +1690,7 @@ impl ClusterInfo { Some(root_bank.feature_set.clone()), ) } - None => (HashMap::new(), None), + None => (Arc::default(), None), }; let require_stake_for_gossip = self.require_stake_for_gossip(feature_set.as_deref(), &stakes); @@ -2485,7 +2485,7 @@ impl ClusterInfo { // feature does not roll back (if the feature happens to get enabled in // a minority fork). let (feature_set, stakes) = match bank_forks { - None => (None, HashMap::default()), + None => (None, Arc::default()), Some(bank_forks) => { let bank = bank_forks.read().unwrap().root_bank(); let feature_set = bank.feature_set.clone(); diff --git a/ledger/src/leader_schedule_utils.rs b/ledger/src/leader_schedule_utils.rs index c9d801bd3..3f17dcf5c 100644 --- a/ledger/src/leader_schedule_utils.rs +++ b/ledger/src/leader_schedule_utils.rs @@ -12,7 +12,10 @@ pub fn leader_schedule(epoch: Epoch, bank: &Bank) -> Option { bank.epoch_staked_nodes(epoch).map(|stakes| { let mut seed = [0u8; 32]; seed[0..8].copy_from_slice(&epoch.to_le_bytes()); - let mut stakes: Vec<_> = stakes.into_iter().collect(); + let mut stakes: Vec<_> = stakes + .iter() + .map(|(pubkey, stake)| (*pubkey, *stake)) + .collect(); sort_stakes(&mut stakes); LeaderSchedule::new( &stakes, @@ -88,7 +91,11 @@ mod tests { .genesis_config; let bank = Bank::new_for_tests(&genesis_config); - let pubkeys_and_stakes: Vec<_> = bank.staked_nodes().into_iter().collect(); + let pubkeys_and_stakes: Vec<_> = bank + .staked_nodes() + .iter() + .map(|(pubkey, stake)| (*pubkey, *stake)) + .collect(); let seed = [0u8; 32]; let leader_schedule = LeaderSchedule::new( &pubkeys_and_stakes, diff --git a/runtime/src/bank.rs b/runtime/src/bank.rs index fba999770..22881ff40 100644 --- a/runtime/src/bank.rs +++ b/runtime/src/bank.rs @@ -5137,7 +5137,7 @@ impl Bank { self.stakes.read().unwrap().stake_delegations().clone() } - pub fn staked_nodes(&self) -> HashMap { + pub fn staked_nodes(&self) -> Arc> { self.stakes.read().unwrap().staked_nodes() } @@ -5177,7 +5177,7 @@ impl Bank { &self.epoch_stakes } - pub fn epoch_staked_nodes(&self, epoch: Epoch) -> Option> { + pub fn epoch_staked_nodes(&self, epoch: Epoch) -> Option>> { Some(self.epoch_stakes.get(&epoch)?.stakes().staked_nodes()) } diff --git a/runtime/src/stakes.rs b/runtime/src/stakes.rs index 67f0fdfab..974a6997d 100644 --- a/runtime/src/stakes.rs +++ b/runtime/src/stakes.rs @@ -1,19 +1,21 @@ //! Stakes serve as a cache of stake and vote accounts to derive //! node stakes -use crate::vote_account::{ArcVoteAccount, VoteAccounts}; -use solana_sdk::{ - account::{AccountSharedData, ReadableAccount}, - clock::Epoch, - pubkey::Pubkey, - stake::{ - self, - state::{Delegation, StakeState}, +use { + crate::vote_account::{ArcVoteAccount, VoteAccounts}, + solana_sdk::{ + account::{AccountSharedData, ReadableAccount}, + clock::Epoch, + pubkey::Pubkey, + stake::{ + self, + state::{Delegation, StakeState}, + }, + sysvar::stake_history::StakeHistory, }, - sysvar::stake_history::StakeHistory, + solana_stake_program::stake_state, + solana_vote_program::vote_state::VoteState, + std::{borrow::Borrow, collections::HashMap, sync::Arc}, }; -use solana_stake_program::stake_state; -use solana_vote_program::vote_state::VoteState; -use std::{borrow::Borrow, collections::HashMap}; #[derive(Default, Clone, PartialEq, Debug, Deserialize, Serialize, AbiExample)] pub struct Stakes { @@ -217,7 +219,7 @@ impl Stakes { &self.stake_delegations } - pub fn staked_nodes(&self) -> HashMap { + pub fn staked_nodes(&self) -> Arc> { self.vote_accounts.staked_nodes() } diff --git a/runtime/src/vote_account.rs b/runtime/src/vote_account.rs index 38a79a6a2..deebc99d2 100644 --- a/runtime/src/vote_account.rs +++ b/runtime/src/vote_account.rs @@ -1,16 +1,19 @@ -use serde::de::{Deserialize, Deserializer}; -use serde::ser::{Serialize, Serializer}; -use solana_sdk::{ - account::Account, account::AccountSharedData, instruction::InstructionError, pubkey::Pubkey, -}; -use solana_vote_program::vote_state::VoteState; -use std::{ - borrow::Borrow, - cmp::Ordering, - collections::{hash_map::Entry, HashMap}, - iter::FromIterator, - ops::Deref, - sync::{Arc, Once, RwLock, RwLockReadGuard}, +use { + itertools::Itertools, + serde::de::{Deserialize, Deserializer}, + serde::ser::{Serialize, Serializer}, + solana_sdk::{ + account::Account, account::AccountSharedData, instruction::InstructionError, pubkey::Pubkey, + }, + solana_vote_program::vote_state::VoteState, + std::{ + borrow::Borrow, + cmp::Ordering, + collections::{hash_map::Entry, HashMap}, + iter::FromIterator, + ops::Deref, + sync::{Arc, Once, RwLock, RwLockReadGuard}, + }, }; // The value here does not matter. It will be overwritten @@ -31,10 +34,14 @@ pub struct VoteAccount { #[derive(Debug, AbiExample)] pub struct VoteAccounts { vote_accounts: HashMap, + // Inner Arc is meant to implement copy-on-write semantics as opposed to + // sharing mutations (hence RwLock> instead of Arc>). staked_nodes: RwLock< - HashMap< - Pubkey, // VoteAccount.vote_state.node_pubkey. - u64, // Total stake across all vote-accounts. + Arc< + HashMap< + Pubkey, // VoteAccount.vote_state.node_pubkey. + u64, // Total stake across all vote-accounts. + >, >, >, staked_nodes_once: Once, @@ -59,20 +66,19 @@ impl VoteAccount { } impl VoteAccounts { - pub fn staked_nodes(&self) -> HashMap { + pub fn staked_nodes(&self) -> Arc> { self.staked_nodes_once.call_once(|| { - let mut staked_nodes = HashMap::new(); - for (stake, vote_account) in - self.vote_accounts.values().filter(|(stake, _)| *stake != 0) - { - if let Some(node_pubkey) = vote_account.node_pubkey() { - staked_nodes - .entry(node_pubkey) - .and_modify(|s| *s += *stake) - .or_insert(*stake); - } - } - *self.staked_nodes.write().unwrap() = staked_nodes + let staked_nodes = self + .vote_accounts + .values() + .filter(|(stake, _)| *stake != 0) + .filter_map(|(stake, vote_account)| { + let node_pubkey = vote_account.node_pubkey()?; + Some((node_pubkey, stake)) + }) + .into_grouping_map() + .aggregate(|acc, _node_pubkey, stake| Some(acc.unwrap_or_default() + stake)); + *self.staked_nodes.write().unwrap() = Arc::new(staked_nodes) }); self.staked_nodes.read().unwrap().clone() } @@ -119,9 +125,9 @@ impl VoteAccounts { fn add_node_stake(&mut self, stake: u64, vote_account: &ArcVoteAccount) { if stake != 0 && self.staked_nodes_once.is_completed() { if let Some(node_pubkey) = vote_account.node_pubkey() { - self.staked_nodes - .write() - .unwrap() + let mut staked_nodes = self.staked_nodes.write().unwrap(); + let staked_nodes = Arc::make_mut(&mut staked_nodes); + staked_nodes .entry(node_pubkey) .and_modify(|s| *s += stake) .or_insert(stake); @@ -132,7 +138,9 @@ impl VoteAccounts { fn sub_node_stake(&mut self, stake: u64, vote_account: &ArcVoteAccount) { if stake != 0 && self.staked_nodes_once.is_completed() { if let Some(node_pubkey) = vote_account.node_pubkey() { - match self.staked_nodes.write().unwrap().entry(node_pubkey) { + let mut staked_nodes = self.staked_nodes.write().unwrap(); + let staked_nodes = Arc::make_mut(&mut staked_nodes); + match staked_nodes.entry(node_pubkey) { Entry::Vacant(_) => panic!("this should not happen!"), Entry::Occupied(mut entry) => match entry.get().cmp(&stake) { Ordering::Less => panic!("subtraction value exceeds node's stake"), @@ -474,7 +482,7 @@ mod tests { if (k + 1) % 128 == 0 { assert_eq!( staked_nodes(&accounts[..k + 1]), - vote_accounts.staked_nodes() + *vote_accounts.staked_nodes() ); } } @@ -484,7 +492,7 @@ mod tests { let (pubkey, (_, _)) = accounts.swap_remove(index); vote_accounts.remove(&pubkey); if (k + 1) % 32 == 0 { - assert_eq!(staked_nodes(&accounts), vote_accounts.staked_nodes()); + assert_eq!(staked_nodes(&accounts), *vote_accounts.staked_nodes()); } } // Modify the stakes for some of the accounts. @@ -499,7 +507,7 @@ mod tests { } *stake = new_stake; if (k + 1) % 128 == 0 { - assert_eq!(staked_nodes(&accounts), vote_accounts.staked_nodes()); + assert_eq!(staked_nodes(&accounts), *vote_accounts.staked_nodes()); } } // Remove everything. @@ -508,9 +516,41 @@ mod tests { let (pubkey, (_, _)) = accounts.swap_remove(index); vote_accounts.remove(&pubkey); if accounts.len() % 32 == 0 { - assert_eq!(staked_nodes(&accounts), vote_accounts.staked_nodes()); + assert_eq!(staked_nodes(&accounts), *vote_accounts.staked_nodes()); } } assert!(vote_accounts.staked_nodes.read().unwrap().is_empty()); } + + // Asserts that returned staked-nodes are copy-on-write references. + #[test] + fn test_staked_nodes_cow() { + let mut rng = rand::thread_rng(); + let mut accounts = new_rand_vote_accounts(&mut rng, 64); + // Add vote accounts. + let mut vote_accounts = VoteAccounts::default(); + for (pubkey, (stake, vote_account)) in (&mut accounts).take(1024) { + vote_accounts.insert(pubkey, (stake, vote_account)); + } + let staked_nodes = vote_accounts.staked_nodes(); + let (pubkey, (more_stake, vote_account)) = + accounts.find(|(_, (stake, _))| *stake != 0).unwrap(); + let node_pubkey = vote_account.node_pubkey().unwrap(); + vote_accounts.insert(pubkey, (more_stake, vote_account)); + assert_ne!(staked_nodes, vote_accounts.staked_nodes()); + assert_eq!( + vote_accounts.staked_nodes()[&node_pubkey], + more_stake + staked_nodes.get(&node_pubkey).copied().unwrap_or_default() + ); + for (pubkey, stake) in vote_accounts.staked_nodes().iter() { + if *pubkey != node_pubkey { + assert_eq!(*stake, staked_nodes[pubkey]); + } else { + assert_eq!( + *stake, + more_stake + staked_nodes.get(pubkey).copied().unwrap_or_default() + ); + } + } + } }