diff --git a/core/src/consensus.rs b/core/src/consensus.rs index 465680b2b..3b1622274 100644 --- a/core/src/consensus.rs +++ b/core/src/consensus.rs @@ -1404,6 +1404,7 @@ pub mod test { let mut account = AccountSharedData::from(Account { data: vec![0; VoteState::size_of()], lamports: *lamports, + owner: solana_vote_program::id(), ..Account::default() }); let mut vote_state = VoteState::default(); @@ -1417,7 +1418,7 @@ pub mod test { .expect("serialize state"); ( solana_sdk::pubkey::new_rand(), - (*lamports, VoteAccount::from(account)), + (*lamports, VoteAccount::try_from(account).unwrap()), ) }) .collect() diff --git a/core/src/progress_map.rs b/core/src/progress_map.rs index 69a22a9a1..950f98627 100644 --- a/core/src/progress_map.rs +++ b/core/src/progress_map.rs @@ -695,7 +695,15 @@ impl ProgressMap { #[cfg(test)] mod test { - use {super::*, solana_runtime::vote_account::VoteAccount}; + use {super::*, solana_runtime::vote_account::VoteAccount, solana_sdk::account::Account}; + + fn new_test_vote_account() -> VoteAccount { + let account = Account { + owner: solana_vote_program::id(), + ..Account::default() + }; + VoteAccount::try_from(account).unwrap() + } #[test] fn test_add_vote_pubkey() { @@ -730,7 +738,7 @@ mod test { let epoch_vote_accounts: HashMap<_, _> = vote_account_pubkeys .iter() .skip(num_vote_accounts - staked_vote_accounts) - .map(|pubkey| (*pubkey, (1, VoteAccount::default()))) + .map(|pubkey| (*pubkey, (1, new_test_vote_account()))) .collect(); let mut stats = PropagatedStats::default(); @@ -772,7 +780,7 @@ mod test { let epoch_vote_accounts: HashMap<_, _> = vote_account_pubkeys .iter() .skip(num_vote_accounts - staked_vote_accounts) - .map(|pubkey| (*pubkey, (1, VoteAccount::default()))) + .map(|pubkey| (*pubkey, (1, new_test_vote_account()))) .collect(); stats.add_node_pubkey_internal(&node_pubkey, &vote_account_pubkeys, &epoch_vote_accounts); assert!(stats.propagated_node_ids.contains(&node_pubkey)); diff --git a/ledger/src/blockstore_processor.rs b/ledger/src/blockstore_processor.rs index 38a9b7571..5bb94773a 100644 --- a/ledger/src/blockstore_processor.rs +++ b/ledger/src/blockstore_processor.rs @@ -3780,7 +3780,7 @@ pub mod tests { VoteState::serialize(&versioned, vote_account.data_as_mut_slice()).unwrap(); ( solana_sdk::pubkey::new_rand(), - (stake, VoteAccount::from(vote_account)), + (stake, VoteAccount::try_from(vote_account).unwrap()), ) }) .collect() diff --git a/ledger/src/lib.rs b/ledger/src/lib.rs index 588d57fdd..b557e75e3 100644 --- a/ledger/src/lib.rs +++ b/ledger/src/lib.rs @@ -28,7 +28,7 @@ pub mod shred_stats; mod shredder; pub mod sigverify_shreds; pub mod slot_stats; -pub mod staking_utils; +mod staking_utils; #[macro_use] extern crate solana_metrics; diff --git a/ledger/src/staking_utils.rs b/ledger/src/staking_utils.rs index b2d29a88a..375630e31 100644 --- a/ledger/src/staking_utils.rs +++ b/ledger/src/staking_utils.rs @@ -113,11 +113,12 @@ pub(crate) mod tests { let account = AccountSharedData::new_data( rng.gen(), // lamports &VoteStateVersions::new_current(vote_state), - &Pubkey::new_unique(), // owner + &solana_vote_program::id(), // owner ) .unwrap(); let vote_pubkey = Pubkey::new_unique(); - (vote_pubkey, (stake, VoteAccount::from(account))) + let vote_account = VoteAccount::try_from(account).unwrap(); + (vote_pubkey, (stake, vote_account)) }); let result = vote_accounts.collect::().staked_nodes(); assert_eq!(result.len(), 2); diff --git a/runtime/src/bank.rs b/runtime/src/bank.rs index b6c21cc0f..502d91624 100644 --- a/runtime/src/bank.rs +++ b/runtime/src/bank.rs @@ -2910,7 +2910,7 @@ impl Bank { { vote_accounts_cache_miss_count.fetch_add(1, Relaxed); } - Some(VoteAccount::from(account)) + VoteAccount::try_from(account).ok() }; let invalid_vote_keys = DashMap::::new(); let make_vote_delegations_entry = |vote_pubkey| { diff --git a/runtime/src/epoch_stakes.rs b/runtime/src/epoch_stakes.rs index 6214005da..11b01165a 100644 --- a/runtime/src/epoch_stakes.rs +++ b/runtime/src/epoch_stakes.rs @@ -190,10 +190,8 @@ pub(crate) mod tests { .iter() .flat_map(|(_, vote_accounts)| { vote_accounts.iter().map(|v| { - ( - v.vote_account, - (stake_per_account, VoteAccount::from(v.account.clone())), - ) + let vote_account = VoteAccount::try_from(v.account.clone()).unwrap(); + (v.vote_account, (stake_per_account, vote_account)) }) }) .collect(); diff --git a/runtime/src/stake_account.rs b/runtime/src/stake_account.rs index 75f73c4b8..9ef7e8302 100644 --- a/runtime/src/stake_account.rs +++ b/runtime/src/stake_account.rs @@ -30,8 +30,8 @@ pub enum Error { InstructionError(#[from] InstructionError), #[error("Invalid delegation: {0:?}")] InvalidDelegation(StakeState), - #[error("Invalid stake account owner: {owner:?}")] - InvalidOwner { owner: Pubkey }, + #[error("Invalid stake account owner: {0}")] + InvalidOwner(/*owner:*/ Pubkey), } impl StakeAccount { @@ -59,9 +59,7 @@ impl TryFrom for StakeAccount<()> { type Error = Error; fn try_from(account: AccountSharedData) -> Result { if account.owner() != &solana_stake_program::id() { - return Err(Error::InvalidOwner { - owner: *account.owner(), - }); + return Err(Error::InvalidOwner(*account.owner())); } let stake_state = account.state()?; Ok(Self { diff --git a/runtime/src/stakes.rs b/runtime/src/stakes.rs index a57fc8258..e8b9ea051 100644 --- a/runtime/src/stakes.rs +++ b/runtime/src/stakes.rs @@ -78,13 +78,20 @@ impl StakesCache { debug_assert_ne!(account.lamports(), 0u64); if solana_vote_program::check_id(owner) { if VoteState::is_correct_size_and_initialized(account.data()) { - let vote_account = VoteAccount::from(account.clone()); - { - // Called to eagerly deserialize vote state - let _res = vote_account.vote_state(); + match VoteAccount::try_from(account.clone()) { + Ok(vote_account) => { + { + // Called to eagerly deserialize vote state + let _res = vote_account.vote_state(); + } + let mut stakes = self.0.write().unwrap(); + stakes.upsert_vote_account(pubkey, vote_account); + } + Err(_) => { + let mut stakes = self.0.write().unwrap(); + stakes.remove_vote_account(pubkey) + } } - let mut stakes = self.0.write().unwrap(); - stakes.upsert_vote_account(pubkey, vote_account); } else { let mut stakes = self.0.write().unwrap(); stakes.remove_vote_account(pubkey) diff --git a/runtime/src/vote_account.rs b/runtime/src/vote_account.rs index 35568e563..195abe508 100644 --- a/runtime/src/vote_account.rs +++ b/runtime/src/vote_account.rs @@ -13,21 +13,31 @@ use { iter::FromIterator, sync::{Arc, Once, RwLock, RwLockReadGuard}, }, + thiserror::Error, }; // The value here does not matter. It will be overwritten // at the first call to VoteAccount::vote_state(). -const INVALID_VOTE_STATE: Result = - Err(InstructionError::InvalidAccountData); +const INVALID_VOTE_STATE: Result = Err(Error::InstructionError( + InstructionError::InvalidAccountData, +)); -#[derive(Clone, Debug, Default, PartialEq, AbiExample, Deserialize)] -#[serde(from = "Account")] +#[derive(Clone, Debug, PartialEq, AbiExample, Deserialize)] +#[serde(try_from = "Account")] pub struct VoteAccount(Arc); +#[derive(Debug, Error)] +pub enum Error { + #[error(transparent)] + InstructionError(#[from] InstructionError), + #[error("Invalid vote account owner: {0}")] + InvalidOwner(/*owner:*/ Pubkey), +} + #[derive(Debug, AbiExample)] struct VoteAccountInner { account: Account, - vote_state: RwLock>, + vote_state: RwLock>, vote_state_once: Once, } @@ -59,11 +69,11 @@ impl VoteAccount { self.0.account.owner() } - pub fn vote_state(&self) -> RwLockReadGuard> { + pub fn vote_state(&self) -> RwLockReadGuard> { let inner = &self.0; inner.vote_state_once.call_once(|| { let vote_state = VoteState::deserialize(inner.account.data()); - *inner.vote_state.write().unwrap() = vote_state; + *inner.vote_state.write().unwrap() = vote_state.map_err(Error::from); }); inner.vote_state.read().unwrap() } @@ -187,9 +197,10 @@ impl Serialize for VoteAccount { } } -impl From for VoteAccount { - fn from(account: AccountSharedData) -> Self { - Self::from(Account::from(account)) +impl TryFrom for VoteAccount { + type Error = Error; + fn try_from(account: AccountSharedData) -> Result { + Self::try_from(Account::from(account)) } } @@ -199,29 +210,25 @@ impl From for AccountSharedData { } } -impl From for VoteAccount { - fn from(account: Account) -> Self { - Self(Arc::new(VoteAccountInner::from(account))) +impl TryFrom for VoteAccount { + type Error = Error; + fn try_from(account: Account) -> Result { + let vote_account = VoteAccountInner::try_from(account)?; + Ok(Self(Arc::new(vote_account))) } } -impl From for VoteAccountInner { - fn from(account: Account) -> Self { - Self { +impl TryFrom for VoteAccountInner { + type Error = Error; + fn try_from(account: Account) -> Result { + if !solana_vote_program::check_id(account.owner()) { + return Err(Error::InvalidOwner(*account.owner())); + } + Ok(Self { account, vote_state: RwLock::new(INVALID_VOTE_STATE), vote_state_once: Once::new(), - } - } -} - -impl Default for VoteAccountInner { - fn default() -> Self { - Self { - account: Account::default(), - vote_state: RwLock::new(INVALID_VOTE_STATE), - vote_state_once: Once::new(), - } + }) } } @@ -356,7 +363,7 @@ mod tests { let account = Account::new_data( rng.gen(), // lamports &VoteStateVersions::new_current(vote_state.clone()), - &Pubkey::new_unique(), // owner + &solana_vote_program::id(), // owner ) .unwrap(); (account, vote_state) @@ -371,7 +378,8 @@ mod tests { let node = nodes[rng.gen_range(0, nodes.len())]; let (account, _) = new_rand_vote_account(rng, Some(node)); let stake = rng.gen_range(0, 997); - (Pubkey::new_unique(), (stake, VoteAccount::from(account))) + let vote_account = VoteAccount::try_from(account).unwrap(); + (Pubkey::new_unique(), (stake, vote_account)) }) } @@ -399,7 +407,7 @@ mod tests { let mut rng = rand::thread_rng(); let (account, vote_state) = new_rand_vote_account(&mut rng, None); let lamports = account.lamports(); - let vote_account = VoteAccount::from(account); + let vote_account = VoteAccount::try_from(account).unwrap(); assert_eq!(lamports, vote_account.lamports()); assert_eq!(vote_state, *vote_account.vote_state().as_ref().unwrap()); // 2nd call to .vote_state() should return the cached value. @@ -410,7 +418,7 @@ mod tests { fn test_vote_account_serialize() { let mut rng = rand::thread_rng(); let (account, vote_state) = new_rand_vote_account(&mut rng, None); - let vote_account = VoteAccount::from(account.clone()); + let vote_account = VoteAccount::try_from(account.clone()).unwrap(); assert_eq!(vote_state, *vote_account.vote_state().as_ref().unwrap()); // Assert than VoteAccount has the same wire format as Account. assert_eq!( @@ -424,7 +432,7 @@ mod tests { let mut rng = rand::thread_rng(); let (account, vote_state) = new_rand_vote_account(&mut rng, None); let data = bincode::serialize(&account).unwrap(); - let vote_account = VoteAccount::from(account); + let vote_account = VoteAccount::try_from(account).unwrap(); assert_eq!(vote_state, *vote_account.vote_state().as_ref().unwrap()); let other_vote_account: VoteAccount = bincode::deserialize(&data).unwrap(); assert_eq!(vote_account, other_vote_account); @@ -438,7 +446,7 @@ mod tests { fn test_vote_account_round_trip() { let mut rng = rand::thread_rng(); let (account, vote_state) = new_rand_vote_account(&mut rng, None); - let vote_account = VoteAccount::from(account); + let vote_account = VoteAccount::try_from(account).unwrap(); assert_eq!(vote_state, *vote_account.vote_state().as_ref().unwrap()); let data = bincode::serialize(&vote_account).unwrap(); let other_vote_account: VoteAccount = bincode::deserialize(&data).unwrap();