diff --git a/core/src/cluster_info_vote_listener.rs b/core/src/cluster_info_vote_listener.rs index b8045e22dd..1c3005d20f 100644 --- a/core/src/cluster_info_vote_listener.rs +++ b/core/src/cluster_info_vote_listener.rs @@ -29,7 +29,6 @@ use solana_runtime::{ bank_forks::BankForks, commitment::VOTE_THRESHOLD_SIZE, epoch_stakes::{EpochAuthorizedVoters, EpochStakes}, - stakes::Stakes, vote_sender_types::{ReplayVoteReceiver, ReplayedVote}, }; use solana_sdk::{ @@ -601,7 +600,7 @@ impl ClusterInfoVoteListener { // The last vote slot, which is the greatest slot in the stack // of votes in a vote transaction, qualifies for optimistic confirmation. if slot == last_vote_slot { - let vote_accounts = Stakes::vote_accounts(epoch_stakes.stakes()); + let vote_accounts = epoch_stakes.stakes().vote_accounts(); let stake = vote_accounts .get(vote_pubkey) .map(|(stake, _)| *stake) diff --git a/core/src/commitment_service.rs b/core/src/commitment_service.rs index dac4939401..f5225c34ea 100644 --- a/core/src/commitment_service.rs +++ b/core/src/commitment_service.rs @@ -183,8 +183,8 @@ impl AggregateCommitmentService { let mut commitment = HashMap::new(); let mut rooted_stake: Vec<(Slot, u64)> = Vec::new(); - for (_, (lamports, account)) in bank.vote_accounts().into_iter() { - if lamports == 0 { + for (lamports, account) in bank.vote_accounts().values() { + if *lamports == 0 { continue; } if let Ok(vote_state) = account.vote_state().as_ref() { @@ -193,7 +193,7 @@ impl AggregateCommitmentService { &mut rooted_stake, vote_state, ancestors, - lamports, + *lamports, ); } } diff --git a/core/src/consensus.rs b/core/src/consensus.rs index 000d58f5e6..b5868ffc92 100644 --- a/core/src/consensus.rs +++ b/core/src/consensus.rs @@ -198,17 +198,14 @@ impl Tower { Self::new(node_pubkey, vote_account, root, &heaviest_bank) } - pub(crate) fn collect_vote_lockouts( + pub(crate) fn collect_vote_lockouts( vote_account_pubkey: &Pubkey, bank_slot: Slot, - vote_accounts: F, + vote_accounts: &HashMap, ancestors: &HashMap>, get_frozen_hash: impl Fn(Slot) -> Option, latest_validator_votes_for_frozen_banks: &mut LatestValidatorVotesForFrozenBanks, - ) -> ComputedBankState - where - F: IntoIterator, - { + ) -> ComputedBankState { let mut vote_slots = HashSet::new(); let mut voted_stakes = HashMap::new(); let mut total_stake = 0; @@ -217,7 +214,8 @@ impl Tower { // keyed by end of the range let mut lockout_intervals = LockoutIntervals::new(); let mut my_latest_landed_vote = None; - for (key, (voted_stake, account)) in vote_accounts { + for (&key, (voted_stake, account)) in vote_accounts.iter() { + let voted_stake = *voted_stake; if voted_stake == 0 { continue; } @@ -1270,56 +1268,60 @@ pub fn reconcile_blockstore_roots_with_tower( #[cfg(test)] pub mod test { - use super::*; - use crate::{ - fork_choice::ForkChoice, heaviest_subtree_fork_choice::SlotHashKey, - replay_stage::HeaviestForkFailures, tower_storage::FileTowerStorage, - vote_simulator::VoteSimulator, + use { + super::*, + crate::{ + fork_choice::ForkChoice, heaviest_subtree_fork_choice::SlotHashKey, + replay_stage::HeaviestForkFailures, tower_storage::FileTowerStorage, + vote_simulator::VoteSimulator, + }, + itertools::Itertools, + solana_ledger::{blockstore::make_slot_entries, get_tmp_ledger_path}, + solana_runtime::bank::Bank, + solana_sdk::{ + account::{Account, AccountSharedData, ReadableAccount, WritableAccount}, + clock::Slot, + hash::Hash, + pubkey::Pubkey, + signature::Signer, + slot_history::SlotHistory, + }, + solana_vote_program::vote_state::{Vote, VoteStateVersions, MAX_LOCKOUT_HISTORY}, + std::{ + collections::HashMap, + fs::{remove_file, OpenOptions}, + io::{Read, Seek, SeekFrom, Write}, + path::PathBuf, + sync::Arc, + }, + tempfile::TempDir, + trees::tr, }; - use solana_ledger::{blockstore::make_slot_entries, get_tmp_ledger_path}; - use solana_runtime::bank::Bank; - use solana_sdk::{ - account::{Account, AccountSharedData, ReadableAccount, WritableAccount}, - clock::Slot, - hash::Hash, - pubkey::Pubkey, - signature::Signer, - slot_history::SlotHistory, - }; - use solana_vote_program::vote_state::{Vote, VoteStateVersions, MAX_LOCKOUT_HISTORY}; - use std::{ - collections::HashMap, - fs::{remove_file, OpenOptions}, - io::{Read, Seek, SeekFrom, Write}, - path::PathBuf, - sync::Arc, - }; - use tempfile::TempDir; - use trees::tr; - fn gen_stakes(stake_votes: &[(u64, &[u64])]) -> Vec<(Pubkey, (u64, VoteAccount))> { - let mut stakes = vec![]; - for (lamports, votes) in stake_votes { - let mut account = AccountSharedData::from(Account { - data: vec![0; VoteState::size_of()], - lamports: *lamports, - ..Account::default() - }); - let mut vote_state = VoteState::default(); - for slot in *votes { - vote_state.process_slot_vote_unchecked(*slot); - } - VoteState::serialize( - &VoteStateVersions::new_current(vote_state), - &mut account.data_as_mut_slice(), - ) - .expect("serialize state"); - stakes.push(( - solana_sdk::pubkey::new_rand(), - (*lamports, VoteAccount::from(account)), - )); - } - stakes + fn gen_stakes(stake_votes: &[(u64, &[u64])]) -> HashMap { + stake_votes + .iter() + .map(|(lamports, votes)| { + let mut account = AccountSharedData::from(Account { + data: vec![0; VoteState::size_of()], + lamports: *lamports, + ..Account::default() + }); + let mut vote_state = VoteState::default(); + for slot in *votes { + vote_state.process_slot_vote_unchecked(*slot); + } + VoteState::serialize( + &VoteStateVersions::new_current(vote_state), + &mut account.data_as_mut_slice(), + ) + .expect("serialize state"); + ( + solana_sdk::pubkey::new_rand(), + (*lamports, VoteAccount::from(account)), + ) + }) + .collect() } #[test] @@ -1964,10 +1966,10 @@ pub mod test { #[test] fn test_collect_vote_lockouts_sums() { //two accounts voting for slot 0 with 1 token staked - let mut accounts = gen_stakes(&[(1, &[0]), (1, &[0])]); - accounts.sort_by_key(|(pk, _)| *pk); + let accounts = gen_stakes(&[(1, &[0]), (1, &[0])]); let account_latest_votes: Vec<(Pubkey, SlotHashKey)> = accounts .iter() + .sorted_by_key(|(pk, _)| *pk) .map(|(pubkey, _)| (*pubkey, (0, Hash::default()))) .collect(); @@ -1984,7 +1986,7 @@ pub mod test { } = Tower::collect_vote_lockouts( &Pubkey::default(), 1, - accounts.into_iter(), + &accounts, &ancestors, |_| Some(Hash::default()), &mut latest_validator_votes_for_frozen_banks, @@ -2004,10 +2006,10 @@ pub mod test { fn test_collect_vote_lockouts_root() { let votes: Vec = (0..MAX_LOCKOUT_HISTORY as u64).collect(); //two accounts voting for slots 0..MAX_LOCKOUT_HISTORY with 1 token staked - let mut accounts = gen_stakes(&[(1, &votes), (1, &votes)]); - accounts.sort_by_key(|(pk, _)| *pk); + let accounts = gen_stakes(&[(1, &votes), (1, &votes)]); let account_latest_votes: Vec<(Pubkey, SlotHashKey)> = accounts .iter() + .sorted_by_key(|(pk, _)| *pk) .map(|(pubkey, _)| { ( *pubkey, @@ -2044,7 +2046,7 @@ pub mod test { } = Tower::collect_vote_lockouts( &Pubkey::default(), MAX_LOCKOUT_HISTORY as u64, - accounts.into_iter(), + &accounts, &ancestors, |_| Some(Hash::default()), &mut latest_validator_votes_for_frozen_banks, @@ -2340,7 +2342,7 @@ pub mod test { } = Tower::collect_vote_lockouts( &Pubkey::default(), vote_to_evaluate, - accounts.clone().into_iter(), + &accounts, &ancestors, |_| None, &mut LatestValidatorVotesForFrozenBanks::default(), @@ -2358,7 +2360,7 @@ pub mod test { } = Tower::collect_vote_lockouts( &Pubkey::default(), vote_to_evaluate, - accounts.into_iter(), + &accounts, &ancestors, |_| None, &mut LatestValidatorVotesForFrozenBanks::default(), diff --git a/core/src/replay_stage.rs b/core/src/replay_stage.rs index d90352b5ae..72b3a08eea 100644 --- a/core/src/replay_stage.rs +++ b/core/src/replay_stage.rs @@ -2169,7 +2169,7 @@ impl ReplayStage { let computed_bank_state = Tower::collect_vote_lockouts( my_vote_pubkey, bank_slot, - bank.vote_accounts().into_iter(), + &bank.vote_accounts(), ancestors, |slot| progress.get_hash(slot), latest_validator_votes_for_frozen_banks, diff --git a/core/src/validator.rs b/core/src/validator.rs index 217d040fb1..4e2ed84656 100644 --- a/core/src/validator.rs +++ b/core/src/validator.rs @@ -1544,7 +1544,8 @@ fn get_stake_percent_in_gossip(bank: &Bank, cluster_info: &ClusterInfo, log: boo let my_shred_version = cluster_info.my_shred_version(); let my_id = cluster_info.id(); - for (_, (activated_stake, vote_account)) in bank.vote_accounts() { + for (activated_stake, vote_account) in bank.vote_accounts().values() { + let activated_stake = *activated_stake; total_activated_stake += activated_stake; if activated_stake == 0 { diff --git a/ledger-tool/src/main.rs b/ledger-tool/src/main.rs index cfb75a0847..4f491709e6 100644 --- a/ledger-tool/src/main.rs +++ b/ledger-tool/src/main.rs @@ -353,18 +353,18 @@ fn graph_forks(bank_forks: &BankForks, include_all_votes: bool) -> String { .iter() .map(|(_, (stake, _))| stake) .sum(); - for (_, (stake, vote_account)) in bank.vote_accounts() { + for (stake, vote_account) in bank.vote_accounts().values() { let vote_state = vote_account.vote_state(); let vote_state = vote_state.as_ref().unwrap_or(&default_vote_state); if let Some(last_vote) = vote_state.votes.iter().last() { let entry = last_votes.entry(vote_state.node_pubkey).or_insert(( last_vote.slot, vote_state.clone(), - stake, + *stake, total_stake, )); if entry.0 < last_vote.slot { - *entry = (last_vote.slot, vote_state.clone(), stake, total_stake); + *entry = (last_vote.slot, vote_state.clone(), *stake, total_stake); } } } @@ -394,7 +394,7 @@ fn graph_forks(bank_forks: &BankForks, include_all_votes: bool) -> String { let mut first = true; loop { - for (_, (_, vote_account)) in bank.vote_accounts() { + for (_, vote_account) in bank.vote_accounts().values() { let vote_state = vote_account.vote_state(); let vote_state = vote_state.as_ref().unwrap_or(&default_vote_state); if let Some(last_vote) = vote_state.votes.iter().last() { diff --git a/ledger/src/blockstore_processor.rs b/ledger/src/blockstore_processor.rs index bc634f9763..05b174761a 100644 --- a/ledger/src/blockstore_processor.rs +++ b/ledger/src/blockstore_processor.rs @@ -1114,7 +1114,7 @@ fn load_frozen_forks( supermajority_root_from_vote_accounts( bank.slot(), bank.total_epoch_stake(), - bank.vote_accounts(), + &bank.vote_accounts(), ).and_then(|supermajority_root| { if supermajority_root > *root { // If there's a cluster confirmed root greater than our last @@ -1223,18 +1223,15 @@ fn supermajority_root(roots: &[(Slot, u64)], total_epoch_stake: u64) -> Option( +fn supermajority_root_from_vote_accounts( bank_slot: Slot, total_epoch_stake: u64, - vote_accounts: I, -) -> Option -where - I: IntoIterator, -{ + vote_accounts: &HashMap, +) -> Option { let mut roots_stakes: Vec<(Slot, u64)> = vote_accounts - .into_iter() + .iter() .filter_map(|(key, (stake, account))| { - if stake == 0 { + if *stake == 0 { return None; } @@ -1246,7 +1243,7 @@ where ); None } - Ok(vote_state) => vote_state.root_slot.map(|root_slot| (root_slot, stake)), + Ok(vote_state) => Some((vote_state.root_slot?, *stake)), } }) .collect(); @@ -3591,7 +3588,7 @@ pub mod tests { #[allow(clippy::field_reassign_with_default)] fn test_supermajority_root_from_vote_accounts() { let convert_to_vote_accounts = - |roots_stakes: Vec<(Slot, u64)>| -> Vec<(Pubkey, (u64, VoteAccount))> { + |roots_stakes: Vec<(Slot, u64)>| -> HashMap { roots_stakes .into_iter() .map(|(root, stake)| { @@ -3609,7 +3606,7 @@ pub mod tests { (stake, VoteAccount::from(vote_account)), ) }) - .collect_vec() + .collect() }; let total_stake = 10; @@ -3617,22 +3614,19 @@ pub mod tests { // Supermajority root should be None assert!( - supermajority_root_from_vote_accounts(slot, total_stake, std::iter::empty()).is_none() + supermajority_root_from_vote_accounts(slot, total_stake, &HashMap::default()).is_none() ); // Supermajority root should be None let roots_stakes = vec![(8, 1), (3, 1), (4, 1), (8, 1)]; let accounts = convert_to_vote_accounts(roots_stakes); - assert!( - supermajority_root_from_vote_accounts(slot, total_stake, accounts.into_iter()) - .is_none() - ); + assert!(supermajority_root_from_vote_accounts(slot, total_stake, &accounts).is_none()); // Supermajority root should be 4, has 7/10 of the stake let roots_stakes = vec![(8, 1), (3, 1), (4, 1), (8, 5)]; let accounts = convert_to_vote_accounts(roots_stakes); assert_eq!( - supermajority_root_from_vote_accounts(slot, total_stake, accounts.into_iter()).unwrap(), + supermajority_root_from_vote_accounts(slot, total_stake, &accounts).unwrap(), 4 ); @@ -3640,7 +3634,7 @@ pub mod tests { let roots_stakes = vec![(8, 1), (3, 1), (4, 1), (8, 6)]; let accounts = convert_to_vote_accounts(roots_stakes); assert_eq!( - supermajority_root_from_vote_accounts(slot, total_stake, accounts.into_iter()).unwrap(), + supermajority_root_from_vote_accounts(slot, total_stake, &accounts).unwrap(), 8 ); } diff --git a/runtime/src/bank.rs b/runtime/src/bank.rs index 68cb6ae05e..d93703fe65 100644 --- a/runtime/src/bank.rs +++ b/runtime/src/bank.rs @@ -2252,25 +2252,21 @@ impl Bank { ) -> Option { let mut get_timestamp_estimate_time = Measure::start("get_timestamp_estimate"); let slots_per_epoch = self.epoch_schedule().slots_per_epoch; - let recent_timestamps = - self.vote_accounts() - .into_iter() - .filter_map(|(pubkey, (_, account))| { - let vote_state = account.vote_state(); - let vote_state = vote_state.as_ref().ok()?; - let slot_delta = self.slot().checked_sub(vote_state.last_timestamp.slot)?; - if slot_delta <= slots_per_epoch { - Some(( - pubkey, - ( - vote_state.last_timestamp.slot, - vote_state.last_timestamp.timestamp, - ), - )) - } else { - None - } - }); + let vote_accounts = self.vote_accounts(); + let recent_timestamps = vote_accounts.iter().filter_map(|(pubkey, (_, account))| { + let vote_state = account.vote_state(); + let vote_state = vote_state.as_ref().ok()?; + let slot_delta = self.slot().checked_sub(vote_state.last_timestamp.slot)?; + (slot_delta <= slots_per_epoch).then(|| { + ( + *pubkey, + ( + vote_state.last_timestamp.slot, + vote_state.last_timestamp.timestamp, + ), + ) + }) + }); let slot_duration = Duration::from_nanos(self.ns_per_slot as u64); let epoch = self.epoch_schedule().get_epoch(self.slot()); let stakes = self.epoch_vote_accounts(epoch)?; @@ -3716,24 +3712,25 @@ impl Bank { // // Ref: collect_fees #[allow(clippy::needless_collect)] - fn distribute_rent_to_validators(&self, vote_accounts: I, rent_to_be_distributed: u64) - where - I: IntoIterator, - { + fn distribute_rent_to_validators( + &self, + vote_accounts: &HashMap, + rent_to_be_distributed: u64, + ) { let mut total_staked = 0; // Collect the stake associated with each validator. // Note that a validator may be present in this vector multiple times if it happens to have // more than one staked vote account somehow let mut validator_stakes = vote_accounts - .into_iter() + .iter() .filter_map(|(_vote_pubkey, (staked, account))| { - if staked == 0 { + if *staked == 0 { None } else { - total_staked += staked; + total_staked += *staked; let node_pubkey = account.vote_state().as_ref().ok()?.node_pubkey; - Some((node_pubkey, staked)) + Some((node_pubkey, *staked)) } }) .collect::>(); @@ -3848,7 +3845,7 @@ impl Bank { return; } - self.distribute_rent_to_validators(self.vote_accounts(), rent_to_be_distributed); + self.distribute_rent_to_validators(&self.vote_accounts(), rent_to_be_distributed); } fn collect_rent( @@ -5198,26 +5195,15 @@ impl Bank { /// current vote accounts for this bank along with the stake /// attributed to each account - /// Note: This clones the entire vote-accounts hashmap. For a single - /// account lookup use get_vote_account instead. - pub fn vote_accounts(&self) -> Vec<(Pubkey, (/*stake:*/ u64, VoteAccount))> { - self.stakes - .read() - .unwrap() - .vote_accounts() - .iter() - .map(|(k, v)| (*k, v.clone())) - .collect() + pub fn vote_accounts(&self) -> Arc> { + let stakes = self.stakes.read().unwrap(); + Arc::from(stakes.vote_accounts()) } /// Vote account for the given vote account pubkey along with the stake. pub fn get_vote_account(&self, vote_account: &Pubkey) -> Option<(/*stake:*/ u64, VoteAccount)> { - self.stakes - .read() - .unwrap() - .vote_accounts() - .get(vote_account) - .cloned() + let stakes = self.stakes.read().unwrap(); + stakes.vote_accounts().get(vote_account).cloned() } /// Get the EpochStakes for a given epoch @@ -5239,9 +5225,8 @@ impl Bank { &self, epoch: Epoch, ) -> Option<&HashMap> { - self.epoch_stakes - .get(&epoch) - .map(|epoch_stakes| Stakes::vote_accounts(epoch_stakes.stakes())) + let epoch_stakes = self.epoch_stakes.get(&epoch)?.stakes(); + Some(epoch_stakes.vote_accounts().as_ref()) } /// Get the fixed authorized voter for the given vote account for the @@ -6679,7 +6664,7 @@ pub(crate) mod tests { let bank = Bank::new_for_tests(&genesis_config); let old_validator_lamports = bank.get_balance(&validator_pubkey); - bank.distribute_rent_to_validators(bank.vote_accounts(), RENT_TO_BE_DISTRIBUTED); + bank.distribute_rent_to_validators(&bank.vote_accounts(), RENT_TO_BE_DISTRIBUTED); let new_validator_lamports = bank.get_balance(&validator_pubkey); assert_eq!( new_validator_lamports, @@ -6693,7 +6678,7 @@ pub(crate) mod tests { let bank = std::panic::AssertUnwindSafe(Bank::new_for_tests(&genesis_config)); let old_validator_lamports = bank.get_balance(&validator_pubkey); let new_validator_lamports = std::panic::catch_unwind(|| { - bank.distribute_rent_to_validators(bank.vote_accounts(), RENT_TO_BE_DISTRIBUTED); + bank.distribute_rent_to_validators(&bank.vote_accounts(), RENT_TO_BE_DISTRIBUTED); bank.get_balance(&validator_pubkey) }); @@ -9531,7 +9516,7 @@ pub(crate) mod tests { bank.process_transaction(&transaction).unwrap(); - let vote_accounts = bank.vote_accounts().into_iter().collect::>(); + let vote_accounts = bank.vote_accounts(); assert_eq!(vote_accounts.len(), 2); @@ -9906,7 +9891,10 @@ pub(crate) mod tests { } // Non-native loader accounts can not be used for instruction processing - assert!(bank.stakes.read().unwrap().vote_accounts().is_empty()); + { + let stakes = bank.stakes.read().unwrap(); + assert!(stakes.vote_accounts().as_ref().is_empty()); + } assert!(bank.stakes.read().unwrap().stake_delegations().is_empty()); assert_eq!(bank.calculate_capitalization(true), bank.capitalization()); @@ -9916,13 +9904,19 @@ pub(crate) mod tests { .fetch_add(vote_account.lamports() + stake_account.lamports(), Relaxed); bank.store_account(&vote_id, &vote_account); bank.store_account(&stake_id, &stake_account); - assert!(!bank.stakes.read().unwrap().vote_accounts().is_empty()); + { + let stakes = bank.stakes.read().unwrap(); + assert!(!stakes.vote_accounts().as_ref().is_empty()); + } assert!(!bank.stakes.read().unwrap().stake_delegations().is_empty()); assert_eq!(bank.calculate_capitalization(true), bank.capitalization()); bank.add_builtin("mock_program1", vote_id, mock_ix_processor); bank.add_builtin("mock_program2", stake_id, mock_ix_processor); - assert!(bank.stakes.read().unwrap().vote_accounts().is_empty()); + { + let stakes = bank.stakes.read().unwrap(); + assert!(stakes.vote_accounts().as_ref().is_empty()); + } assert!(bank.stakes.read().unwrap().stake_delegations().is_empty()); assert_eq!(bank.calculate_capitalization(true), bank.capitalization()); assert_eq!( @@ -9942,7 +9936,10 @@ pub(crate) mod tests { bank.update_accounts_hash(); let new_hash = bank.get_accounts_hash(); assert_eq!(old_hash, new_hash); - assert!(bank.stakes.read().unwrap().vote_accounts().is_empty()); + { + let stakes = bank.stakes.read().unwrap(); + assert!(stakes.vote_accounts().as_ref().is_empty()); + } assert!(bank.stakes.read().unwrap().stake_delegations().is_empty()); assert_eq!(bank.calculate_capitalization(true), bank.capitalization()); assert_eq!( diff --git a/runtime/src/epoch_stakes.rs b/runtime/src/epoch_stakes.rs index ee488ecd50..e9f8151471 100644 --- a/runtime/src/epoch_stakes.rs +++ b/runtime/src/epoch_stakes.rs @@ -24,9 +24,9 @@ pub struct EpochStakes { impl EpochStakes { pub fn new(stakes: &Stakes, leader_schedule_epoch: Epoch) -> Self { - let epoch_vote_accounts = Stakes::vote_accounts(stakes); + let epoch_vote_accounts = stakes.vote_accounts(); let (total_stake, node_id_to_vote_accounts, epoch_authorized_voters) = - Self::parse_epoch_vote_accounts(epoch_vote_accounts, leader_schedule_epoch); + Self::parse_epoch_vote_accounts(epoch_vote_accounts.as_ref(), leader_schedule_epoch); Self { stakes: Arc::new(stakes.clone()), total_stake, @@ -52,7 +52,8 @@ impl EpochStakes { } pub fn vote_account_stake(&self, vote_account: &Pubkey) -> u64 { - Stakes::vote_accounts(&self.stakes) + self.stakes + .vote_accounts() .get(vote_account) .map(|(stake, _)| *stake) .unwrap_or(0) diff --git a/runtime/src/stakes.rs b/runtime/src/stakes.rs index 3df518bf03..9c2c5e28e3 100644 --- a/runtime/src/stakes.rs +++ b/runtime/src/stakes.rs @@ -196,8 +196,8 @@ impl Stakes { } } - pub fn vote_accounts(&self) -> &HashMap { - self.vote_accounts.as_ref() + pub fn vote_accounts(&self) -> &VoteAccounts { + &self.vote_accounts } pub fn stake_delegations(&self) -> &HashMap { diff --git a/runtime/src/vote_account.rs b/runtime/src/vote_account.rs index 91539bc9d4..7e7e54a7fb 100644 --- a/runtime/src/vote_account.rs +++ b/runtime/src/vote_account.rs @@ -31,7 +31,7 @@ struct VoteAccountInner { #[derive(Debug, AbiExample)] pub struct VoteAccounts { - vote_accounts: HashMap, + vote_accounts: Arc>, // Inner Arc is meant to implement copy-on-write semantics as opposed to // sharing mutations (hence RwLock> instead of Arc>). staked_nodes: RwLock< @@ -46,7 +46,7 @@ pub struct VoteAccounts { } impl VoteAccount { - pub fn lamports(&self) -> u64 { + pub(crate) fn lamports(&self) -> u64 { self.0.account.lamports } @@ -83,37 +83,43 @@ impl VoteAccounts { self.staked_nodes.read().unwrap().clone() } - pub fn iter(&self) -> impl Iterator { + pub fn get(&self, pubkey: &Pubkey) -> Option<&(/*stake:*/ u64, VoteAccount)> { + self.vote_accounts.get(pubkey) + } + + pub(crate) fn iter(&self) -> impl Iterator { self.vote_accounts.iter() } - pub fn insert(&mut self, pubkey: Pubkey, (stake, vote_account): (u64, VoteAccount)) { + pub(crate) fn insert(&mut self, pubkey: Pubkey, (stake, vote_account): (u64, VoteAccount)) { self.add_node_stake(stake, &vote_account); - if let Some((stake, vote_account)) = - self.vote_accounts.insert(pubkey, (stake, vote_account)) - { + let vote_accounts = Arc::make_mut(&mut self.vote_accounts); + if let Some((stake, vote_account)) = vote_accounts.insert(pubkey, (stake, vote_account)) { self.sub_node_stake(stake, &vote_account); } } - pub fn remove(&mut self, pubkey: &Pubkey) -> Option<(u64, VoteAccount)> { - let value = self.vote_accounts.remove(pubkey); - if let Some((stake, ref vote_account)) = value { + pub(crate) fn remove(&mut self, pubkey: &Pubkey) -> Option<(u64, VoteAccount)> { + let vote_accounts = Arc::make_mut(&mut self.vote_accounts); + let entry = vote_accounts.remove(pubkey); + if let Some((stake, ref vote_account)) = entry { self.sub_node_stake(stake, vote_account); } - value + entry } - pub fn add_stake(&mut self, pubkey: &Pubkey, delta: u64) { - if let Some((stake, vote_account)) = self.vote_accounts.get_mut(pubkey) { + pub(crate) fn add_stake(&mut self, pubkey: &Pubkey, delta: u64) { + let vote_accounts = Arc::make_mut(&mut self.vote_accounts); + if let Some((stake, vote_account)) = vote_accounts.get_mut(pubkey) { *stake += delta; let vote_account = vote_account.clone(); self.add_node_stake(delta, &vote_account); } } - pub fn sub_stake(&mut self, pubkey: &Pubkey, delta: u64) { - if let Some((stake, vote_account)) = self.vote_accounts.get_mut(pubkey) { + pub(crate) fn sub_stake(&mut self, pubkey: &Pubkey, delta: u64) { + let vote_accounts = Arc::make_mut(&mut self.vote_accounts); + if let Some((stake, vote_account)) = vote_accounts.get_mut(pubkey) { *stake = stake .checked_sub(delta) .expect("subtraction value exceeds account's stake"); @@ -221,7 +227,7 @@ impl PartialEq for VoteAccountInner { impl Default for VoteAccounts { fn default() -> Self { Self { - vote_accounts: HashMap::default(), + vote_accounts: Arc::default(), staked_nodes: RwLock::default(), staked_nodes_once: Once::new(), } @@ -257,8 +263,8 @@ impl PartialEq for VoteAccounts { type VoteAccountsHashMap = HashMap; -impl From for VoteAccounts { - fn from(vote_accounts: VoteAccountsHashMap) -> Self { +impl From> for VoteAccounts { + fn from(vote_accounts: Arc) -> Self { Self { vote_accounts, staked_nodes: RwLock::default(), @@ -273,12 +279,18 @@ impl AsRef for VoteAccounts { } } +impl From<&VoteAccounts> for Arc { + fn from(vote_accounts: &VoteAccounts) -> Self { + Arc::clone(&vote_accounts.vote_accounts) + } +} + impl FromIterator<(Pubkey, (/*stake:*/ u64, VoteAccount))> for VoteAccounts { fn from_iter(iter: I) -> Self where I: IntoIterator, { - Self::from(HashMap::from_iter(iter)) + Self::from(Arc::new(HashMap::from_iter(iter))) } } @@ -297,7 +309,7 @@ impl<'de> Deserialize<'de> for VoteAccounts { D: Deserializer<'de>, { let vote_accounts = VoteAccountsHashMap::deserialize(deserializer)?; - Ok(Self::from(vote_accounts)) + Ok(Self::from(Arc::new(vote_accounts))) } } @@ -430,7 +442,7 @@ mod tests { let mut rng = rand::thread_rng(); let vote_accounts_hash_map: HashMap = new_rand_vote_accounts(&mut rng, 64).take(1024).collect(); - let vote_accounts = VoteAccounts::from(vote_accounts_hash_map.clone()); + let vote_accounts = VoteAccounts::from(Arc::new(vote_accounts_hash_map.clone())); assert!(vote_accounts.staked_nodes().len() > 32); assert_eq!( bincode::serialize(&vote_accounts).unwrap(), @@ -452,12 +464,12 @@ mod tests { let data = bincode::serialize(&vote_accounts_hash_map).unwrap(); let vote_accounts: VoteAccounts = bincode::deserialize(&data).unwrap(); assert!(vote_accounts.staked_nodes().len() > 32); - assert_eq!(vote_accounts.vote_accounts, vote_accounts_hash_map); + assert_eq!(*vote_accounts.vote_accounts, vote_accounts_hash_map); let data = bincode::options() .serialize(&vote_accounts_hash_map) .unwrap(); let vote_accounts: VoteAccounts = bincode::options().deserialize(&data).unwrap(); - assert_eq!(vote_accounts.vote_accounts, vote_accounts_hash_map); + assert_eq!(*vote_accounts.vote_accounts, vote_accounts_hash_map); } #[test] @@ -542,4 +554,38 @@ mod tests { } } } + + // Asserts that returned vote-accounts are copy-on-write references. + #[test] + fn test_vote_accounts_cow() { + let mut rng = rand::thread_rng(); + let mut accounts = new_rand_vote_accounts(&mut rng, 64); + // Add vote accounts. + let mut vote_accounts = VoteAccounts::default(); + for (pubkey, (stake, vote_account)) in (&mut accounts).take(1024) { + vote_accounts.insert(pubkey, (stake, vote_account)); + } + let vote_accounts_hashmap = Arc::::from(&vote_accounts); + assert_eq!(vote_accounts_hashmap, vote_accounts.vote_accounts); + assert!(Arc::ptr_eq( + &vote_accounts_hashmap, + &vote_accounts.vote_accounts + )); + let (pubkey, (more_stake, vote_account)) = + accounts.find(|(_, (stake, _))| *stake != 0).unwrap(); + vote_accounts.insert(pubkey, (more_stake, vote_account.clone())); + assert!(!Arc::ptr_eq( + &vote_accounts_hashmap, + &vote_accounts.vote_accounts + )); + assert_ne!(vote_accounts_hashmap, vote_accounts.vote_accounts); + let other = (more_stake, vote_account); + for (pk, value) in vote_accounts.iter() { + if *pk != pubkey { + assert_eq!(value, &vote_accounts_hashmap[pk]); + } else { + assert_eq!(value, &other); + } + } + } }