diff --git a/sdk/program/src/serialize_utils/cursor.rs b/sdk/program/src/serialize_utils/cursor.rs index 006673738..9d33d1e48 100644 --- a/sdk/program/src/serialize_utils/cursor.rs +++ b/sdk/program/src/serialize_utils/cursor.rs @@ -61,6 +61,15 @@ pub(crate) fn read_pubkey>( Ok(Pubkey::from(buf)) } +pub(crate) fn read_bool>(cursor: &mut Cursor) -> Result { + let byte = read_u8(cursor)?; + match byte { + 0 => Ok(false), + 1 => Ok(true), + _ => Err(InstructionError::InvalidAccountData), + } +} + #[cfg(test)] mod test { use {super::*, rand::Rng, std::fmt::Debug}; @@ -115,6 +124,12 @@ mod test { } } + #[test] + fn test_read_bool() { + test_read(read_bool, false); + test_read(read_bool, true); + } + fn test_read( reader: fn(&mut Cursor>) -> Result, test_value: T, diff --git a/sdk/program/src/vote/state/mod.rs b/sdk/program/src/vote/state/mod.rs index 9eddce4d9..8cfcd0ef1 100644 --- a/sdk/program/src/vote/state/mod.rs +++ b/sdk/program/src/vote/state/mod.rs @@ -18,7 +18,7 @@ use { sysvar::clock::Clock, vote::{authorized_voters::AuthorizedVoters, error::VoteError}, }, - bincode::{serialize_into, ErrorKind}, + bincode::{serialize_into, serialized_size, ErrorKind}, serde_derive::{Deserialize, Serialize}, std::{collections::VecDeque, fmt::Debug, io::Cursor}, }; @@ -399,6 +399,12 @@ impl VoteState { input: &[u8], vote_state: &mut VoteState, ) -> Result<(), InstructionError> { + let minimum_size = + serialized_size(vote_state).map_err(|_| InstructionError::InvalidAccountData)?; + if (input.len() as u64) < minimum_size { + return Err(InstructionError::InvalidAccountData); + } + let mut cursor = Cursor::new(input); let variant = read_u32(&mut cursor)?; @@ -410,7 +416,13 @@ impl VoteState { // Current. the only difference from V1_14_11 is the addition of a slot-latency to each vote 2 => deserialize_vote_state_into(&mut cursor, vote_state, true), _ => Err(InstructionError::InvalidAccountData), + }?; + + if cursor.position() > input.len() as u64 { + return Err(InstructionError::InvalidAccountData); } + + Ok(()) } pub fn serialize( @@ -886,7 +898,7 @@ mod tests { // variant // provide 4x the minimum struct size in bytes to ensure we typically touch every field - let struct_bytes_x4 = std::mem::size_of::() * 4; + let struct_bytes_x4 = std::mem::size_of::() * 4; for _ in 0..1000 { let raw_data: Vec = (0..struct_bytes_x4).map(|_| rand::random::()).collect(); let mut unstructured = Unstructured::new(&raw_data); @@ -911,7 +923,7 @@ mod tests { assert_eq!(e, InstructionError::InvalidAccountData); // variant - let serialized_len_x4 = bincode::serialized_size(&test_vote_state).unwrap() * 4; + let serialized_len_x4 = serialized_size(&test_vote_state).unwrap() * 4; let mut rng = rand::thread_rng(); for _ in 0..1000 { let raw_data_length = rng.gen_range(1..serialized_len_x4); @@ -1262,7 +1274,7 @@ mod tests { fn test_vote_state_size_of() { let vote_state = VoteState::get_max_sized_vote_state(); let vote_state = VoteStateVersions::new_current(vote_state); - let size = bincode::serialized_size(&vote_state).unwrap(); + let size = serialized_size(&vote_state).unwrap(); assert_eq!(VoteState::size_of() as u64, size); } diff --git a/sdk/program/src/vote/state/vote_state_deserialize.rs b/sdk/program/src/vote/state/vote_state_deserialize.rs index b93f1c744..b457395cc 100644 --- a/sdk/program/src/vote/state/vote_state_deserialize.rs +++ b/sdk/program/src/vote/state/vote_state_deserialize.rs @@ -5,6 +5,7 @@ use { serialize_utils::cursor::*, vote::state::{BlockTimestamp, LandedVote, Lockout, VoteState, MAX_ITEMS}, }, + bincode::serialized_size, std::io::Cursor, }; @@ -66,34 +67,46 @@ fn read_prior_voters_into>( cursor: &mut Cursor, vote_state: &mut VoteState, ) -> Result<(), InstructionError> { - let mut encountered_null_voter = false; - for i in 0..MAX_ITEMS { - let prior_voter = read_pubkey(cursor)?; - let from_epoch = read_u64(cursor)?; - let until_epoch = read_u64(cursor)?; - let item = (prior_voter, from_epoch, until_epoch); + // record our position at the start of the struct + let prior_voters_position = cursor.position(); - if item == (Pubkey::default(), 0, 0) { - encountered_null_voter = true; - } else if encountered_null_voter { - // `prior_voters` should never be sparse - return Err(InstructionError::InvalidAccountData); - } else { - vote_state.prior_voters.buf[i] = item; + // `serialized_size()` must be used over `mem::size_of()` because of alignment + let is_empty_position = serialized_size(&vote_state.prior_voters) + .ok() + .and_then(|v| v.checked_add(prior_voters_position)) + .and_then(|v| v.checked_sub(1)) + .ok_or(InstructionError::InvalidAccountData)?; + + // move to the end, to check if we need to parse the data + cursor.set_position(is_empty_position); + + // if empty, we already read past the end of this struct and need to do no further work + // otherwise we go back to the start and proceed to decode the data + let is_empty = read_bool(cursor)?; + if !is_empty { + cursor.set_position(prior_voters_position); + + let mut encountered_null_voter = false; + for i in 0..MAX_ITEMS { + let prior_voter = read_pubkey(cursor)?; + let from_epoch = read_u64(cursor)?; + let until_epoch = read_u64(cursor)?; + let item = (prior_voter, from_epoch, until_epoch); + + if item == (Pubkey::default(), 0, 0) { + encountered_null_voter = true; + } else if encountered_null_voter { + // `prior_voters` should never be sparse + return Err(InstructionError::InvalidAccountData); + } else { + vote_state.prior_voters.buf[i] = item; + } } + + vote_state.prior_voters.idx = read_u64(cursor)? as usize; + vote_state.prior_voters.is_empty = read_bool(cursor)?; } - let idx = read_u64(cursor)? as usize; - vote_state.prior_voters.idx = idx; - - let is_empty_byte = read_u8(cursor)?; - let is_empty = match is_empty_byte { - 0 => false, - 1 => true, - _ => return Err(InstructionError::InvalidAccountData), - }; - vote_state.prior_voters.is_empty = is_empty; - Ok(()) }