diff --git a/perf/src/sigverify.rs b/perf/src/sigverify.rs index 846504198..7473a0e1d 100644 --- a/perf/src/sigverify.rs +++ b/perf/src/sigverify.rs @@ -8,11 +8,10 @@ use crate::cuda_runtime::PinnedVec; use crate::packet::{Packet, Packets}; use crate::perf_libs; use crate::recycler::Recycler; -use bincode::serialized_size; use rayon::ThreadPool; use solana_metrics::inc_new_counter_debug; use solana_rayon_threadlimit::get_thread_count; -use solana_sdk::message::MessageHeader; +use solana_sdk::message::MESSAGE_HEADER_LENGTH; use solana_sdk::pubkey::Pubkey; use solana_sdk::short_vec::decode_len; use solana_sdk::signature::Signature; @@ -75,6 +74,12 @@ impl std::convert::From> for PacketError { } } +impl std::convert::From for PacketError { + fn from(_e: std::num::TryFromIntError) -> Self { + Self::InvalidLen + } +} + pub fn init() { if let Some(api) = perf_libs::api() { unsafe { @@ -147,17 +152,14 @@ pub fn batch_size(batches: &[Packets]) -> usize { // internal function to be unit-tested; should be used only by get_packet_offsets fn do_get_packet_offsets( packet: &Packet, - current_offset: u32, + current_offset: usize, ) -> Result { - let message_header_size = serialized_size(&MessageHeader::default()).unwrap() as usize; // should have at least 1 signature, sig lengths and the message header - let min_packet_size = 1usize + let _ = 1usize .checked_add(size_of::()) - .and_then(|v| v.checked_add(message_header_size)) + .and_then(|v| v.checked_add(MESSAGE_HEADER_LENGTH)) + .filter(|v| *v <= packet.meta.size) .ok_or(PacketError::InvalidLen)?; - if min_packet_size > packet.meta.size { - return Err(PacketError::InvalidLen); - } // read the length of Transaction.signatures (serialized with short_vec) let (sig_len_untrusted, sig_size) = @@ -170,76 +172,69 @@ fn do_get_packet_offsets( .and_then(|v| v.checked_add(sig_size)) .ok_or(PacketError::InvalidLen)?; + let msg_start_offset_plus_one = msg_start_offset + .checked_add(1) + .ok_or(PacketError::InvalidLen)?; + // Packet should have data at least for signatures, MessageHeader, 1 byte for Message.account_keys.len - let min_message_end_offset = msg_start_offset - .checked_add(message_header_size) - .and_then(|v| v.checked_add(1)) + let _ = msg_start_offset_plus_one + .checked_add(MESSAGE_HEADER_LENGTH) + .filter(|v| *v <= packet.meta.size) .ok_or(PacketError::InvalidSignatureLen)?; - if min_message_end_offset > packet.meta.size { - return Err(PacketError::InvalidSignatureLen); - } // read MessageHeader.num_required_signatures (serialized with u8) - let sig_len_maybe_trusted = packet.data[msg_start_offset] as usize; + let sig_len_maybe_trusted = packet.data[msg_start_offset]; let message_account_keys_len_offset = msg_start_offset - .checked_add(message_header_size) + .checked_add(MESSAGE_HEADER_LENGTH) .ok_or(PacketError::InvalidLen)?; // This reads and compares the MessageHeader num_required_signatures and // num_readonly_signed_accounts bytes. If num_required_signatures is not larger than // num_readonly_signed_accounts, the first account is not debitable, and cannot be charged // required transaction fees. - let readonly_signer_offset = msg_start_offset - .checked_add(1) - .ok_or(PacketError::InvalidLen)?; - if packet.data[msg_start_offset] <= packet.data[readonly_signer_offset] { + let readonly_signer_offset = msg_start_offset_plus_one; + if sig_len_maybe_trusted <= packet.data[readonly_signer_offset] { return Err(PacketError::PayerNotWritable); } + if usize::from(sig_len_maybe_trusted) != sig_len_untrusted { + return Err(PacketError::MismatchSignatureLen); + } + // read the length of Message.account_keys (serialized with short_vec) let (pubkey_len, pubkey_len_size) = decode_len(&packet.data[message_account_keys_len_offset..]) .map_err(|_| PacketError::InvalidShortVec)?; - let min_account_keys_end_offset = pubkey_len - .checked_mul(size_of::()) - .and_then(|v| v.checked_add(message_account_keys_len_offset)) - .and_then(|v| v.checked_add(pubkey_len_size)) + let pubkey_start = message_account_keys_len_offset + .checked_add(pubkey_len_size) .ok_or(PacketError::InvalidPubkeyLen)?; - if min_account_keys_end_offset > packet.meta.size { - return Err(PacketError::InvalidPubkeyLen); - } - let sig_start = usize::try_from(current_offset) - .ok() - .and_then(|v| v.checked_add(sig_size)) - .ok_or(PacketError::InvalidLen)?; - let msg_start = usize::try_from(current_offset) - .ok() - .and_then(|v| v.checked_add(msg_start_offset)) - .ok_or(PacketError::InvalidLen)?; - let pubkey_start = msg_start - .checked_add(message_header_size) - .and_then(|v| v.checked_add(pubkey_len_size)) - .ok_or(PacketError::InvalidLen)?; + let _ = pubkey_len + .checked_mul(size_of::()) + .and_then(|v| v.checked_add(pubkey_start)) + .filter(|v| *v <= packet.meta.size) + .ok_or(PacketError::InvalidPubkeyLen)?; - if sig_len_maybe_trusted != sig_len_untrusted { - return Err(PacketError::MismatchSignatureLen); - } - - fn to_u32(value: usize) -> Result { - u32::try_from(value).map_err(|_| PacketError::InvalidLen) - } + let sig_start = current_offset + .checked_add(sig_size) + .ok_or(PacketError::InvalidLen)?; + let msg_start = current_offset + .checked_add(msg_start_offset) + .ok_or(PacketError::InvalidLen)?; + let pubkey_start = current_offset + .checked_add(pubkey_start) + .ok_or(PacketError::InvalidLen)?; Ok(PacketOffsets::new( - to_u32(sig_len_untrusted)?, - to_u32(sig_start)?, - to_u32(msg_start)?, - to_u32(pubkey_start)?, + u32::try_from(sig_len_untrusted)?, + u32::try_from(sig_start)?, + u32::try_from(msg_start)?, + u32::try_from(pubkey_start)?, )) } -fn get_packet_offsets(packet: &Packet, current_offset: u32) -> PacketOffsets { +fn get_packet_offsets(packet: &Packet, current_offset: usize) -> PacketOffsets { let unsanitized_packet_offsets = do_get_packet_offsets(packet, current_offset); if let Ok(offsets) = unsanitized_packet_offsets { offsets @@ -259,13 +254,11 @@ pub fn generate_offsets(batches: &[Packets], recycler: &Recycler) -> T msg_start_offsets.set_pinnable(); let mut msg_sizes: PinnedVec<_> = recycler.allocate().unwrap(); msg_sizes.set_pinnable(); - let mut current_packet: u32 = 0; + let mut current_offset: usize = 0; let mut v_sig_lens = Vec::new(); batches.iter().for_each(|p| { let mut sig_lens = Vec::new(); p.packets.iter().for_each(|packet| { - let current_offset = current_packet.saturating_mul(size_of::() as u32); - let packet_offsets = get_packet_offsets(packet, current_offset); sig_lens.push(packet_offsets.sig_len); @@ -274,6 +267,7 @@ pub fn generate_offsets(batches: &[Packets], recycler: &Recycler) -> T let mut pubkey_offset = packet_offsets.pubkey_start; let mut sig_offset = packet_offsets.sig_start; + let msg_size = current_offset.saturating_add(packet.meta.size) as u32; for _ in 0..packet_offsets.sig_len { signature_offsets.push(sig_offset); sig_offset = sig_offset.saturating_add(size_of::() as u32); @@ -283,12 +277,10 @@ pub fn generate_offsets(batches: &[Packets], recycler: &Recycler) -> T msg_start_offsets.push(packet_offsets.msg_start); - let msg_size = current_offset - .saturating_add(packet.meta.size as u32) - .saturating_sub(packet_offsets.msg_start); + let msg_size = msg_size.saturating_sub(packet_offsets.msg_start); msg_sizes.push(msg_size); } - current_packet = current_packet.saturating_add(1); + current_offset = current_offset.saturating_add(size_of::()); }); v_sig_lens.push(sig_lens); }); @@ -689,7 +681,7 @@ mod tests { // Just like get_packet_offsets, but not returning redundant information. fn get_packet_offsets_from_tx(tx: Transaction, current_offset: u32) -> PacketOffsets { let packet = sigverify::make_packet_from_transaction(tx); - let packet_offsets = sigverify::get_packet_offsets(&packet, current_offset); + let packet_offsets = sigverify::get_packet_offsets(&packet, current_offset as usize); PacketOffsets::new( packet_offsets.sig_len, packet_offsets.sig_start - current_offset, diff --git a/sdk/program/src/message.rs b/sdk/program/src/message.rs index f09bf7420..2011c33e2 100644 --- a/sdk/program/src/message.rs +++ b/sdk/program/src/message.rs @@ -142,6 +142,8 @@ fn get_program_ids(instructions: &[Instruction]) -> Vec { .collect() } +pub const MESSAGE_HEADER_LENGTH: usize = 3; + #[frozen_abi(digest = "BVC5RhetsNpheGipt5rUrkR6RDDUHtD5sCLK1UjymL4S")] #[derive(Serialize, Deserialize, Default, Debug, PartialEq, Eq, Clone, AbiExample)] #[serde(rename_all = "camelCase")] @@ -941,4 +943,12 @@ mod tests { assert!(message.is_non_loader_key(&key1, 1)); assert!(!message.is_non_loader_key(&loader2, 2)); } + + #[test] + fn test_message_header_len_constant() { + assert_eq!( + bincode::serialized_size(&MessageHeader::default()).unwrap() as usize, + MESSAGE_HEADER_LENGTH + ); + } }