From 5dbf7d8f91d343ca996b30dc2981691a6c9df443 Mon Sep 17 00:00:00 2001 From: behzad nouri Date: Fri, 3 Jun 2022 01:05:06 +0000 Subject: [PATCH] removes raw indexing into packet data (#25554) Packets are at the boundary of the system where, vast majority of the time, they are received from an untrusted source. Raw indexing into the data buffer can open attack vectors if the offsets are invalid. Validating offsets beforehand is verbose and error prone. The commit updates Packet::data() api to take a SliceIndex and always to return an Option. The call-sites are so forced to explicitly handle the case where the offsets are invalid. --- bench-streamer/src/main.rs | 3 +- core/src/banking_stage.rs | 2 +- core/src/packet_hasher.rs | 2 +- core/src/serve_repair.rs | 6 +- core/src/unprocessed_packet_batches.rs | 8 +- core/src/window_service.rs | 2 +- gossip/tests/gossip.rs | 2 +- ledger/src/shred.rs | 25 ++-- ledger/src/sigverify_shreds.rs | 2 +- perf/src/sigverify.rs | 173 +++++++++++++------------ sdk/src/packet.rs | 20 +-- streamer/src/nonblocking/sendmmsg.rs | 10 +- streamer/src/packet.rs | 4 +- streamer/src/sendmmsg.rs | 7 +- streamer/src/streamer.rs | 5 +- 15 files changed, 151 insertions(+), 120 deletions(-) diff --git a/bench-streamer/src/main.rs b/bench-streamer/src/main.rs index f6079bc5c1..f03d29bf74 100644 --- a/bench-streamer/src/main.rs +++ b/bench-streamer/src/main.rs @@ -37,7 +37,8 @@ fn producer(addr: &SocketAddr, exit: Arc) -> JoinHandle<()> { for p in packet_batch.iter() { let a = p.meta.socket_addr(); assert!(p.meta.size <= PACKET_DATA_SIZE); - send.send_to(p.data(), &a).unwrap(); + let data = p.data(..).unwrap_or_default(); + send.send_to(data, &a).unwrap(); num += 1; } assert_eq!(num, 10); diff --git a/core/src/banking_stage.rs b/core/src/banking_stage.rs index b3e90b0c47..071c665fb2 100644 --- a/core/src/banking_stage.rs +++ b/core/src/banking_stage.rs @@ -541,7 +541,7 @@ impl BankingStage { .iter() .filter_map(|p| { if !p.meta.forwarded() && data_budget.take(p.meta.size) { - Some(p.data().to_vec()) + Some(p.data(..)?.to_vec()) } else { None } diff --git a/core/src/packet_hasher.rs b/core/src/packet_hasher.rs index 8fcae1d73b..31280015c3 100644 --- a/core/src/packet_hasher.rs +++ b/core/src/packet_hasher.rs @@ -26,7 +26,7 @@ impl Default for PacketHasher { impl PacketHasher { pub(crate) fn hash_packet(&self, packet: &Packet) -> u64 { - self.hash_data(packet.data()) + self.hash_data(packet.data(..).unwrap_or_default()) } pub(crate) fn hash_shred(&self, shred: &Shred) -> u64 { diff --git a/core/src/serve_repair.rs b/core/src/serve_repair.rs index f9289c65c1..d17c1978e1 100644 --- a/core/src/serve_repair.rs +++ b/core/src/serve_repair.rs @@ -814,7 +814,7 @@ mod tests { .into_iter() .filter_map(|p| { assert_eq!(repair_response::nonce(p).unwrap(), nonce); - Shred::new_from_serialized_shred(p.data().to_vec()).ok() + Shred::new_from_serialized_shred(p.data(..).unwrap().to_vec()).ok() }) .collect(); assert!(!rv.is_empty()); @@ -898,7 +898,7 @@ mod tests { .into_iter() .filter_map(|p| { assert_eq!(repair_response::nonce(p).unwrap(), nonce); - Shred::new_from_serialized_shred(p.data().to_vec()).ok() + Shred::new_from_serialized_shred(p.data(..).unwrap().to_vec()).ok() }) .collect(); assert_eq!(rv[0].index(), 1); @@ -1347,7 +1347,7 @@ mod tests { fn verify_responses<'a>(request: &ShredRepairType, packets: impl Iterator) { for packet in packets { - let shred_payload = packet.data().to_vec(); + let shred_payload = packet.data(..).unwrap().to_vec(); let shred = Shred::new_from_serialized_shred(shred_payload).unwrap(); request.verify_response(&shred); } diff --git a/core/src/unprocessed_packet_batches.rs b/core/src/unprocessed_packet_batches.rs index 277ba3adf7..bfc8852111 100644 --- a/core/src/unprocessed_packet_batches.rs +++ b/core/src/unprocessed_packet_batches.rs @@ -368,12 +368,14 @@ pub fn deserialize_packets<'a>( /// Read the transaction message from packet data pub fn packet_message(packet: &Packet) -> Result<&[u8], DeserializedPacketError> { - let (sig_len, sig_size) = - decode_shortu16_len(packet.data()).map_err(DeserializedPacketError::ShortVecError)?; + let (sig_len, sig_size) = packet + .data(..) + .and_then(|bytes| decode_shortu16_len(bytes).ok()) + .ok_or(DeserializedPacketError::ShortVecError(()))?; sig_len .checked_mul(size_of::()) .and_then(|v| v.checked_add(sig_size)) - .and_then(|msg_start| packet.data().get(msg_start..)) + .and_then(|msg_start| packet.data(msg_start..)) .ok_or(DeserializedPacketError::SignatureOverflowed(sig_size)) } diff --git a/core/src/window_service.rs b/core/src/window_service.rs index 0bb9d50381..7e582fa4ea 100644 --- a/core/src/window_service.rs +++ b/core/src/window_service.rs @@ -363,7 +363,7 @@ where inc_new_counter_debug!("streamer-recv_window-invalid_or_unnecessary_packet", 1); return None; } - let serialized_shred = packet.data().to_vec(); + let serialized_shred = packet.data(..)?.to_vec(); let shred = Shred::new_from_serialized_shred(serialized_shred).ok()?; if !shred_filter(&shred, working_bank.clone(), last_root) { return None; diff --git a/gossip/tests/gossip.rs b/gossip/tests/gossip.rs index b57a03cb94..f3e136cdba 100644 --- a/gossip/tests/gossip.rs +++ b/gossip/tests/gossip.rs @@ -260,7 +260,7 @@ pub fn cluster_info_retransmit() { let retransmit_peers: Vec<_> = peers.iter().collect(); retransmit_to( &retransmit_peers, - p.data(), + p.data(..).unwrap(), &tn1, false, &SocketAddrSpace::Unspecified, diff --git a/ledger/src/shred.rs b/ledger/src/shred.rs index f7b80f6157..5462797a05 100644 --- a/ledger/src/shred.rs +++ b/ledger/src/shred.rs @@ -509,7 +509,7 @@ pub mod layout { use {super::*, std::ops::Range}; fn get_shred_size(packet: &Packet) -> Option { - let size = packet.data().len(); + let size = packet.data(..)?.len(); if packet.meta.repair() { size.checked_sub(SIZE_OF_NONCE) } else { @@ -519,7 +519,7 @@ pub mod layout { pub fn get_shred(packet: &Packet) -> Option<&[u8]> { let size = get_shred_size(packet)?; - let shred = packet.data().get(..size)?; + let shred = packet.data(..size)?; // Should at least have a signature. (size >= SIZE_OF_SIGNATURE).then(|| shred) } @@ -826,7 +826,7 @@ mod tests { let shred = Shred::new_from_data(10, 0, 1000, &[1, 2, 3], ShredFlags::empty(), 0, 1, 0); let mut packet = Packet::default(); shred.copy_to_packet(&mut packet); - let shred_res = Shred::new_from_serialized_shred(packet.data().to_vec()); + let shred_res = Shred::new_from_serialized_shred(packet.data(..).unwrap().to_vec()); assert_matches!( shred.parent(), Err(Error::InvalidParentOffset { @@ -1029,9 +1029,12 @@ mod tests { assert_eq!(shred, Shred::new_from_serialized_shred(payload).unwrap()); assert_eq!( shred.reference_tick(), - layout::get_reference_tick(packet.data()).unwrap() + layout::get_reference_tick(packet.data(..).unwrap()).unwrap() + ); + assert_eq!( + layout::get_slot(packet.data(..).unwrap()), + Some(shred.slot()) ); - assert_eq!(layout::get_slot(packet.data()), Some(shred.slot())); assert_eq!( get_shred_slot_index_type(&packet, &mut ShredFetchStats::default()), Some((shred.slot(), shred.index(), shred.shred_type())) @@ -1070,9 +1073,12 @@ mod tests { assert_eq!(shred, Shred::new_from_serialized_shred(payload).unwrap()); assert_eq!( shred.reference_tick(), - layout::get_reference_tick(packet.data()).unwrap() + layout::get_reference_tick(packet.data(..).unwrap()).unwrap() + ); + assert_eq!( + layout::get_slot(packet.data(..).unwrap()), + Some(shred.slot()) ); - assert_eq!(layout::get_slot(packet.data()), Some(shred.slot())); assert_eq!( get_shred_slot_index_type(&packet, &mut ShredFetchStats::default()), Some((shred.slot(), shred.index(), shred.shred_type())) @@ -1116,7 +1122,10 @@ mod tests { packet.meta.size = payload.len(); assert_eq!(shred.bytes_to_store(), payload); assert_eq!(shred, Shred::new_from_serialized_shred(payload).unwrap()); - assert_eq!(layout::get_slot(packet.data()), Some(shred.slot())); + assert_eq!( + layout::get_slot(packet.data(..).unwrap()), + Some(shred.slot()) + ); assert_eq!( get_shred_slot_index_type(&packet, &mut ShredFetchStats::default()), Some((shred.slot(), shred.index(), shred.shred_type())) diff --git a/ledger/src/sigverify_shreds.rs b/ledger/src/sigverify_shreds.rs index 163e733442..aecad26aa6 100644 --- a/ledger/src/sigverify_shreds.rs +++ b/ledger/src/sigverify_shreds.rs @@ -281,7 +281,7 @@ fn sign_shred_cpu(keypair: &Keypair, packet: &mut Packet) { packet.meta.size >= sig.end, "packet is not large enough for a signature" ); - let signature = keypair.sign_message(&packet.data()[msg]); + let signature = keypair.sign_message(packet.data(msg).unwrap()); trace!("signature {:?}", signature); packet.buffer_mut()[sig].copy_from_slice(signature.as_ref()); } diff --git a/perf/src/sigverify.rs b/perf/src/sigverify.rs index 1691b04264..6a7be319cb 100644 --- a/perf/src/sigverify.rs +++ b/perf/src/sigverify.rs @@ -12,7 +12,7 @@ use { }, ahash::AHasher, rand::{thread_rng, Rng}, - rayon::ThreadPool, + rayon::{prelude::*, ThreadPool}, solana_metrics::inc_new_counter_debug, solana_rayon_threadlimit::get_thread_count, solana_sdk::{ @@ -114,10 +114,13 @@ pub fn init() { } } -fn verify_packet(packet: &mut Packet, reject_non_vote: bool) { +/// Returns true if the signatrue on the packet verifies. +/// Caller must do packet.set_discard(true) if this returns false. +#[must_use] +fn verify_packet(packet: &mut Packet, reject_non_vote: bool) -> bool { // If this packet was already marked as discard, drop it if packet.meta.discard() { - return; + return false; } let packet_offsets = get_packet_offsets(packet, 0, reject_non_vote); @@ -126,36 +129,38 @@ fn verify_packet(packet: &mut Packet, reject_non_vote: bool) { let msg_start = packet_offsets.msg_start as usize; if packet_offsets.sig_len == 0 { - packet.meta.set_discard(true); - return; + return false; } if packet.meta.size <= msg_start { - packet.meta.set_discard(true); - return; + return false; } - let msg_end = packet.meta.size; for _ in 0..packet_offsets.sig_len { let pubkey_end = pubkey_start.saturating_add(size_of::()); - let sig_end = sig_start.saturating_add(size_of::()); - - // get_packet_offsets should ensure pubkey_end and sig_end do - // not overflow packet.meta.size - - let signature = Signature::new(&packet.data()[sig_start..sig_end]); - - if !signature.verify( - &packet.data()[pubkey_start..pubkey_end], - &packet.data()[msg_start..msg_end], - ) { - packet.meta.set_discard(true); - return; + let sig_end = match sig_start.checked_add(size_of::()) { + Some(sig_end) => sig_end, + None => return false, + }; + let signature = match packet.data(sig_start..sig_end) { + Some(signature) => Signature::new(signature), + None => return false, + }; + let pubkey = match packet.data(pubkey_start..pubkey_end) { + Some(pubkey) => pubkey, + None => return false, + }; + let message = match packet.data(msg_start..) { + Some(message) => message, + None => return false, + }; + if !signature.verify(pubkey, message) { + return false; } - pubkey_start = pubkey_end; sig_start = sig_end; } + true } pub fn count_packets_in_batches(batches: &[PacketBatch]) -> usize { @@ -202,9 +207,10 @@ fn do_get_packet_offsets( .ok_or(PacketError::InvalidLen)?; // read the length of Transaction.signatures (serialized with short_vec) - let (sig_len_untrusted, sig_size) = - decode_shortu16_len(packet.data()).map_err(|_| PacketError::InvalidShortVec)?; - + let (sig_len_untrusted, sig_size) = packet + .data(..) + .and_then(|bytes| decode_shortu16_len(bytes).ok()) + .ok_or(PacketError::InvalidShortVec)?; // Using msg_start_offset which is based on sig_len_untrusted introduces uncertainty. // Ultimately, the actual sigverify will determine the uncertainty. let msg_start_offset = sig_len_untrusted @@ -222,7 +228,9 @@ fn do_get_packet_offsets( // next byte indicates if the transaction is versioned. If the top bit // is set, the remaining bits encode a version number. If the top bit is // not set, this byte is the first byte of the message header. - let message_prefix = packet.data()[msg_start_offset]; + let message_prefix = *packet + .data(msg_start_offset) + .ok_or(PacketError::InvalidSignatureLen)?; if message_prefix & MESSAGE_VERSION_PREFIX != 0 { let version = message_prefix & !MESSAGE_VERSION_PREFIX; match version { @@ -252,8 +260,9 @@ fn do_get_packet_offsets( .ok_or(PacketError::InvalidSignatureLen)?; // read MessageHeader.num_required_signatures (serialized with u8) - let sig_len_maybe_trusted = packet.data()[msg_header_offset]; - + let sig_len_maybe_trusted = *packet + .data(msg_header_offset) + .ok_or(PacketError::InvalidSignatureLen)?; let message_account_keys_len_offset = msg_header_offset .checked_add(MESSAGE_HEADER_LENGTH) .ok_or(PacketError::InvalidSignatureLen)?; @@ -263,7 +272,11 @@ fn do_get_packet_offsets( // num_readonly_signed_accounts, the first account is not debitable, and cannot be charged // required transaction fees. let readonly_signer_offset = msg_header_offset_plus_one; - if sig_len_maybe_trusted <= packet.data()[readonly_signer_offset] { + if sig_len_maybe_trusted + <= *packet + .data(readonly_signer_offset) + .ok_or(PacketError::InvalidSignatureLen)? + { return Err(PacketError::PayerNotWritable); } @@ -272,10 +285,10 @@ fn do_get_packet_offsets( } // read the length of Message.account_keys (serialized with short_vec) - let (pubkey_len, pubkey_len_size) = - decode_shortu16_len(&packet.data()[message_account_keys_len_offset..]) - .map_err(|_| PacketError::InvalidShortVec)?; - + let (pubkey_len, pubkey_len_size) = packet + .data(message_account_keys_len_offset..) + .and_then(|bytes| decode_shortu16_len(bytes).ok()) + .ok_or(PacketError::InvalidShortVec)?; let pubkey_start = message_account_keys_len_offset .checked_add(pubkey_len_size) .ok_or(PacketError::InvalidPubkeyLen)?; @@ -311,19 +324,17 @@ fn do_get_packet_offsets( pub fn check_for_tracer_packet(packet: &mut Packet) -> bool { let first_pubkey_start: usize = TRACER_KEY_OFFSET_IN_TRANSACTION; - let maybe_first_pubkey_end = first_pubkey_start - .checked_add(size_of::()) - .filter(|v| v <= &packet.meta.size); + let first_pubkey_end = match first_pubkey_start.checked_add(size_of::()) { + Some(offset) => offset, + None => return false, + }; // Check for tracer pubkey - if let Some(first_pubkey_end) = maybe_first_pubkey_end { - let is_tracer_packet = - &packet.data()[first_pubkey_start..first_pubkey_end] == TRACER_KEY.as_ref(); - if is_tracer_packet { + match packet.data(first_pubkey_start..first_pubkey_end) { + Some(pubkey) if pubkey == TRACER_KEY.as_ref() => { packet.meta.flags |= PacketFlags::TRACER_PACKET; + true } - is_tracer_packet - } else { - false + _ => false, } } @@ -370,10 +381,10 @@ fn check_for_simple_vote_transaction( .filter(|v| *v <= packet.meta.size) .ok_or(PacketError::InvalidLen)?; - let (instruction_len, instruction_len_size) = - decode_shortu16_len(&packet.data()[instructions_len_offset..]) - .map_err(|_| PacketError::InvalidLen)?; - + let (instruction_len, instruction_len_size) = packet + .data(instructions_len_offset..) + .and_then(|bytes| decode_shortu16_len(bytes).ok()) + .ok_or(PacketError::InvalidLen)?; // skip if has more than 1 instruction if instruction_len != 1 { return Err(PacketError::InvalidProgramLen); @@ -389,7 +400,11 @@ fn check_for_simple_vote_transaction( .filter(|v| *v <= packet.meta.size) .ok_or(PacketError::InvalidLen)?; - let instruction_program_id_index: usize = usize::from(packet.data()[instruction_start]); + let instruction_program_id_index: usize = usize::from( + *packet + .data(instruction_start) + .ok_or(PacketError::InvalidLen)?, + ); if instruction_program_id_index >= packet_offsets.pubkey_len as usize { return Err(PacketError::InvalidProgramIdIndex); @@ -403,7 +418,9 @@ fn check_for_simple_vote_transaction( .checked_add(size_of::()) .ok_or(PacketError::InvalidLen)?; - if &packet.data()[instruction_program_id_start..instruction_program_id_end] + if packet + .data(instruction_program_id_start..instruction_program_id_end) + .ok_or(PacketError::InvalidLen)? == solana_sdk::vote::program::id().as_ref() { packet.meta.flags |= PacketFlags::SIMPLE_VOTE_TX; @@ -507,7 +524,7 @@ impl Deduper { /// Compute hash from packet data, returns (hash, bin_pos). fn compute_hash(&self, packet: &Packet) -> (u64, usize) { let mut hasher = AHasher::new_with_keys(self.seed.0, self.seed.1); - hasher.write(packet.data()); + hasher.write(packet.data(..).unwrap_or_default()); let h = hasher.finish(); let len = self.filter.len(); let pos = (usize::try_from(h).unwrap()).wrapping_rem(len); @@ -590,20 +607,20 @@ pub fn shrink_batches(batches: &mut Vec) { } pub fn ed25519_verify_cpu(batches: &mut [PacketBatch], reject_non_vote: bool, packet_count: usize) { - use rayon::prelude::*; debug!("CPU ECDSA for {}", packet_count); PAR_THREAD_POOL.install(|| { batches.into_par_iter().for_each(|batch| { - batch - .par_iter_mut() - .for_each(|p| verify_packet(p, reject_non_vote)) + batch.par_iter_mut().for_each(|packet| { + if !packet.meta.discard() && !verify_packet(packet, reject_non_vote) { + packet.meta.set_discard(true); + } + }) }); }); inc_new_counter_debug!("ed25519_verify_cpu", packet_count); } pub fn ed25519_verify_disabled(batches: &mut [PacketBatch]) { - use rayon::prelude::*; let packet_count = count_packets_in_batches(batches); debug!("disabled ECDSA for {}", packet_count); batches @@ -759,17 +776,20 @@ mod tests { use { super::*, crate::{ - packet::{to_packet_batches, Packet, PacketBatch, PACKETS_PER_BATCH}, + packet::{to_packet_batches, Packet, PacketBatch, PACKETS_PER_BATCH, PACKET_DATA_SIZE}, sigverify::{self, PacketOffsets}, test_tx::{new_test_vote_tx, test_multisig_tx, test_tx}, }, bincode::{deserialize, serialize}, + curve25519_dalek::{edwards::CompressedEdwardsY, scalar::Scalar}, + rand::{thread_rng, Rng}, solana_sdk::{ instruction::CompiledInstruction, message::{Message, MessageHeader}, - signature::{Keypair, Signature}, + signature::{Keypair, Signature, Signer}, transaction::Transaction, }, + std::sync::atomic::{AtomicU64, Ordering}, }; const SIG_OFFSET: usize = 1; @@ -893,8 +913,7 @@ mod tests { let res = sigverify::do_get_packet_offsets(&packet, 0); assert_eq!(res, Err(PacketError::InvalidPubkeyLen)); - verify_packet(&mut packet, false); - assert!(packet.meta.discard()); + assert!(!verify_packet(&mut packet, false)); packet.meta.set_discard(false); let mut batches = generate_packet_batches(&packet, 1, 1); @@ -906,7 +925,6 @@ mod tests { fn test_pubkey_len() { // See that the verify cannot walk off the end of the packet // trying to index into the account_keys to access pubkey. - use solana_sdk::signer::{keypair::Keypair, Signer}; solana_logger::setup(); const NUM_SIG: usize = 17; @@ -929,8 +947,7 @@ mod tests { let res = sigverify::do_get_packet_offsets(&packet, 0); assert_eq!(res, Err(PacketError::InvalidPubkeyLen)); - verify_packet(&mut packet, false); - assert!(packet.meta.discard()); + assert!(!verify_packet(&mut packet, false)); packet.meta.set_discard(false); let mut batches = generate_packet_batches(&packet, 1, 1); @@ -1022,7 +1039,7 @@ mod tests { // set message version to 0 let msg_start = legacy_offsets.msg_start as usize; - let msg_bytes = packet.data()[msg_start..].to_vec(); + let msg_bytes = packet.data(msg_start..).unwrap().to_vec(); packet.buffer_mut()[msg_start] = MESSAGE_VERSION_PREFIX; packet.meta.size += 1; let msg_end = packet.meta.size; @@ -1039,7 +1056,6 @@ mod tests { #[test] fn test_system_transaction_data_layout() { - use crate::packet::PACKET_DATA_SIZE; let mut tx0 = test_tx(); tx0.message.instructions[0].data = vec![1, 2, 3]; let message0a = tx0.message_data(); @@ -1145,7 +1161,7 @@ mod tests { // jumble some data to test failure if modify_data { - packet.buffer_mut()[20] = packet.data()[20].wrapping_add(10); + packet.buffer_mut()[20] = packet.data(20).unwrap().wrapping_add(10); } let mut batches = generate_packet_batches(&packet, n, 2); @@ -1211,7 +1227,7 @@ mod tests { let num_batches = 3; let mut batches = generate_packet_batches(&packet, n, num_batches); - packet.buffer_mut()[40] = packet.data()[40].wrapping_add(8); + packet.buffer_mut()[40] = packet.data(40).unwrap().wrapping_add(8); batches[0].push(packet); @@ -1237,7 +1253,6 @@ mod tests { #[test] fn test_verify_fuzz() { - use rand::{thread_rng, Rng}; solana_logger::setup(); let tx = test_multisig_tx(); @@ -1255,8 +1270,10 @@ mod tests { let packet = thread_rng().gen_range(0, batches[batch].len()); let offset = thread_rng().gen_range(0, batches[batch][packet].meta.size); let add = thread_rng().gen_range(0, 255); - batches[batch][packet].buffer_mut()[offset] = - batches[batch][packet].data()[offset].wrapping_add(add); + batches[batch][packet].buffer_mut()[offset] = batches[batch][packet] + .data(offset) + .unwrap() + .wrapping_add(add); } let batch_to_disable = thread_rng().gen_range(0, batches.len()); @@ -1288,13 +1305,6 @@ mod tests { #[test] fn test_get_checked_scalar() { solana_logger::setup(); - use { - curve25519_dalek::scalar::Scalar, - rand::{thread_rng, Rng}, - rayon::prelude::*, - std::sync::atomic::{AtomicU64, Ordering}, - }; - if perf_libs::api().is_none() { return; } @@ -1330,13 +1340,6 @@ mod tests { #[test] fn test_ge_small_order() { solana_logger::setup(); - use { - curve25519_dalek::edwards::CompressedEdwardsY, - rand::{thread_rng, Rng}, - rayon::prelude::*, - std::sync::atomic::{AtomicU64, Ordering}, - }; - if perf_libs::api().is_none() { return; } @@ -1530,7 +1533,7 @@ mod tests { .filter(|p| !p.meta.discard()) .for_each(|p| start.push(p.clone())) }); - start.sort_by(|a, b| a.data().cmp(b.data())); + start.sort_by(|a, b| a.data(..).cmp(&b.data(..))); let packet_count = count_valid_packets(&batches, |_| ()); shrink_batches(&mut batches); @@ -1542,7 +1545,7 @@ mod tests { .filter(|p| !p.meta.discard()) .for_each(|p| end.push(p.clone())) }); - end.sort_by(|a, b| a.data().cmp(b.data())); + end.sort_by(|a, b| a.data(..).cmp(&b.data(..))); let packet_count2 = count_valid_packets(&batches, |_| ()); assert_eq!(packet_count, packet_count2); assert_eq!(start, end); diff --git a/sdk/src/packet.rs b/sdk/src/packet.rs index eae3aa83ea..2cca65d6a5 100644 --- a/sdk/src/packet.rs +++ b/sdk/src/packet.rs @@ -5,6 +5,7 @@ use { std::{ fmt, io, net::{IpAddr, Ipv4Addr, SocketAddr}, + slice::SliceIndex, }, }; @@ -39,7 +40,7 @@ pub struct Meta { #[repr(C)] pub struct Packet { // Bytes past Packet.meta.size are not valid to read from. - // Use Packet.data() to read from the buffer. + // Use Packet.data(index) to read from the buffer. buffer: [u8; PACKET_DATA_SIZE], pub meta: Meta, } @@ -50,10 +51,14 @@ impl Packet { } /// Returns an immutable reference to the underlying buffer up to - /// Packet.meta.size. The rest of the buffer is not valid to read from. + /// packet.meta.size. The rest of the buffer is not valid to read from. + /// packet.data(..) returns packet.buffer.get(..packet.meta.size). #[inline] - pub fn data(&self) -> &[u8] { - &self.buffer[..self.meta.size] + pub fn data(&self, index: I) -> Option<&>::Output> + where + I: SliceIndex<[u8]>, + { + self.buffer.get(..self.meta.size)?.get(index) } /// Returns a mutable reference to the entirety of the underlying buffer to @@ -88,10 +93,9 @@ impl Packet { pub fn deserialize_slice(&self, index: I) -> Result where T: serde::de::DeserializeOwned, - I: std::slice::SliceIndex<[u8], Output = [u8]>, + I: SliceIndex<[u8], Output = [u8]>, { - let data = self.data(); - let bytes = data.get(index).ok_or(bincode::ErrorKind::SizeLimit)?; + let bytes = self.data(index).ok_or(bincode::ErrorKind::SizeLimit)?; bincode::options() .with_limit(PACKET_DATA_SIZE as u64) .with_fixint_encoding() @@ -123,7 +127,7 @@ impl Default for Packet { impl PartialEq for Packet { fn eq(&self, other: &Packet) -> bool { - self.meta == other.meta && self.data() == other.data() + self.meta == other.meta && self.data(..) == other.data(..) } } diff --git a/streamer/src/nonblocking/sendmmsg.rs b/streamer/src/nonblocking/sendmmsg.rs index 299eb4fb56..8721937e25 100644 --- a/streamer/src/nonblocking/sendmmsg.rs +++ b/streamer/src/nonblocking/sendmmsg.rs @@ -138,9 +138,13 @@ mod tests { let packet = Packet::default(); - let sent = multi_target_send(&sender, packet.data(), &[&addr, &addr2, &addr3, &addr4]) - .await - .ok(); + let sent = multi_target_send( + &sender, + packet.data(..).unwrap(), + &[&addr, &addr2, &addr3, &addr4], + ) + .await + .ok(); assert_eq!(sent, Some(())); let mut packets = vec![Packet::default(); 32]; diff --git a/streamer/src/packet.rs b/streamer/src/packet.rs index 7077e2fdfe..ea8e346f8d 100644 --- a/streamer/src/packet.rs +++ b/streamer/src/packet.rs @@ -67,7 +67,9 @@ pub fn send_to( for p in batch.iter() { let addr = p.meta.socket_addr(); if socket_addr_space.check(&addr) { - socket.send_to(p.data(), &addr)?; + if let Some(data) = p.data(..) { + socket.send_to(data, &addr)?; + } } } Ok(()) diff --git a/streamer/src/sendmmsg.rs b/streamer/src/sendmmsg.rs index dbef5323c7..8147ac9e18 100644 --- a/streamer/src/sendmmsg.rs +++ b/streamer/src/sendmmsg.rs @@ -242,7 +242,12 @@ mod tests { let packet = Packet::default(); - let sent = multi_target_send(&sender, packet.data(), &[&addr, &addr2, &addr3, &addr4]).ok(); + let sent = multi_target_send( + &sender, + packet.data(..).unwrap(), + &[&addr, &addr2, &addr3, &addr4], + ) + .ok(); assert_eq!(sent, Some(())); let mut packets = vec![Packet::default(); 32]; diff --git a/streamer/src/streamer.rs b/streamer/src/streamer.rs index f7fe780cae..42a5fbbc4f 100644 --- a/streamer/src/streamer.rs +++ b/streamer/src/streamer.rs @@ -279,7 +279,7 @@ impl StreamerSendStats { fn record(&mut self, pkt: &Packet) { let ent = self.host_map.entry(pkt.meta.addr).or_default(); ent.count += 1; - ent.bytes += pkt.data().len() as u64; + ent.bytes += pkt.data(..).map(<[u8]>::len).unwrap_or_default() as u64; } } @@ -296,7 +296,8 @@ fn recv_send( } let packets = packet_batch.iter().filter_map(|pkt| { let addr = pkt.meta.socket_addr(); - socket_addr_space.check(&addr).then(|| (pkt.data(), addr)) + let data = pkt.data(..)?; + socket_addr_space.check(&addr).then(|| (data, addr)) }); batch_send(sock, &packets.collect::>())?; Ok(())