From 880684565c1c7f92fce2415de0daed4a8194dcfe Mon Sep 17 00:00:00 2001 From: behzad nouri Date: Wed, 25 May 2022 16:52:54 +0000 Subject: [PATCH] limits read access into Packet data to Packet.meta.size (#25484) Bytes past Packet.meta.size are not valid to read from. The commit makes the buffer field private and instead provides two methods: * Packet::data() which returns an immutable reference to the underlying buffer up to Packet.meta.size. The rest of the buffer is not valid to read from. * Packet::buffer_mut() which returns a mutable reference to the entirety of the underlying buffer to write into. The caller is responsible to update Packet.meta.size after writing to the buffer. --- bench-streamer/src/main.rs | 2 +- core/src/banking_stage.rs | 2 +- core/src/packet_hasher.rs | 3 +- core/src/repair_response.rs | 9 ++-- core/src/serve_repair.rs | 6 +-- core/src/sigverify_shreds.rs | 8 ++-- core/src/unprocessed_packet_batches.rs | 7 +-- core/src/window_service.rs | 8 +--- gossip/tests/gossip.rs | 4 +- ledger/src/shred.rs | 18 ++++---- ledger/src/sigverify_shreds.rs | 26 ++++------- perf/src/sigverify.rs | 63 +++++++++++++------------- sdk/src/packet.rs | 33 ++++++++++---- streamer/src/nonblocking/recvmmsg.rs | 2 +- streamer/src/nonblocking/sendmmsg.rs | 10 ++-- streamer/src/packet.rs | 8 ++-- streamer/src/quic.rs | 2 +- streamer/src/recvmmsg.rs | 7 +-- streamer/src/sendmmsg.rs | 7 +-- streamer/src/streamer.rs | 8 ++-- 20 files changed, 112 insertions(+), 121 deletions(-) diff --git a/bench-streamer/src/main.rs b/bench-streamer/src/main.rs index 81d866360..f6079bc5c 100644 --- a/bench-streamer/src/main.rs +++ b/bench-streamer/src/main.rs @@ -37,7 +37,7 @@ fn producer(addr: &SocketAddr, exit: Arc) -> JoinHandle<()> { for p in packet_batch.iter() { let a = p.meta.socket_addr(); assert!(p.meta.size <= PACKET_DATA_SIZE); - send.send_to(&p.data[..p.meta.size], &a).unwrap(); + send.send_to(p.data(), &a).unwrap(); num += 1; } assert_eq!(num, 10); diff --git a/core/src/banking_stage.rs b/core/src/banking_stage.rs index 49510019e..0f0096312 100644 --- a/core/src/banking_stage.rs +++ b/core/src/banking_stage.rs @@ -508,7 +508,7 @@ impl BankingStage { .iter() .filter_map(|p| { if !p.meta.forwarded() && data_budget.take(p.meta.size) { - Some(p.data[..p.meta.size].to_vec()) + Some(p.data().to_vec()) } else { None } diff --git a/core/src/packet_hasher.rs b/core/src/packet_hasher.rs index 51746445b..8fcae1d73 100644 --- a/core/src/packet_hasher.rs +++ b/core/src/packet_hasher.rs @@ -26,8 +26,7 @@ impl Default for PacketHasher { impl PacketHasher { pub(crate) fn hash_packet(&self, packet: &Packet) -> u64 { - let size = packet.data.len().min(packet.meta.size); - self.hash_data(&packet.data[..size]) + self.hash_data(packet.data()) } pub(crate) fn hash_shred(&self, shred: &Shred) -> u64 { diff --git a/core/src/repair_response.rs b/core/src/repair_response.rs index e8f8347f3..5d36831fb 100644 --- a/core/src/repair_response.rs +++ b/core/src/repair_response.rs @@ -28,13 +28,14 @@ pub fn repair_response_packet_from_bytes( nonce: Nonce, ) -> Option { let mut packet = Packet::default(); - packet.meta.size = bytes.len() + SIZE_OF_NONCE; - if packet.meta.size > packet.data.len() { + let size = bytes.len() + SIZE_OF_NONCE; + if size > packet.buffer_mut().len() { return None; } + packet.meta.size = size; packet.meta.set_socket_addr(dest); - packet.data[..bytes.len()].copy_from_slice(&bytes); - let mut wr = io::Cursor::new(&mut packet.data[bytes.len()..]); + packet.buffer_mut()[..bytes.len()].copy_from_slice(&bytes); + let mut wr = io::Cursor::new(&mut packet.buffer_mut()[bytes.len()..]); bincode::serialize_into(&mut wr, &nonce).expect("Buffer not large enough to fit nonce"); Some(packet) } diff --git a/core/src/serve_repair.rs b/core/src/serve_repair.rs index b51d47e70..f9289c65c 100644 --- a/core/src/serve_repair.rs +++ b/core/src/serve_repair.rs @@ -814,7 +814,7 @@ mod tests { .into_iter() .filter_map(|p| { assert_eq!(repair_response::nonce(p).unwrap(), nonce); - Shred::new_from_serialized_shred(p.data.to_vec()).ok() + Shred::new_from_serialized_shred(p.data().to_vec()).ok() }) .collect(); assert!(!rv.is_empty()); @@ -898,7 +898,7 @@ mod tests { .into_iter() .filter_map(|p| { assert_eq!(repair_response::nonce(p).unwrap(), nonce); - Shred::new_from_serialized_shred(p.data.to_vec()).ok() + Shred::new_from_serialized_shred(p.data().to_vec()).ok() }) .collect(); assert_eq!(rv[0].index(), 1); @@ -1347,7 +1347,7 @@ mod tests { fn verify_responses<'a>(request: &ShredRepairType, packets: impl Iterator) { for packet in packets { - let shred_payload = packet.data.to_vec(); + let shred_payload = packet.data().to_vec(); let shred = Shred::new_from_serialized_shred(shred_payload).unwrap(); request.verify_response(&shred); } diff --git a/core/src/sigverify_shreds.rs b/core/src/sigverify_shreds.rs index 5c6f503cd..261eec390 100644 --- a/core/src/sigverify_shreds.rs +++ b/core/src/sigverify_shreds.rs @@ -120,7 +120,7 @@ pub mod tests { let keypair = Keypair::new(); shred.sign(&keypair); - batches[0][0].data[0..shred.payload().len()].copy_from_slice(shred.payload()); + batches[0][0].buffer_mut()[..shred.payload().len()].copy_from_slice(shred.payload()); batches[0][0].meta.size = shred.payload().len(); let mut shred = Shred::new_from_data( @@ -134,7 +134,7 @@ pub mod tests { 0xc0de, ); shred.sign(&keypair); - batches[1][0].data[0..shred.payload().len()].copy_from_slice(shred.payload()); + batches[1][0].buffer_mut()[..shred.payload().len()].copy_from_slice(shred.payload()); batches[1][0].meta.size = shred.payload().len(); let expected: HashSet = [0xc0de_dead, 0xdead_c0de].iter().cloned().collect(); @@ -169,7 +169,7 @@ pub mod tests { 0xc0de, ); shred.sign(&leader_keypair); - batches[0][0].data[0..shred.payload().len()].copy_from_slice(shred.payload()); + batches[0][0].buffer_mut()[..shred.payload().len()].copy_from_slice(shred.payload()); batches[0][0].meta.size = shred.payload().len(); let mut shred = Shred::new_from_data( @@ -184,7 +184,7 @@ pub mod tests { ); let wrong_keypair = Keypair::new(); shred.sign(&wrong_keypair); - batches[0][1].data[0..shred.payload().len()].copy_from_slice(shred.payload()); + batches[0][1].buffer_mut()[..shred.payload().len()].copy_from_slice(shred.payload()); batches[0][1].meta.size = shred.payload().len(); let num_packets = solana_perf::sigverify::count_packets_in_batches(&batches); diff --git a/core/src/unprocessed_packet_batches.rs b/core/src/unprocessed_packet_batches.rs index 3d44f60cd..e255d2935 100644 --- a/core/src/unprocessed_packet_batches.rs +++ b/core/src/unprocessed_packet_batches.rs @@ -356,14 +356,11 @@ pub fn deserialize_packets<'a>( /// Read the transaction message from packet data pub fn packet_message(packet: &Packet) -> Result<&[u8], DeserializedPacketError> { let (sig_len, sig_size) = - decode_shortu16_len(&packet.data).map_err(DeserializedPacketError::ShortVecError)?; + decode_shortu16_len(packet.data()).map_err(DeserializedPacketError::ShortVecError)?; sig_len .checked_mul(size_of::()) .and_then(|v| v.checked_add(sig_size)) - .map(|msg_start| { - let msg_end = packet.meta.size; - &packet.data[msg_start..msg_end] - }) + .and_then(|msg_start| packet.data().get(msg_start..)) .ok_or(DeserializedPacketError::SignatureOverflowed(sig_size)) } diff --git a/core/src/window_service.rs b/core/src/window_service.rs index 4998daec4..0bb9d5038 100644 --- a/core/src/window_service.rs +++ b/core/src/window_service.rs @@ -23,7 +23,7 @@ use { solana_perf::packet::{Packet, PacketBatch}, solana_rayon_threadlimit::get_thread_count, solana_runtime::{bank::Bank, bank_forks::BankForks}, - solana_sdk::{clock::Slot, packet::PACKET_DATA_SIZE, pubkey::Pubkey}, + solana_sdk::{clock::Slot, pubkey::Pubkey}, std::{ cmp::Reverse, collections::{HashMap, HashSet}, @@ -363,11 +363,7 @@ where inc_new_counter_debug!("streamer-recv_window-invalid_or_unnecessary_packet", 1); return None; } - // shred fetch stage should be sending packets - // with sufficiently large buffers. Needed to ensure - // call to `new_from_serialized_shred` is safe. - assert_eq!(packet.data.len(), PACKET_DATA_SIZE); - let serialized_shred = packet.data.to_vec(); + let serialized_shred = packet.data().to_vec(); let shred = Shred::new_from_serialized_shred(serialized_shred).ok()?; if !shred_filter(&shred, working_bank.clone(), last_root) { return None; diff --git a/gossip/tests/gossip.rs b/gossip/tests/gossip.rs index 00044f154..b57a03cb9 100644 --- a/gossip/tests/gossip.rs +++ b/gossip/tests/gossip.rs @@ -260,7 +260,7 @@ pub fn cluster_info_retransmit() { let retransmit_peers: Vec<_> = peers.iter().collect(); retransmit_to( &retransmit_peers, - &p.data[..p.meta.size], + p.data(), &tn1, false, &SocketAddrSpace::Unspecified, @@ -270,7 +270,7 @@ pub fn cluster_info_retransmit() { .map(|s| { let mut p = Packet::default(); s.set_read_timeout(Some(Duration::new(1, 0))).unwrap(); - let res = s.recv_from(&mut p.data); + let res = s.recv_from(p.buffer_mut()); res.is_err() //true if failed to receive the retransmit packet }) .collect(); diff --git a/ledger/src/shred.rs b/ledger/src/shred.rs index d65defdcd..72d52a337 100644 --- a/ledger/src/shred.rs +++ b/ledger/src/shred.rs @@ -287,7 +287,7 @@ impl Shred { pub fn copy_to_packet(&self, packet: &mut Packet) { let payload = self.payload(); let size = payload.len(); - packet.data[..size].copy_from_slice(&payload[..]); + packet.buffer_mut()[..size].copy_from_slice(&payload[..]); packet.meta.size = size; } @@ -575,7 +575,7 @@ pub fn get_shred_slot_index_type( } }; - let shred_type = match ShredType::try_from(p.data[OFFSET_OF_SHRED_TYPE]) { + let shred_type = match ShredType::try_from(p.data()[OFFSET_OF_SHRED_TYPE]) { Err(_) => { stats.bad_shred_type += 1; return None; @@ -733,7 +733,7 @@ mod tests { let shred = Shred::new_from_data(10, 0, 1000, &[1, 2, 3], ShredFlags::empty(), 0, 1, 0); let mut packet = Packet::default(); shred.copy_to_packet(&mut packet); - let shred_res = Shred::new_from_serialized_shred(packet.data.to_vec()); + let shred_res = Shred::new_from_serialized_shred(packet.data().to_vec()); assert_matches!( shred.parent(), Err(Error::InvalidParentOffset { @@ -825,7 +825,7 @@ mod tests { 200, // version ); shred.copy_to_packet(&mut packet); - packet.data[OFFSET_OF_SHRED_TYPE] = u8::MAX; + packet.buffer_mut()[OFFSET_OF_SHRED_TYPE] = u8::MAX; assert_eq!(None, get_shred_slot_index_type(&packet, &mut stats)); assert_eq!(1, stats.bad_shred_type); @@ -892,13 +892,13 @@ mod tests { data.iter().skip(skip).copied() }); let mut packet = Packet::default(); - packet.data[..payload.len()].copy_from_slice(&payload); + packet.buffer_mut()[..payload.len()].copy_from_slice(&payload); packet.meta.size = payload.len(); assert_eq!(shred.bytes_to_store(), payload); assert_eq!(shred, Shred::new_from_serialized_shred(payload).unwrap()); assert_eq!( shred.reference_tick(), - Shred::reference_tick_from_data(&packet.data).unwrap() + Shred::reference_tick_from_data(packet.data()).unwrap() ); assert_eq!(Shred::get_slot_from_packet(&packet), Some(shred.slot())); assert_eq!( @@ -933,13 +933,13 @@ mod tests { assert_matches!(shred.sanitize(), Ok(())); let payload = bs58_decode(PAYLOAD); let mut packet = Packet::default(); - packet.data[..payload.len()].copy_from_slice(&payload); + packet.buffer_mut()[..payload.len()].copy_from_slice(&payload); packet.meta.size = payload.len(); assert_eq!(shred.bytes_to_store(), payload); assert_eq!(shred, Shred::new_from_serialized_shred(payload).unwrap()); assert_eq!( shred.reference_tick(), - Shred::reference_tick_from_data(&packet.data).unwrap() + Shred::reference_tick_from_data(packet.data()).unwrap() ); assert_eq!(Shred::get_slot_from_packet(&packet), Some(shred.slot())); assert_eq!( @@ -981,7 +981,7 @@ mod tests { parity_shard.iter().skip(skip).copied() }); let mut packet = Packet::default(); - packet.data[..payload.len()].copy_from_slice(&payload); + packet.buffer_mut()[..payload.len()].copy_from_slice(&payload); packet.meta.size = payload.len(); assert_eq!(shred.bytes_to_store(), payload); assert_eq!(shred, Shred::new_from_serialized_shred(payload).unwrap()); diff --git a/ledger/src/sigverify_shreds.rs b/ledger/src/sigverify_shreds.rs index 1b903920e..4f769680f 100644 --- a/ledger/src/sigverify_shreds.rs +++ b/ledger/src/sigverify_shreds.rs @@ -62,12 +62,9 @@ pub fn verify_shred_cpu(packet: &Packet, slot_leaders: &HashMap) }; trace!("slot {}", slot); let pubkey = slot_leaders.get(&slot)?; - if packet.meta.size < sig_end { - return Some(0); - } - let signature = Signature::new(&packet.data[sig_start..sig_end]); + let signature = Signature::new(packet.data().get(sig_start..sig_end)?); trace!("signature {}", signature); - if !signature.verify(pubkey, &packet.data[msg_start..msg_end]) { + if !signature.verify(pubkey, packet.data().get(msg_start..msg_end)?) { return Some(0); } Some(1) @@ -307,14 +304,9 @@ fn sign_shred_cpu(keypair: &Keypair, packet: &mut Packet) { let sig_start = 0; let sig_end = sig_start + size_of::(); let msg_start = sig_end; - let msg_end = packet.meta.size; - assert!( - packet.meta.size >= msg_end, - "packet is not large enough for a signature" - ); - let signature = keypair.sign_message(&packet.data[msg_start..msg_end]); + let signature = keypair.sign_message(&packet.data()[msg_start..]); trace!("signature {:?}", signature); - packet.data[0..sig_end].copy_from_slice(signature.as_ref()); + packet.buffer_mut()[..sig_end].copy_from_slice(signature.as_ref()); } pub fn sign_shreds_cpu(keypair: &Keypair, batches: &mut [PacketBatch]) { @@ -443,7 +435,7 @@ pub fn sign_shreds_gpu( let sig_ix = packet_ix + num_packets; let sig_start = sig_ix * sig_size; let sig_end = sig_start + sig_size; - packet.data[0..sig_size] + packet.buffer_mut()[..sig_size] .copy_from_slice(&signatures_out[sig_start..sig_end]); }); }); @@ -476,7 +468,7 @@ pub mod tests { let keypair = Keypair::new(); shred.sign(&keypair); trace!("signature {}", shred.signature()); - packet.data[0..shred.payload().len()].copy_from_slice(shred.payload()); + packet.buffer_mut()[..shred.payload().len()].copy_from_slice(shred.payload()); packet.meta.size = shred.payload().len(); let leader_slots = [(slot, keypair.pubkey().to_bytes())] @@ -520,7 +512,7 @@ pub mod tests { let keypair = Keypair::new(); shred.sign(&keypair); batches[0].resize(1, Packet::default()); - batches[0][0].data[0..shred.payload().len()].copy_from_slice(shred.payload()); + batches[0][0].buffer_mut()[..shred.payload().len()].copy_from_slice(shred.payload()); batches[0][0].meta.size = shred.payload().len(); let leader_slots = [(slot, keypair.pubkey().to_bytes())] @@ -574,7 +566,7 @@ pub mod tests { let keypair = Keypair::new(); shred.sign(&keypair); batches[0].resize(1, Packet::default()); - batches[0][0].data[0..shred.payload().len()].copy_from_slice(shred.payload()); + batches[0][0].buffer_mut()[..shred.payload().len()].copy_from_slice(shred.payload()); batches[0][0].meta.size = shred.payload().len(); let leader_slots = [ @@ -685,7 +677,7 @@ pub mod tests { 0xc0de, ); batches[0].resize(1, Packet::default()); - batches[0][0].data[0..shred.payload().len()].copy_from_slice(shred.payload()); + batches[0][0].buffer_mut()[..shred.payload().len()].copy_from_slice(shred.payload()); batches[0][0].meta.size = shred.payload().len(); let pubkeys = [ diff --git a/perf/src/sigverify.rs b/perf/src/sigverify.rs index 40546801c..f3badd8e3 100644 --- a/perf/src/sigverify.rs +++ b/perf/src/sigverify.rs @@ -142,11 +142,11 @@ fn verify_packet(packet: &mut Packet, reject_non_vote: bool) { // get_packet_offsets should ensure pubkey_end and sig_end do // not overflow packet.meta.size - let signature = Signature::new(&packet.data[sig_start..sig_end]); + let signature = Signature::new(&packet.data()[sig_start..sig_end]); if !signature.verify( - &packet.data[pubkey_start..pubkey_end], - &packet.data[msg_start..msg_end], + &packet.data()[pubkey_start..pubkey_end], + &packet.data()[msg_start..msg_end], ) { packet.meta.set_discard(true); return; @@ -154,7 +154,7 @@ fn verify_packet(packet: &mut Packet, reject_non_vote: bool) { // Check for tracer pubkey if !packet.meta.is_tracer_packet() - && &packet.data[pubkey_start..pubkey_end] == TRACER_KEY.as_ref() + && &packet.data()[pubkey_start..pubkey_end] == TRACER_KEY.as_ref() { packet.meta.flags |= PacketFlags::TRACER_PACKET; } @@ -202,7 +202,7 @@ fn do_get_packet_offsets( // read the length of Transaction.signatures (serialized with short_vec) let (sig_len_untrusted, sig_size) = - decode_shortu16_len(&packet.data).map_err(|_| PacketError::InvalidShortVec)?; + decode_shortu16_len(packet.data()).map_err(|_| PacketError::InvalidShortVec)?; // Using msg_start_offset which is based on sig_len_untrusted introduces uncertainty. // Ultimately, the actual sigverify will determine the uncertainty. @@ -221,7 +221,7 @@ fn do_get_packet_offsets( // next byte indicates if the transaction is versioned. If the top bit // is set, the remaining bits encode a version number. If the top bit is // not set, this byte is the first byte of the message header. - let message_prefix = packet.data[msg_start_offset]; + let message_prefix = packet.data()[msg_start_offset]; if message_prefix & MESSAGE_VERSION_PREFIX != 0 { let version = message_prefix & !MESSAGE_VERSION_PREFIX; match version { @@ -251,7 +251,7 @@ fn do_get_packet_offsets( .ok_or(PacketError::InvalidSignatureLen)?; // read MessageHeader.num_required_signatures (serialized with u8) - let sig_len_maybe_trusted = packet.data[msg_header_offset]; + let sig_len_maybe_trusted = packet.data()[msg_header_offset]; let message_account_keys_len_offset = msg_header_offset .checked_add(MESSAGE_HEADER_LENGTH) @@ -262,7 +262,7 @@ fn do_get_packet_offsets( // num_readonly_signed_accounts, the first account is not debitable, and cannot be charged // required transaction fees. let readonly_signer_offset = msg_header_offset_plus_one; - if sig_len_maybe_trusted <= packet.data[readonly_signer_offset] { + if sig_len_maybe_trusted <= packet.data()[readonly_signer_offset] { return Err(PacketError::PayerNotWritable); } @@ -272,7 +272,7 @@ fn do_get_packet_offsets( // read the length of Message.account_keys (serialized with short_vec) let (pubkey_len, pubkey_len_size) = - decode_shortu16_len(&packet.data[message_account_keys_len_offset..]) + decode_shortu16_len(&packet.data()[message_account_keys_len_offset..]) .map_err(|_| PacketError::InvalidShortVec)?; let pubkey_start = message_account_keys_len_offset @@ -352,7 +352,7 @@ fn check_for_simple_vote_transaction( .ok_or(PacketError::InvalidLen)?; let (instruction_len, instruction_len_size) = - decode_shortu16_len(&packet.data[instructions_len_offset..]) + decode_shortu16_len(&packet.data()[instructions_len_offset..]) .map_err(|_| PacketError::InvalidLen)?; // skip if has more than 1 instruction @@ -370,7 +370,7 @@ fn check_for_simple_vote_transaction( .filter(|v| *v <= packet.meta.size) .ok_or(PacketError::InvalidLen)?; - let instruction_program_id_index: usize = usize::from(packet.data[instruction_start]); + let instruction_program_id_index: usize = usize::from(packet.data()[instruction_start]); if instruction_program_id_index >= packet_offsets.pubkey_len as usize { return Err(PacketError::InvalidProgramIdIndex); @@ -384,7 +384,7 @@ fn check_for_simple_vote_transaction( .checked_add(size_of::()) .ok_or(PacketError::InvalidLen)?; - if &packet.data[instruction_program_id_start..instruction_program_id_end] + if &packet.data()[instruction_program_id_start..instruction_program_id_end] == solana_sdk::vote::program::id().as_ref() { packet.meta.flags |= PacketFlags::SIMPLE_VOTE_TX; @@ -492,7 +492,7 @@ impl Deduper { return 1; } let mut hasher = AHasher::new_with_keys(self.seed.0, self.seed.1); - hasher.write(&packet.data[0..packet.meta.size]); + hasher.write(packet.data()); let hash = hasher.finish(); let len = self.filter.len(); let pos = (usize::try_from(hash).unwrap()).wrapping_rem(len); @@ -846,8 +846,8 @@ mod tests { let tx = test_tx(); let mut packet = Packet::from_data(None, tx).unwrap(); - packet.data[0] = 0xff; - packet.data[1] = 0xff; + packet.buffer_mut()[0] = 0xff; + packet.buffer_mut()[1] = 0xff; packet.meta.size = 2; let res = sigverify::do_get_packet_offsets(&packet, 0); @@ -919,7 +919,7 @@ mod tests { let mut packet = Packet::from_data(None, tx).unwrap(); // Make the signatures len huge - packet.data[0] = 0x7f; + packet.buffer_mut()[0] = 0x7f; let res = sigverify::do_get_packet_offsets(&packet, 0); assert_eq!(res, Err(PacketError::InvalidSignatureLen)); @@ -931,10 +931,10 @@ mod tests { let mut packet = Packet::from_data(None, tx).unwrap(); // Make the signatures len huge - packet.data[0] = 0xff; - packet.data[1] = 0xff; - packet.data[2] = 0xff; - packet.data[3] = 0xff; + packet.buffer_mut()[0] = 0xff; + packet.buffer_mut()[1] = 0xff; + packet.buffer_mut()[2] = 0xff; + packet.buffer_mut()[3] = 0xff; let res = sigverify::do_get_packet_offsets(&packet, 0); assert_eq!(res, Err(PacketError::InvalidShortVec)); @@ -948,7 +948,7 @@ mod tests { let res = sigverify::do_get_packet_offsets(&packet, 0); // make pubkey len huge - packet.data[res.unwrap().pubkey_start as usize - 1] = 0x7f; + packet.buffer_mut()[res.unwrap().pubkey_start as usize - 1] = 0x7f; let res = sigverify::do_get_packet_offsets(&packet, 0); assert_eq!(res, Err(PacketError::InvalidPubkeyLen)); @@ -982,7 +982,7 @@ mod tests { let res = sigverify::do_get_packet_offsets(&packet, 0); // set message version to 1 - packet.data[res.unwrap().msg_start as usize] = MESSAGE_VERSION_PREFIX + 1; + packet.buffer_mut()[res.unwrap().msg_start as usize] = MESSAGE_VERSION_PREFIX + 1; let res = sigverify::do_get_packet_offsets(&packet, 0); assert_eq!(res, Err(PacketError::UnsupportedVersion)); @@ -997,10 +997,11 @@ mod tests { // set message version to 0 let msg_start = legacy_offsets.msg_start as usize; - let msg_bytes = packet.data[msg_start..packet.meta.size].to_vec(); - packet.data[msg_start] = MESSAGE_VERSION_PREFIX; + let msg_bytes = packet.data()[msg_start..].to_vec(); + packet.buffer_mut()[msg_start] = MESSAGE_VERSION_PREFIX; packet.meta.size += 1; - packet.data[msg_start + 1..packet.meta.size].copy_from_slice(&msg_bytes); + let msg_end = packet.meta.size; + packet.buffer_mut()[msg_start + 1..msg_end].copy_from_slice(&msg_bytes); let offsets = sigverify::do_get_packet_offsets(&packet, 0).unwrap(); let expected_offsets = { @@ -1119,7 +1120,7 @@ mod tests { // jumble some data to test failure if modify_data { - packet.data[20] = packet.data[20].wrapping_add(10); + packet.buffer_mut()[20] = packet.data()[20].wrapping_add(10); } let mut batches = generate_packet_batches(&packet, n, 2); @@ -1185,7 +1186,7 @@ mod tests { let num_batches = 3; let mut batches = generate_packet_batches(&packet, n, num_batches); - packet.data[40] = packet.data[40].wrapping_add(8); + packet.buffer_mut()[40] = packet.data()[40].wrapping_add(8); batches[0].push(packet); @@ -1229,8 +1230,8 @@ mod tests { let packet = thread_rng().gen_range(0, batches[batch].len()); let offset = thread_rng().gen_range(0, batches[batch][packet].meta.size); let add = thread_rng().gen_range(0, 255); - batches[batch][packet].data[offset] = - batches[batch][packet].data[offset].wrapping_add(add); + batches[batch][packet].buffer_mut()[offset] = + batches[batch][packet].data()[offset].wrapping_add(add); } let batch_to_disable = thread_rng().gen_range(0, batches.len()); @@ -1504,7 +1505,7 @@ mod tests { .filter(|p| !p.meta.discard()) .for_each(|p| start.push(p.clone())) }); - start.sort_by_key(|p| p.data); + start.sort_by(|a, b| a.data().cmp(b.data())); let packet_count = count_valid_packets(&batches, |_| ()); let res = shrink_batches(&mut batches); @@ -1517,7 +1518,7 @@ mod tests { .filter(|p| !p.meta.discard()) .for_each(|p| end.push(p.clone())) }); - end.sort_by_key(|p| p.data); + end.sort_by(|a, b| a.data().cmp(b.data())); let packet_count2 = count_valid_packets(&batches, |_| ()); assert_eq!(packet_count, packet_count2); assert_eq!(start, end); diff --git a/sdk/src/packet.rs b/sdk/src/packet.rs index 516d91e10..eae3aa83e 100644 --- a/sdk/src/packet.rs +++ b/sdk/src/packet.rs @@ -38,13 +38,30 @@ pub struct Meta { #[derive(Clone, Eq)] #[repr(C)] pub struct Packet { - pub data: [u8; PACKET_DATA_SIZE], + // Bytes past Packet.meta.size are not valid to read from. + // Use Packet.data() to read from the buffer. + buffer: [u8; PACKET_DATA_SIZE], pub meta: Meta, } impl Packet { - pub fn new(data: [u8; PACKET_DATA_SIZE], meta: Meta) -> Self { - Self { data, meta } + pub fn new(buffer: [u8; PACKET_DATA_SIZE], meta: Meta) -> Self { + Self { buffer, meta } + } + + /// Returns an immutable reference to the underlying buffer up to + /// Packet.meta.size. The rest of the buffer is not valid to read from. + #[inline] + pub fn data(&self) -> &[u8] { + &self.buffer[..self.meta.size] + } + + /// Returns a mutable reference to the entirety of the underlying buffer to + /// write into. The caller is responsible for updating Packet.meta.size + /// after writing to the buffer. + #[inline] + pub fn buffer_mut(&mut self) -> &mut [u8] { + &mut self.buffer[..] } pub fn from_data(dest: Option<&SocketAddr>, data: T) -> Result { @@ -58,7 +75,7 @@ impl Packet { dest: Option<&SocketAddr>, data: &T, ) -> Result<()> { - let mut wr = io::Cursor::new(&mut packet.data[..]); + let mut wr = io::Cursor::new(packet.buffer_mut()); bincode::serialize_into(&mut wr, data)?; let len = wr.position() as usize; packet.meta.size = len; @@ -73,7 +90,7 @@ impl Packet { T: serde::de::DeserializeOwned, I: std::slice::SliceIndex<[u8], Output = [u8]>, { - let data = &self.data[0..self.meta.size]; + let data = self.data(); let bytes = data.get(index).ok_or(bincode::ErrorKind::SizeLimit)?; bincode::options() .with_limit(PACKET_DATA_SIZE as u64) @@ -98,7 +115,7 @@ impl fmt::Debug for Packet { impl Default for Packet { fn default() -> Packet { Packet { - data: unsafe { std::mem::MaybeUninit::uninit().assume_init() }, + buffer: unsafe { std::mem::MaybeUninit::uninit().assume_init() }, meta: Meta::default(), } } @@ -106,9 +123,7 @@ impl Default for Packet { impl PartialEq for Packet { fn eq(&self, other: &Packet) -> bool { - let self_data: &[u8] = self.data.as_ref(); - let other_data: &[u8] = other.data.as_ref(); - self.meta == other.meta && self_data[..self.meta.size] == other_data[..self.meta.size] + self.meta == other.meta && self.data() == other.data() } } diff --git a/streamer/src/nonblocking/recvmmsg.rs b/streamer/src/nonblocking/recvmmsg.rs index b34e74ce6..df2b08ff2 100644 --- a/streamer/src/nonblocking/recvmmsg.rs +++ b/streamer/src/nonblocking/recvmmsg.rs @@ -19,7 +19,7 @@ pub async fn recv_mmsg( let mut i = 0; for p in packets.iter_mut().take(count) { p.meta.size = 0; - match socket.try_recv_from(&mut p.data) { + match socket.try_recv_from(p.buffer_mut()) { Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { break; } diff --git a/streamer/src/nonblocking/sendmmsg.rs b/streamer/src/nonblocking/sendmmsg.rs index 6797bb9fb..299eb4fb5 100644 --- a/streamer/src/nonblocking/sendmmsg.rs +++ b/streamer/src/nonblocking/sendmmsg.rs @@ -138,13 +138,9 @@ mod tests { let packet = Packet::default(); - let sent = multi_target_send( - &sender, - &packet.data[..packet.meta.size], - &[&addr, &addr2, &addr3, &addr4], - ) - .await - .ok(); + let sent = multi_target_send(&sender, packet.data(), &[&addr, &addr2, &addr3, &addr4]) + .await + .ok(); assert_eq!(sent, Some(())); let mut packets = vec![Packet::default(); 32]; diff --git a/streamer/src/packet.rs b/streamer/src/packet.rs index c528442d7..7077e2fdf 100644 --- a/streamer/src/packet.rs +++ b/streamer/src/packet.rs @@ -67,7 +67,7 @@ pub fn send_to( for p in batch.iter() { let addr = p.meta.socket_addr(); if socket_addr_space.check(&addr) { - socket.send_to(&p.data[..p.meta.size], &addr)?; + socket.send_to(p.data(), &addr)?; } } Ok(()) @@ -135,14 +135,14 @@ mod tests { let mut p2 = Packet::default(); p1.meta.size = 1; - p1.data[0] = 0; + p1.buffer_mut()[0] = 0; p2.meta.size = 1; - p2.data[0] = 0; + p2.buffer_mut()[0] = 0; assert!(p1 == p2); - p2.data[0] = 4; + p2.buffer_mut()[0] = 4; assert!(p1 != p2); } diff --git a/streamer/src/quic.rs b/streamer/src/quic.rs index 6429d8d1b..162510194 100644 --- a/streamer/src/quic.rs +++ b/streamer/src/quic.rs @@ -189,7 +189,7 @@ fn handle_chunk( if let Some(batch) = maybe_batch.as_mut() { let end = chunk.offset as usize + chunk.bytes.len(); - batch[0].data[chunk.offset as usize..end].copy_from_slice(&chunk.bytes); + batch[0].buffer_mut()[chunk.offset as usize..end].copy_from_slice(&chunk.bytes); batch[0].meta.size = std::cmp::max(batch[0].meta.size, end); stats.total_chunks_received.fetch_add(1, Ordering::Relaxed); } diff --git a/streamer/src/recvmmsg.rs b/streamer/src/recvmmsg.rs index ea6d6bf6a..bb2691009 100644 --- a/streamer/src/recvmmsg.rs +++ b/streamer/src/recvmmsg.rs @@ -22,7 +22,7 @@ pub fn recv_mmsg(socket: &UdpSocket, packets: &mut [Packet]) -> io::Result 0 => { break; } @@ -84,9 +84,10 @@ pub fn recv_mmsg(sock: &UdpSocket, packets: &mut [Packet]) -> io::Result>())?; Ok(()) @@ -478,7 +476,7 @@ mod test { for i in 0..NUM_PACKETS { let mut p = Packet::default(); { - p.data[0] = i as u8; + p.buffer_mut()[0] = i as u8; p.meta.size = PACKET_DATA_SIZE; p.meta.set_socket_addr(&addr); }