diff --git a/bench-streamer/src/main.rs b/bench-streamer/src/main.rs index f6079bc5c1..f03d29bf74 100644 --- a/bench-streamer/src/main.rs +++ b/bench-streamer/src/main.rs @@ -37,7 +37,8 @@ fn producer(addr: &SocketAddr, exit: Arc) -> JoinHandle<()> { for p in packet_batch.iter() { let a = p.meta.socket_addr(); assert!(p.meta.size <= PACKET_DATA_SIZE); - send.send_to(p.data(), &a).unwrap(); + let data = p.data(..).unwrap_or_default(); + send.send_to(data, &a).unwrap(); num += 1; } assert_eq!(num, 10); diff --git a/core/src/banking_stage.rs b/core/src/banking_stage.rs index b3e90b0c47..071c665fb2 100644 --- a/core/src/banking_stage.rs +++ b/core/src/banking_stage.rs @@ -541,7 +541,7 @@ impl BankingStage { .iter() .filter_map(|p| { if !p.meta.forwarded() && data_budget.take(p.meta.size) { - Some(p.data().to_vec()) + Some(p.data(..)?.to_vec()) } else { None } diff --git a/core/src/packet_hasher.rs b/core/src/packet_hasher.rs index 8fcae1d73b..31280015c3 100644 --- a/core/src/packet_hasher.rs +++ b/core/src/packet_hasher.rs @@ -26,7 +26,7 @@ impl Default for PacketHasher { impl PacketHasher { pub(crate) fn hash_packet(&self, packet: &Packet) -> u64 { - self.hash_data(packet.data()) + self.hash_data(packet.data(..).unwrap_or_default()) } pub(crate) fn hash_shred(&self, shred: &Shred) -> u64 { diff --git a/core/src/serve_repair.rs b/core/src/serve_repair.rs index f9289c65c1..d17c1978e1 100644 --- a/core/src/serve_repair.rs +++ b/core/src/serve_repair.rs @@ -814,7 +814,7 @@ mod tests { .into_iter() .filter_map(|p| { assert_eq!(repair_response::nonce(p).unwrap(), nonce); - Shred::new_from_serialized_shred(p.data().to_vec()).ok() + Shred::new_from_serialized_shred(p.data(..).unwrap().to_vec()).ok() }) .collect(); assert!(!rv.is_empty()); @@ -898,7 +898,7 @@ mod tests { .into_iter() .filter_map(|p| { assert_eq!(repair_response::nonce(p).unwrap(), nonce); - Shred::new_from_serialized_shred(p.data().to_vec()).ok() + Shred::new_from_serialized_shred(p.data(..).unwrap().to_vec()).ok() }) .collect(); assert_eq!(rv[0].index(), 1); @@ -1347,7 +1347,7 @@ mod tests { fn verify_responses<'a>(request: &ShredRepairType, packets: impl Iterator) { for packet in packets { - let shred_payload = packet.data().to_vec(); + let shred_payload = packet.data(..).unwrap().to_vec(); let shred = Shred::new_from_serialized_shred(shred_payload).unwrap(); request.verify_response(&shred); } diff --git a/core/src/unprocessed_packet_batches.rs b/core/src/unprocessed_packet_batches.rs index 277ba3adf7..bfc8852111 100644 --- a/core/src/unprocessed_packet_batches.rs +++ b/core/src/unprocessed_packet_batches.rs @@ -368,12 +368,14 @@ pub fn deserialize_packets<'a>( /// Read the transaction message from packet data pub fn packet_message(packet: &Packet) -> Result<&[u8], DeserializedPacketError> { - let (sig_len, sig_size) = - decode_shortu16_len(packet.data()).map_err(DeserializedPacketError::ShortVecError)?; + let (sig_len, sig_size) = packet + .data(..) + .and_then(|bytes| decode_shortu16_len(bytes).ok()) + .ok_or(DeserializedPacketError::ShortVecError(()))?; sig_len .checked_mul(size_of::()) .and_then(|v| v.checked_add(sig_size)) - .and_then(|msg_start| packet.data().get(msg_start..)) + .and_then(|msg_start| packet.data(msg_start..)) .ok_or(DeserializedPacketError::SignatureOverflowed(sig_size)) } diff --git a/core/src/window_service.rs b/core/src/window_service.rs index 0bb9d50381..7e582fa4ea 100644 --- a/core/src/window_service.rs +++ b/core/src/window_service.rs @@ -363,7 +363,7 @@ where inc_new_counter_debug!("streamer-recv_window-invalid_or_unnecessary_packet", 1); return None; } - let serialized_shred = packet.data().to_vec(); + let serialized_shred = packet.data(..)?.to_vec(); let shred = Shred::new_from_serialized_shred(serialized_shred).ok()?; if !shred_filter(&shred, working_bank.clone(), last_root) { return None; diff --git a/gossip/tests/gossip.rs b/gossip/tests/gossip.rs index b57a03cb94..f3e136cdba 100644 --- a/gossip/tests/gossip.rs +++ b/gossip/tests/gossip.rs @@ -260,7 +260,7 @@ pub fn cluster_info_retransmit() { let retransmit_peers: Vec<_> = peers.iter().collect(); retransmit_to( &retransmit_peers, - p.data(), + p.data(..).unwrap(), &tn1, false, &SocketAddrSpace::Unspecified, diff --git a/ledger/src/shred.rs b/ledger/src/shred.rs index f7b80f6157..5462797a05 100644 --- a/ledger/src/shred.rs +++ b/ledger/src/shred.rs @@ -509,7 +509,7 @@ pub mod layout { use {super::*, std::ops::Range}; fn get_shred_size(packet: &Packet) -> Option { - let size = packet.data().len(); + let size = packet.data(..)?.len(); if packet.meta.repair() { size.checked_sub(SIZE_OF_NONCE) } else { @@ -519,7 +519,7 @@ pub mod layout { pub fn get_shred(packet: &Packet) -> Option<&[u8]> { let size = get_shred_size(packet)?; - let shred = packet.data().get(..size)?; + let shred = packet.data(..size)?; // Should at least have a signature. (size >= SIZE_OF_SIGNATURE).then(|| shred) } @@ -826,7 +826,7 @@ mod tests { let shred = Shred::new_from_data(10, 0, 1000, &[1, 2, 3], ShredFlags::empty(), 0, 1, 0); let mut packet = Packet::default(); shred.copy_to_packet(&mut packet); - let shred_res = Shred::new_from_serialized_shred(packet.data().to_vec()); + let shred_res = Shred::new_from_serialized_shred(packet.data(..).unwrap().to_vec()); assert_matches!( shred.parent(), Err(Error::InvalidParentOffset { @@ -1029,9 +1029,12 @@ mod tests { assert_eq!(shred, Shred::new_from_serialized_shred(payload).unwrap()); assert_eq!( shred.reference_tick(), - layout::get_reference_tick(packet.data()).unwrap() + layout::get_reference_tick(packet.data(..).unwrap()).unwrap() + ); + assert_eq!( + layout::get_slot(packet.data(..).unwrap()), + Some(shred.slot()) ); - assert_eq!(layout::get_slot(packet.data()), Some(shred.slot())); assert_eq!( get_shred_slot_index_type(&packet, &mut ShredFetchStats::default()), Some((shred.slot(), shred.index(), shred.shred_type())) @@ -1070,9 +1073,12 @@ mod tests { assert_eq!(shred, Shred::new_from_serialized_shred(payload).unwrap()); assert_eq!( shred.reference_tick(), - layout::get_reference_tick(packet.data()).unwrap() + layout::get_reference_tick(packet.data(..).unwrap()).unwrap() + ); + assert_eq!( + layout::get_slot(packet.data(..).unwrap()), + Some(shred.slot()) ); - assert_eq!(layout::get_slot(packet.data()), Some(shred.slot())); assert_eq!( get_shred_slot_index_type(&packet, &mut ShredFetchStats::default()), Some((shred.slot(), shred.index(), shred.shred_type())) @@ -1116,7 +1122,10 @@ mod tests { packet.meta.size = payload.len(); assert_eq!(shred.bytes_to_store(), payload); assert_eq!(shred, Shred::new_from_serialized_shred(payload).unwrap()); - assert_eq!(layout::get_slot(packet.data()), Some(shred.slot())); + assert_eq!( + layout::get_slot(packet.data(..).unwrap()), + Some(shred.slot()) + ); assert_eq!( get_shred_slot_index_type(&packet, &mut ShredFetchStats::default()), Some((shred.slot(), shred.index(), shred.shred_type())) diff --git a/ledger/src/sigverify_shreds.rs b/ledger/src/sigverify_shreds.rs index 163e733442..aecad26aa6 100644 --- a/ledger/src/sigverify_shreds.rs +++ b/ledger/src/sigverify_shreds.rs @@ -281,7 +281,7 @@ fn sign_shred_cpu(keypair: &Keypair, packet: &mut Packet) { packet.meta.size >= sig.end, "packet is not large enough for a signature" ); - let signature = keypair.sign_message(&packet.data()[msg]); + let signature = keypair.sign_message(packet.data(msg).unwrap()); trace!("signature {:?}", signature); packet.buffer_mut()[sig].copy_from_slice(signature.as_ref()); } diff --git a/perf/src/sigverify.rs b/perf/src/sigverify.rs index 1691b04264..6a7be319cb 100644 --- a/perf/src/sigverify.rs +++ b/perf/src/sigverify.rs @@ -12,7 +12,7 @@ use { }, ahash::AHasher, rand::{thread_rng, Rng}, - rayon::ThreadPool, + rayon::{prelude::*, ThreadPool}, solana_metrics::inc_new_counter_debug, solana_rayon_threadlimit::get_thread_count, solana_sdk::{ @@ -114,10 +114,13 @@ pub fn init() { } } -fn verify_packet(packet: &mut Packet, reject_non_vote: bool) { +/// Returns true if the signatrue on the packet verifies. +/// Caller must do packet.set_discard(true) if this returns false. +#[must_use] +fn verify_packet(packet: &mut Packet, reject_non_vote: bool) -> bool { // If this packet was already marked as discard, drop it if packet.meta.discard() { - return; + return false; } let packet_offsets = get_packet_offsets(packet, 0, reject_non_vote); @@ -126,36 +129,38 @@ fn verify_packet(packet: &mut Packet, reject_non_vote: bool) { let msg_start = packet_offsets.msg_start as usize; if packet_offsets.sig_len == 0 { - packet.meta.set_discard(true); - return; + return false; } if packet.meta.size <= msg_start { - packet.meta.set_discard(true); - return; + return false; } - let msg_end = packet.meta.size; for _ in 0..packet_offsets.sig_len { let pubkey_end = pubkey_start.saturating_add(size_of::()); - let sig_end = sig_start.saturating_add(size_of::()); - - // get_packet_offsets should ensure pubkey_end and sig_end do - // not overflow packet.meta.size - - let signature = Signature::new(&packet.data()[sig_start..sig_end]); - - if !signature.verify( - &packet.data()[pubkey_start..pubkey_end], - &packet.data()[msg_start..msg_end], - ) { - packet.meta.set_discard(true); - return; + let sig_end = match sig_start.checked_add(size_of::()) { + Some(sig_end) => sig_end, + None => return false, + }; + let signature = match packet.data(sig_start..sig_end) { + Some(signature) => Signature::new(signature), + None => return false, + }; + let pubkey = match packet.data(pubkey_start..pubkey_end) { + Some(pubkey) => pubkey, + None => return false, + }; + let message = match packet.data(msg_start..) { + Some(message) => message, + None => return false, + }; + if !signature.verify(pubkey, message) { + return false; } - pubkey_start = pubkey_end; sig_start = sig_end; } + true } pub fn count_packets_in_batches(batches: &[PacketBatch]) -> usize { @@ -202,9 +207,10 @@ fn do_get_packet_offsets( .ok_or(PacketError::InvalidLen)?; // read the length of Transaction.signatures (serialized with short_vec) - let (sig_len_untrusted, sig_size) = - decode_shortu16_len(packet.data()).map_err(|_| PacketError::InvalidShortVec)?; - + let (sig_len_untrusted, sig_size) = packet + .data(..) + .and_then(|bytes| decode_shortu16_len(bytes).ok()) + .ok_or(PacketError::InvalidShortVec)?; // Using msg_start_offset which is based on sig_len_untrusted introduces uncertainty. // Ultimately, the actual sigverify will determine the uncertainty. let msg_start_offset = sig_len_untrusted @@ -222,7 +228,9 @@ fn do_get_packet_offsets( // next byte indicates if the transaction is versioned. If the top bit // is set, the remaining bits encode a version number. If the top bit is // not set, this byte is the first byte of the message header. - let message_prefix = packet.data()[msg_start_offset]; + let message_prefix = *packet + .data(msg_start_offset) + .ok_or(PacketError::InvalidSignatureLen)?; if message_prefix & MESSAGE_VERSION_PREFIX != 0 { let version = message_prefix & !MESSAGE_VERSION_PREFIX; match version { @@ -252,8 +260,9 @@ fn do_get_packet_offsets( .ok_or(PacketError::InvalidSignatureLen)?; // read MessageHeader.num_required_signatures (serialized with u8) - let sig_len_maybe_trusted = packet.data()[msg_header_offset]; - + let sig_len_maybe_trusted = *packet + .data(msg_header_offset) + .ok_or(PacketError::InvalidSignatureLen)?; let message_account_keys_len_offset = msg_header_offset .checked_add(MESSAGE_HEADER_LENGTH) .ok_or(PacketError::InvalidSignatureLen)?; @@ -263,7 +272,11 @@ fn do_get_packet_offsets( // num_readonly_signed_accounts, the first account is not debitable, and cannot be charged // required transaction fees. let readonly_signer_offset = msg_header_offset_plus_one; - if sig_len_maybe_trusted <= packet.data()[readonly_signer_offset] { + if sig_len_maybe_trusted + <= *packet + .data(readonly_signer_offset) + .ok_or(PacketError::InvalidSignatureLen)? + { return Err(PacketError::PayerNotWritable); } @@ -272,10 +285,10 @@ fn do_get_packet_offsets( } // read the length of Message.account_keys (serialized with short_vec) - let (pubkey_len, pubkey_len_size) = - decode_shortu16_len(&packet.data()[message_account_keys_len_offset..]) - .map_err(|_| PacketError::InvalidShortVec)?; - + let (pubkey_len, pubkey_len_size) = packet + .data(message_account_keys_len_offset..) + .and_then(|bytes| decode_shortu16_len(bytes).ok()) + .ok_or(PacketError::InvalidShortVec)?; let pubkey_start = message_account_keys_len_offset .checked_add(pubkey_len_size) .ok_or(PacketError::InvalidPubkeyLen)?; @@ -311,19 +324,17 @@ fn do_get_packet_offsets( pub fn check_for_tracer_packet(packet: &mut Packet) -> bool { let first_pubkey_start: usize = TRACER_KEY_OFFSET_IN_TRANSACTION; - let maybe_first_pubkey_end = first_pubkey_start - .checked_add(size_of::()) - .filter(|v| v <= &packet.meta.size); + let first_pubkey_end = match first_pubkey_start.checked_add(size_of::()) { + Some(offset) => offset, + None => return false, + }; // Check for tracer pubkey - if let Some(first_pubkey_end) = maybe_first_pubkey_end { - let is_tracer_packet = - &packet.data()[first_pubkey_start..first_pubkey_end] == TRACER_KEY.as_ref(); - if is_tracer_packet { + match packet.data(first_pubkey_start..first_pubkey_end) { + Some(pubkey) if pubkey == TRACER_KEY.as_ref() => { packet.meta.flags |= PacketFlags::TRACER_PACKET; + true } - is_tracer_packet - } else { - false + _ => false, } } @@ -370,10 +381,10 @@ fn check_for_simple_vote_transaction( .filter(|v| *v <= packet.meta.size) .ok_or(PacketError::InvalidLen)?; - let (instruction_len, instruction_len_size) = - decode_shortu16_len(&packet.data()[instructions_len_offset..]) - .map_err(|_| PacketError::InvalidLen)?; - + let (instruction_len, instruction_len_size) = packet + .data(instructions_len_offset..) + .and_then(|bytes| decode_shortu16_len(bytes).ok()) + .ok_or(PacketError::InvalidLen)?; // skip if has more than 1 instruction if instruction_len != 1 { return Err(PacketError::InvalidProgramLen); @@ -389,7 +400,11 @@ fn check_for_simple_vote_transaction( .filter(|v| *v <= packet.meta.size) .ok_or(PacketError::InvalidLen)?; - let instruction_program_id_index: usize = usize::from(packet.data()[instruction_start]); + let instruction_program_id_index: usize = usize::from( + *packet + .data(instruction_start) + .ok_or(PacketError::InvalidLen)?, + ); if instruction_program_id_index >= packet_offsets.pubkey_len as usize { return Err(PacketError::InvalidProgramIdIndex); @@ -403,7 +418,9 @@ fn check_for_simple_vote_transaction( .checked_add(size_of::()) .ok_or(PacketError::InvalidLen)?; - if &packet.data()[instruction_program_id_start..instruction_program_id_end] + if packet + .data(instruction_program_id_start..instruction_program_id_end) + .ok_or(PacketError::InvalidLen)? == solana_sdk::vote::program::id().as_ref() { packet.meta.flags |= PacketFlags::SIMPLE_VOTE_TX; @@ -507,7 +524,7 @@ impl Deduper { /// Compute hash from packet data, returns (hash, bin_pos). fn compute_hash(&self, packet: &Packet) -> (u64, usize) { let mut hasher = AHasher::new_with_keys(self.seed.0, self.seed.1); - hasher.write(packet.data()); + hasher.write(packet.data(..).unwrap_or_default()); let h = hasher.finish(); let len = self.filter.len(); let pos = (usize::try_from(h).unwrap()).wrapping_rem(len); @@ -590,20 +607,20 @@ pub fn shrink_batches(batches: &mut Vec) { } pub fn ed25519_verify_cpu(batches: &mut [PacketBatch], reject_non_vote: bool, packet_count: usize) { - use rayon::prelude::*; debug!("CPU ECDSA for {}", packet_count); PAR_THREAD_POOL.install(|| { batches.into_par_iter().for_each(|batch| { - batch - .par_iter_mut() - .for_each(|p| verify_packet(p, reject_non_vote)) + batch.par_iter_mut().for_each(|packet| { + if !packet.meta.discard() && !verify_packet(packet, reject_non_vote) { + packet.meta.set_discard(true); + } + }) }); }); inc_new_counter_debug!("ed25519_verify_cpu", packet_count); } pub fn ed25519_verify_disabled(batches: &mut [PacketBatch]) { - use rayon::prelude::*; let packet_count = count_packets_in_batches(batches); debug!("disabled ECDSA for {}", packet_count); batches @@ -759,17 +776,20 @@ mod tests { use { super::*, crate::{ - packet::{to_packet_batches, Packet, PacketBatch, PACKETS_PER_BATCH}, + packet::{to_packet_batches, Packet, PacketBatch, PACKETS_PER_BATCH, PACKET_DATA_SIZE}, sigverify::{self, PacketOffsets}, test_tx::{new_test_vote_tx, test_multisig_tx, test_tx}, }, bincode::{deserialize, serialize}, + curve25519_dalek::{edwards::CompressedEdwardsY, scalar::Scalar}, + rand::{thread_rng, Rng}, solana_sdk::{ instruction::CompiledInstruction, message::{Message, MessageHeader}, - signature::{Keypair, Signature}, + signature::{Keypair, Signature, Signer}, transaction::Transaction, }, + std::sync::atomic::{AtomicU64, Ordering}, }; const SIG_OFFSET: usize = 1; @@ -893,8 +913,7 @@ mod tests { let res = sigverify::do_get_packet_offsets(&packet, 0); assert_eq!(res, Err(PacketError::InvalidPubkeyLen)); - verify_packet(&mut packet, false); - assert!(packet.meta.discard()); + assert!(!verify_packet(&mut packet, false)); packet.meta.set_discard(false); let mut batches = generate_packet_batches(&packet, 1, 1); @@ -906,7 +925,6 @@ mod tests { fn test_pubkey_len() { // See that the verify cannot walk off the end of the packet // trying to index into the account_keys to access pubkey. - use solana_sdk::signer::{keypair::Keypair, Signer}; solana_logger::setup(); const NUM_SIG: usize = 17; @@ -929,8 +947,7 @@ mod tests { let res = sigverify::do_get_packet_offsets(&packet, 0); assert_eq!(res, Err(PacketError::InvalidPubkeyLen)); - verify_packet(&mut packet, false); - assert!(packet.meta.discard()); + assert!(!verify_packet(&mut packet, false)); packet.meta.set_discard(false); let mut batches = generate_packet_batches(&packet, 1, 1); @@ -1022,7 +1039,7 @@ mod tests { // set message version to 0 let msg_start = legacy_offsets.msg_start as usize; - let msg_bytes = packet.data()[msg_start..].to_vec(); + let msg_bytes = packet.data(msg_start..).unwrap().to_vec(); packet.buffer_mut()[msg_start] = MESSAGE_VERSION_PREFIX; packet.meta.size += 1; let msg_end = packet.meta.size; @@ -1039,7 +1056,6 @@ mod tests { #[test] fn test_system_transaction_data_layout() { - use crate::packet::PACKET_DATA_SIZE; let mut tx0 = test_tx(); tx0.message.instructions[0].data = vec![1, 2, 3]; let message0a = tx0.message_data(); @@ -1145,7 +1161,7 @@ mod tests { // jumble some data to test failure if modify_data { - packet.buffer_mut()[20] = packet.data()[20].wrapping_add(10); + packet.buffer_mut()[20] = packet.data(20).unwrap().wrapping_add(10); } let mut batches = generate_packet_batches(&packet, n, 2); @@ -1211,7 +1227,7 @@ mod tests { let num_batches = 3; let mut batches = generate_packet_batches(&packet, n, num_batches); - packet.buffer_mut()[40] = packet.data()[40].wrapping_add(8); + packet.buffer_mut()[40] = packet.data(40).unwrap().wrapping_add(8); batches[0].push(packet); @@ -1237,7 +1253,6 @@ mod tests { #[test] fn test_verify_fuzz() { - use rand::{thread_rng, Rng}; solana_logger::setup(); let tx = test_multisig_tx(); @@ -1255,8 +1270,10 @@ mod tests { let packet = thread_rng().gen_range(0, batches[batch].len()); let offset = thread_rng().gen_range(0, batches[batch][packet].meta.size); let add = thread_rng().gen_range(0, 255); - batches[batch][packet].buffer_mut()[offset] = - batches[batch][packet].data()[offset].wrapping_add(add); + batches[batch][packet].buffer_mut()[offset] = batches[batch][packet] + .data(offset) + .unwrap() + .wrapping_add(add); } let batch_to_disable = thread_rng().gen_range(0, batches.len()); @@ -1288,13 +1305,6 @@ mod tests { #[test] fn test_get_checked_scalar() { solana_logger::setup(); - use { - curve25519_dalek::scalar::Scalar, - rand::{thread_rng, Rng}, - rayon::prelude::*, - std::sync::atomic::{AtomicU64, Ordering}, - }; - if perf_libs::api().is_none() { return; } @@ -1330,13 +1340,6 @@ mod tests { #[test] fn test_ge_small_order() { solana_logger::setup(); - use { - curve25519_dalek::edwards::CompressedEdwardsY, - rand::{thread_rng, Rng}, - rayon::prelude::*, - std::sync::atomic::{AtomicU64, Ordering}, - }; - if perf_libs::api().is_none() { return; } @@ -1530,7 +1533,7 @@ mod tests { .filter(|p| !p.meta.discard()) .for_each(|p| start.push(p.clone())) }); - start.sort_by(|a, b| a.data().cmp(b.data())); + start.sort_by(|a, b| a.data(..).cmp(&b.data(..))); let packet_count = count_valid_packets(&batches, |_| ()); shrink_batches(&mut batches); @@ -1542,7 +1545,7 @@ mod tests { .filter(|p| !p.meta.discard()) .for_each(|p| end.push(p.clone())) }); - end.sort_by(|a, b| a.data().cmp(b.data())); + end.sort_by(|a, b| a.data(..).cmp(&b.data(..))); let packet_count2 = count_valid_packets(&batches, |_| ()); assert_eq!(packet_count, packet_count2); assert_eq!(start, end); diff --git a/sdk/src/packet.rs b/sdk/src/packet.rs index eae3aa83ea..2cca65d6a5 100644 --- a/sdk/src/packet.rs +++ b/sdk/src/packet.rs @@ -5,6 +5,7 @@ use { std::{ fmt, io, net::{IpAddr, Ipv4Addr, SocketAddr}, + slice::SliceIndex, }, }; @@ -39,7 +40,7 @@ pub struct Meta { #[repr(C)] pub struct Packet { // Bytes past Packet.meta.size are not valid to read from. - // Use Packet.data() to read from the buffer. + // Use Packet.data(index) to read from the buffer. buffer: [u8; PACKET_DATA_SIZE], pub meta: Meta, } @@ -50,10 +51,14 @@ impl Packet { } /// Returns an immutable reference to the underlying buffer up to - /// Packet.meta.size. The rest of the buffer is not valid to read from. + /// packet.meta.size. The rest of the buffer is not valid to read from. + /// packet.data(..) returns packet.buffer.get(..packet.meta.size). #[inline] - pub fn data(&self) -> &[u8] { - &self.buffer[..self.meta.size] + pub fn data(&self, index: I) -> Option<&>::Output> + where + I: SliceIndex<[u8]>, + { + self.buffer.get(..self.meta.size)?.get(index) } /// Returns a mutable reference to the entirety of the underlying buffer to @@ -88,10 +93,9 @@ impl Packet { pub fn deserialize_slice(&self, index: I) -> Result where T: serde::de::DeserializeOwned, - I: std::slice::SliceIndex<[u8], Output = [u8]>, + I: SliceIndex<[u8], Output = [u8]>, { - let data = self.data(); - let bytes = data.get(index).ok_or(bincode::ErrorKind::SizeLimit)?; + let bytes = self.data(index).ok_or(bincode::ErrorKind::SizeLimit)?; bincode::options() .with_limit(PACKET_DATA_SIZE as u64) .with_fixint_encoding() @@ -123,7 +127,7 @@ impl Default for Packet { impl PartialEq for Packet { fn eq(&self, other: &Packet) -> bool { - self.meta == other.meta && self.data() == other.data() + self.meta == other.meta && self.data(..) == other.data(..) } } diff --git a/streamer/src/nonblocking/sendmmsg.rs b/streamer/src/nonblocking/sendmmsg.rs index 299eb4fb56..8721937e25 100644 --- a/streamer/src/nonblocking/sendmmsg.rs +++ b/streamer/src/nonblocking/sendmmsg.rs @@ -138,9 +138,13 @@ mod tests { let packet = Packet::default(); - let sent = multi_target_send(&sender, packet.data(), &[&addr, &addr2, &addr3, &addr4]) - .await - .ok(); + let sent = multi_target_send( + &sender, + packet.data(..).unwrap(), + &[&addr, &addr2, &addr3, &addr4], + ) + .await + .ok(); assert_eq!(sent, Some(())); let mut packets = vec![Packet::default(); 32]; diff --git a/streamer/src/packet.rs b/streamer/src/packet.rs index 7077e2fdfe..ea8e346f8d 100644 --- a/streamer/src/packet.rs +++ b/streamer/src/packet.rs @@ -67,7 +67,9 @@ pub fn send_to( for p in batch.iter() { let addr = p.meta.socket_addr(); if socket_addr_space.check(&addr) { - socket.send_to(p.data(), &addr)?; + if let Some(data) = p.data(..) { + socket.send_to(data, &addr)?; + } } } Ok(()) diff --git a/streamer/src/sendmmsg.rs b/streamer/src/sendmmsg.rs index dbef5323c7..8147ac9e18 100644 --- a/streamer/src/sendmmsg.rs +++ b/streamer/src/sendmmsg.rs @@ -242,7 +242,12 @@ mod tests { let packet = Packet::default(); - let sent = multi_target_send(&sender, packet.data(), &[&addr, &addr2, &addr3, &addr4]).ok(); + let sent = multi_target_send( + &sender, + packet.data(..).unwrap(), + &[&addr, &addr2, &addr3, &addr4], + ) + .ok(); assert_eq!(sent, Some(())); let mut packets = vec![Packet::default(); 32]; diff --git a/streamer/src/streamer.rs b/streamer/src/streamer.rs index f7fe780cae..42a5fbbc4f 100644 --- a/streamer/src/streamer.rs +++ b/streamer/src/streamer.rs @@ -279,7 +279,7 @@ impl StreamerSendStats { fn record(&mut self, pkt: &Packet) { let ent = self.host_map.entry(pkt.meta.addr).or_default(); ent.count += 1; - ent.bytes += pkt.data().len() as u64; + ent.bytes += pkt.data(..).map(<[u8]>::len).unwrap_or_default() as u64; } } @@ -296,7 +296,8 @@ fn recv_send( } let packets = packet_batch.iter().filter_map(|pkt| { let addr = pkt.meta.socket_addr(); - socket_addr_space.check(&addr).then(|| (pkt.data(), addr)) + let data = pkt.data(..)?; + socket_addr_space.check(&addr).then(|| (data, addr)) }); batch_send(sock, &packets.collect::>())?; Ok(())