limits read access into Packet data to Packet.meta.size (#25484)

Bytes past Packet.meta.size are not valid to read from.

The commit makes the buffer field private and instead provides two
methods:
* Packet::data() which returns an immutable reference to the underlying
  buffer up to Packet.meta.size. The rest of the buffer is not valid to
  read from.
* Packet::buffer_mut() which returns a mutable reference to the entirety
  of the underlying buffer to write into. The caller is responsible to
  update Packet.meta.size after writing to the buffer.
This commit is contained in:
behzad nouri 2022-05-25 16:52:54 +00:00 committed by GitHub
parent f10c80b49f
commit 880684565c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 112 additions and 121 deletions

View File

@ -37,7 +37,7 @@ fn producer(addr: &SocketAddr, exit: Arc<AtomicBool>) -> JoinHandle<()> {
for p in packet_batch.iter() {
let a = p.meta.socket_addr();
assert!(p.meta.size <= PACKET_DATA_SIZE);
send.send_to(&p.data[..p.meta.size], &a).unwrap();
send.send_to(p.data(), &a).unwrap();
num += 1;
}
assert_eq!(num, 10);

View File

@ -508,7 +508,7 @@ impl BankingStage {
.iter()
.filter_map(|p| {
if !p.meta.forwarded() && data_budget.take(p.meta.size) {
Some(p.data[..p.meta.size].to_vec())
Some(p.data().to_vec())
} else {
None
}

View File

@ -26,8 +26,7 @@ impl Default for PacketHasher {
impl PacketHasher {
pub(crate) fn hash_packet(&self, packet: &Packet) -> u64 {
let size = packet.data.len().min(packet.meta.size);
self.hash_data(&packet.data[..size])
self.hash_data(packet.data())
}
pub(crate) fn hash_shred(&self, shred: &Shred) -> u64 {

View File

@ -28,13 +28,14 @@ pub fn repair_response_packet_from_bytes(
nonce: Nonce,
) -> Option<Packet> {
let mut packet = Packet::default();
packet.meta.size = bytes.len() + SIZE_OF_NONCE;
if packet.meta.size > packet.data.len() {
let size = bytes.len() + SIZE_OF_NONCE;
if size > packet.buffer_mut().len() {
return None;
}
packet.meta.size = size;
packet.meta.set_socket_addr(dest);
packet.data[..bytes.len()].copy_from_slice(&bytes);
let mut wr = io::Cursor::new(&mut packet.data[bytes.len()..]);
packet.buffer_mut()[..bytes.len()].copy_from_slice(&bytes);
let mut wr = io::Cursor::new(&mut packet.buffer_mut()[bytes.len()..]);
bincode::serialize_into(&mut wr, &nonce).expect("Buffer not large enough to fit nonce");
Some(packet)
}

View File

@ -814,7 +814,7 @@ mod tests {
.into_iter()
.filter_map(|p| {
assert_eq!(repair_response::nonce(p).unwrap(), nonce);
Shred::new_from_serialized_shred(p.data.to_vec()).ok()
Shred::new_from_serialized_shred(p.data().to_vec()).ok()
})
.collect();
assert!(!rv.is_empty());
@ -898,7 +898,7 @@ mod tests {
.into_iter()
.filter_map(|p| {
assert_eq!(repair_response::nonce(p).unwrap(), nonce);
Shred::new_from_serialized_shred(p.data.to_vec()).ok()
Shred::new_from_serialized_shred(p.data().to_vec()).ok()
})
.collect();
assert_eq!(rv[0].index(), 1);
@ -1347,7 +1347,7 @@ mod tests {
fn verify_responses<'a>(request: &ShredRepairType, packets: impl Iterator<Item = &'a Packet>) {
for packet in packets {
let shred_payload = packet.data.to_vec();
let shred_payload = packet.data().to_vec();
let shred = Shred::new_from_serialized_shred(shred_payload).unwrap();
request.verify_response(&shred);
}

View File

@ -120,7 +120,7 @@ pub mod tests {
let keypair = Keypair::new();
shred.sign(&keypair);
batches[0][0].data[0..shred.payload().len()].copy_from_slice(shred.payload());
batches[0][0].buffer_mut()[..shred.payload().len()].copy_from_slice(shred.payload());
batches[0][0].meta.size = shred.payload().len();
let mut shred = Shred::new_from_data(
@ -134,7 +134,7 @@ pub mod tests {
0xc0de,
);
shred.sign(&keypair);
batches[1][0].data[0..shred.payload().len()].copy_from_slice(shred.payload());
batches[1][0].buffer_mut()[..shred.payload().len()].copy_from_slice(shred.payload());
batches[1][0].meta.size = shred.payload().len();
let expected: HashSet<u64> = [0xc0de_dead, 0xdead_c0de].iter().cloned().collect();
@ -169,7 +169,7 @@ pub mod tests {
0xc0de,
);
shred.sign(&leader_keypair);
batches[0][0].data[0..shred.payload().len()].copy_from_slice(shred.payload());
batches[0][0].buffer_mut()[..shred.payload().len()].copy_from_slice(shred.payload());
batches[0][0].meta.size = shred.payload().len();
let mut shred = Shred::new_from_data(
@ -184,7 +184,7 @@ pub mod tests {
);
let wrong_keypair = Keypair::new();
shred.sign(&wrong_keypair);
batches[0][1].data[0..shred.payload().len()].copy_from_slice(shred.payload());
batches[0][1].buffer_mut()[..shred.payload().len()].copy_from_slice(shred.payload());
batches[0][1].meta.size = shred.payload().len();
let num_packets = solana_perf::sigverify::count_packets_in_batches(&batches);

View File

@ -356,14 +356,11 @@ pub fn deserialize_packets<'a>(
/// Read the transaction message from packet data
pub fn packet_message(packet: &Packet) -> Result<&[u8], DeserializedPacketError> {
let (sig_len, sig_size) =
decode_shortu16_len(&packet.data).map_err(DeserializedPacketError::ShortVecError)?;
decode_shortu16_len(packet.data()).map_err(DeserializedPacketError::ShortVecError)?;
sig_len
.checked_mul(size_of::<Signature>())
.and_then(|v| v.checked_add(sig_size))
.map(|msg_start| {
let msg_end = packet.meta.size;
&packet.data[msg_start..msg_end]
})
.and_then(|msg_start| packet.data().get(msg_start..))
.ok_or(DeserializedPacketError::SignatureOverflowed(sig_size))
}

View File

@ -23,7 +23,7 @@ use {
solana_perf::packet::{Packet, PacketBatch},
solana_rayon_threadlimit::get_thread_count,
solana_runtime::{bank::Bank, bank_forks::BankForks},
solana_sdk::{clock::Slot, packet::PACKET_DATA_SIZE, pubkey::Pubkey},
solana_sdk::{clock::Slot, pubkey::Pubkey},
std::{
cmp::Reverse,
collections::{HashMap, HashSet},
@ -363,11 +363,7 @@ where
inc_new_counter_debug!("streamer-recv_window-invalid_or_unnecessary_packet", 1);
return None;
}
// shred fetch stage should be sending packets
// with sufficiently large buffers. Needed to ensure
// call to `new_from_serialized_shred` is safe.
assert_eq!(packet.data.len(), PACKET_DATA_SIZE);
let serialized_shred = packet.data.to_vec();
let serialized_shred = packet.data().to_vec();
let shred = Shred::new_from_serialized_shred(serialized_shred).ok()?;
if !shred_filter(&shred, working_bank.clone(), last_root) {
return None;

View File

@ -260,7 +260,7 @@ pub fn cluster_info_retransmit() {
let retransmit_peers: Vec<_> = peers.iter().collect();
retransmit_to(
&retransmit_peers,
&p.data[..p.meta.size],
p.data(),
&tn1,
false,
&SocketAddrSpace::Unspecified,
@ -270,7 +270,7 @@ pub fn cluster_info_retransmit() {
.map(|s| {
let mut p = Packet::default();
s.set_read_timeout(Some(Duration::new(1, 0))).unwrap();
let res = s.recv_from(&mut p.data);
let res = s.recv_from(p.buffer_mut());
res.is_err() //true if failed to receive the retransmit packet
})
.collect();

View File

@ -287,7 +287,7 @@ impl Shred {
pub fn copy_to_packet(&self, packet: &mut Packet) {
let payload = self.payload();
let size = payload.len();
packet.data[..size].copy_from_slice(&payload[..]);
packet.buffer_mut()[..size].copy_from_slice(&payload[..]);
packet.meta.size = size;
}
@ -575,7 +575,7 @@ pub fn get_shred_slot_index_type(
}
};
let shred_type = match ShredType::try_from(p.data[OFFSET_OF_SHRED_TYPE]) {
let shred_type = match ShredType::try_from(p.data()[OFFSET_OF_SHRED_TYPE]) {
Err(_) => {
stats.bad_shred_type += 1;
return None;
@ -733,7 +733,7 @@ mod tests {
let shred = Shred::new_from_data(10, 0, 1000, &[1, 2, 3], ShredFlags::empty(), 0, 1, 0);
let mut packet = Packet::default();
shred.copy_to_packet(&mut packet);
let shred_res = Shred::new_from_serialized_shred(packet.data.to_vec());
let shred_res = Shred::new_from_serialized_shred(packet.data().to_vec());
assert_matches!(
shred.parent(),
Err(Error::InvalidParentOffset {
@ -825,7 +825,7 @@ mod tests {
200, // version
);
shred.copy_to_packet(&mut packet);
packet.data[OFFSET_OF_SHRED_TYPE] = u8::MAX;
packet.buffer_mut()[OFFSET_OF_SHRED_TYPE] = u8::MAX;
assert_eq!(None, get_shred_slot_index_type(&packet, &mut stats));
assert_eq!(1, stats.bad_shred_type);
@ -892,13 +892,13 @@ mod tests {
data.iter().skip(skip).copied()
});
let mut packet = Packet::default();
packet.data[..payload.len()].copy_from_slice(&payload);
packet.buffer_mut()[..payload.len()].copy_from_slice(&payload);
packet.meta.size = payload.len();
assert_eq!(shred.bytes_to_store(), payload);
assert_eq!(shred, Shred::new_from_serialized_shred(payload).unwrap());
assert_eq!(
shred.reference_tick(),
Shred::reference_tick_from_data(&packet.data).unwrap()
Shred::reference_tick_from_data(packet.data()).unwrap()
);
assert_eq!(Shred::get_slot_from_packet(&packet), Some(shred.slot()));
assert_eq!(
@ -933,13 +933,13 @@ mod tests {
assert_matches!(shred.sanitize(), Ok(()));
let payload = bs58_decode(PAYLOAD);
let mut packet = Packet::default();
packet.data[..payload.len()].copy_from_slice(&payload);
packet.buffer_mut()[..payload.len()].copy_from_slice(&payload);
packet.meta.size = payload.len();
assert_eq!(shred.bytes_to_store(), payload);
assert_eq!(shred, Shred::new_from_serialized_shred(payload).unwrap());
assert_eq!(
shred.reference_tick(),
Shred::reference_tick_from_data(&packet.data).unwrap()
Shred::reference_tick_from_data(packet.data()).unwrap()
);
assert_eq!(Shred::get_slot_from_packet(&packet), Some(shred.slot()));
assert_eq!(
@ -981,7 +981,7 @@ mod tests {
parity_shard.iter().skip(skip).copied()
});
let mut packet = Packet::default();
packet.data[..payload.len()].copy_from_slice(&payload);
packet.buffer_mut()[..payload.len()].copy_from_slice(&payload);
packet.meta.size = payload.len();
assert_eq!(shred.bytes_to_store(), payload);
assert_eq!(shred, Shred::new_from_serialized_shred(payload).unwrap());

View File

@ -62,12 +62,9 @@ pub fn verify_shred_cpu(packet: &Packet, slot_leaders: &HashMap<u64, [u8; 32]>)
};
trace!("slot {}", slot);
let pubkey = slot_leaders.get(&slot)?;
if packet.meta.size < sig_end {
return Some(0);
}
let signature = Signature::new(&packet.data[sig_start..sig_end]);
let signature = Signature::new(packet.data().get(sig_start..sig_end)?);
trace!("signature {}", signature);
if !signature.verify(pubkey, &packet.data[msg_start..msg_end]) {
if !signature.verify(pubkey, packet.data().get(msg_start..msg_end)?) {
return Some(0);
}
Some(1)
@ -307,14 +304,9 @@ fn sign_shred_cpu(keypair: &Keypair, packet: &mut Packet) {
let sig_start = 0;
let sig_end = sig_start + size_of::<Signature>();
let msg_start = sig_end;
let msg_end = packet.meta.size;
assert!(
packet.meta.size >= msg_end,
"packet is not large enough for a signature"
);
let signature = keypair.sign_message(&packet.data[msg_start..msg_end]);
let signature = keypair.sign_message(&packet.data()[msg_start..]);
trace!("signature {:?}", signature);
packet.data[0..sig_end].copy_from_slice(signature.as_ref());
packet.buffer_mut()[..sig_end].copy_from_slice(signature.as_ref());
}
pub fn sign_shreds_cpu(keypair: &Keypair, batches: &mut [PacketBatch]) {
@ -443,7 +435,7 @@ pub fn sign_shreds_gpu(
let sig_ix = packet_ix + num_packets;
let sig_start = sig_ix * sig_size;
let sig_end = sig_start + sig_size;
packet.data[0..sig_size]
packet.buffer_mut()[..sig_size]
.copy_from_slice(&signatures_out[sig_start..sig_end]);
});
});
@ -476,7 +468,7 @@ pub mod tests {
let keypair = Keypair::new();
shred.sign(&keypair);
trace!("signature {}", shred.signature());
packet.data[0..shred.payload().len()].copy_from_slice(shred.payload());
packet.buffer_mut()[..shred.payload().len()].copy_from_slice(shred.payload());
packet.meta.size = shred.payload().len();
let leader_slots = [(slot, keypair.pubkey().to_bytes())]
@ -520,7 +512,7 @@ pub mod tests {
let keypair = Keypair::new();
shred.sign(&keypair);
batches[0].resize(1, Packet::default());
batches[0][0].data[0..shred.payload().len()].copy_from_slice(shred.payload());
batches[0][0].buffer_mut()[..shred.payload().len()].copy_from_slice(shred.payload());
batches[0][0].meta.size = shred.payload().len();
let leader_slots = [(slot, keypair.pubkey().to_bytes())]
@ -574,7 +566,7 @@ pub mod tests {
let keypair = Keypair::new();
shred.sign(&keypair);
batches[0].resize(1, Packet::default());
batches[0][0].data[0..shred.payload().len()].copy_from_slice(shred.payload());
batches[0][0].buffer_mut()[..shred.payload().len()].copy_from_slice(shred.payload());
batches[0][0].meta.size = shred.payload().len();
let leader_slots = [
@ -685,7 +677,7 @@ pub mod tests {
0xc0de,
);
batches[0].resize(1, Packet::default());
batches[0][0].data[0..shred.payload().len()].copy_from_slice(shred.payload());
batches[0][0].buffer_mut()[..shred.payload().len()].copy_from_slice(shred.payload());
batches[0][0].meta.size = shred.payload().len();
let pubkeys = [

View File

@ -142,11 +142,11 @@ fn verify_packet(packet: &mut Packet, reject_non_vote: bool) {
// get_packet_offsets should ensure pubkey_end and sig_end do
// not overflow packet.meta.size
let signature = Signature::new(&packet.data[sig_start..sig_end]);
let signature = Signature::new(&packet.data()[sig_start..sig_end]);
if !signature.verify(
&packet.data[pubkey_start..pubkey_end],
&packet.data[msg_start..msg_end],
&packet.data()[pubkey_start..pubkey_end],
&packet.data()[msg_start..msg_end],
) {
packet.meta.set_discard(true);
return;
@ -154,7 +154,7 @@ fn verify_packet(packet: &mut Packet, reject_non_vote: bool) {
// Check for tracer pubkey
if !packet.meta.is_tracer_packet()
&& &packet.data[pubkey_start..pubkey_end] == TRACER_KEY.as_ref()
&& &packet.data()[pubkey_start..pubkey_end] == TRACER_KEY.as_ref()
{
packet.meta.flags |= PacketFlags::TRACER_PACKET;
}
@ -202,7 +202,7 @@ fn do_get_packet_offsets(
// read the length of Transaction.signatures (serialized with short_vec)
let (sig_len_untrusted, sig_size) =
decode_shortu16_len(&packet.data).map_err(|_| PacketError::InvalidShortVec)?;
decode_shortu16_len(packet.data()).map_err(|_| PacketError::InvalidShortVec)?;
// Using msg_start_offset which is based on sig_len_untrusted introduces uncertainty.
// Ultimately, the actual sigverify will determine the uncertainty.
@ -221,7 +221,7 @@ fn do_get_packet_offsets(
// next byte indicates if the transaction is versioned. If the top bit
// is set, the remaining bits encode a version number. If the top bit is
// not set, this byte is the first byte of the message header.
let message_prefix = packet.data[msg_start_offset];
let message_prefix = packet.data()[msg_start_offset];
if message_prefix & MESSAGE_VERSION_PREFIX != 0 {
let version = message_prefix & !MESSAGE_VERSION_PREFIX;
match version {
@ -251,7 +251,7 @@ fn do_get_packet_offsets(
.ok_or(PacketError::InvalidSignatureLen)?;
// read MessageHeader.num_required_signatures (serialized with u8)
let sig_len_maybe_trusted = packet.data[msg_header_offset];
let sig_len_maybe_trusted = packet.data()[msg_header_offset];
let message_account_keys_len_offset = msg_header_offset
.checked_add(MESSAGE_HEADER_LENGTH)
@ -262,7 +262,7 @@ fn do_get_packet_offsets(
// num_readonly_signed_accounts, the first account is not debitable, and cannot be charged
// required transaction fees.
let readonly_signer_offset = msg_header_offset_plus_one;
if sig_len_maybe_trusted <= packet.data[readonly_signer_offset] {
if sig_len_maybe_trusted <= packet.data()[readonly_signer_offset] {
return Err(PacketError::PayerNotWritable);
}
@ -272,7 +272,7 @@ fn do_get_packet_offsets(
// read the length of Message.account_keys (serialized with short_vec)
let (pubkey_len, pubkey_len_size) =
decode_shortu16_len(&packet.data[message_account_keys_len_offset..])
decode_shortu16_len(&packet.data()[message_account_keys_len_offset..])
.map_err(|_| PacketError::InvalidShortVec)?;
let pubkey_start = message_account_keys_len_offset
@ -352,7 +352,7 @@ fn check_for_simple_vote_transaction(
.ok_or(PacketError::InvalidLen)?;
let (instruction_len, instruction_len_size) =
decode_shortu16_len(&packet.data[instructions_len_offset..])
decode_shortu16_len(&packet.data()[instructions_len_offset..])
.map_err(|_| PacketError::InvalidLen)?;
// skip if has more than 1 instruction
@ -370,7 +370,7 @@ fn check_for_simple_vote_transaction(
.filter(|v| *v <= packet.meta.size)
.ok_or(PacketError::InvalidLen)?;
let instruction_program_id_index: usize = usize::from(packet.data[instruction_start]);
let instruction_program_id_index: usize = usize::from(packet.data()[instruction_start]);
if instruction_program_id_index >= packet_offsets.pubkey_len as usize {
return Err(PacketError::InvalidProgramIdIndex);
@ -384,7 +384,7 @@ fn check_for_simple_vote_transaction(
.checked_add(size_of::<Pubkey>())
.ok_or(PacketError::InvalidLen)?;
if &packet.data[instruction_program_id_start..instruction_program_id_end]
if &packet.data()[instruction_program_id_start..instruction_program_id_end]
== solana_sdk::vote::program::id().as_ref()
{
packet.meta.flags |= PacketFlags::SIMPLE_VOTE_TX;
@ -492,7 +492,7 @@ impl Deduper {
return 1;
}
let mut hasher = AHasher::new_with_keys(self.seed.0, self.seed.1);
hasher.write(&packet.data[0..packet.meta.size]);
hasher.write(packet.data());
let hash = hasher.finish();
let len = self.filter.len();
let pos = (usize::try_from(hash).unwrap()).wrapping_rem(len);
@ -846,8 +846,8 @@ mod tests {
let tx = test_tx();
let mut packet = Packet::from_data(None, tx).unwrap();
packet.data[0] = 0xff;
packet.data[1] = 0xff;
packet.buffer_mut()[0] = 0xff;
packet.buffer_mut()[1] = 0xff;
packet.meta.size = 2;
let res = sigverify::do_get_packet_offsets(&packet, 0);
@ -919,7 +919,7 @@ mod tests {
let mut packet = Packet::from_data(None, tx).unwrap();
// Make the signatures len huge
packet.data[0] = 0x7f;
packet.buffer_mut()[0] = 0x7f;
let res = sigverify::do_get_packet_offsets(&packet, 0);
assert_eq!(res, Err(PacketError::InvalidSignatureLen));
@ -931,10 +931,10 @@ mod tests {
let mut packet = Packet::from_data(None, tx).unwrap();
// Make the signatures len huge
packet.data[0] = 0xff;
packet.data[1] = 0xff;
packet.data[2] = 0xff;
packet.data[3] = 0xff;
packet.buffer_mut()[0] = 0xff;
packet.buffer_mut()[1] = 0xff;
packet.buffer_mut()[2] = 0xff;
packet.buffer_mut()[3] = 0xff;
let res = sigverify::do_get_packet_offsets(&packet, 0);
assert_eq!(res, Err(PacketError::InvalidShortVec));
@ -948,7 +948,7 @@ mod tests {
let res = sigverify::do_get_packet_offsets(&packet, 0);
// make pubkey len huge
packet.data[res.unwrap().pubkey_start as usize - 1] = 0x7f;
packet.buffer_mut()[res.unwrap().pubkey_start as usize - 1] = 0x7f;
let res = sigverify::do_get_packet_offsets(&packet, 0);
assert_eq!(res, Err(PacketError::InvalidPubkeyLen));
@ -982,7 +982,7 @@ mod tests {
let res = sigverify::do_get_packet_offsets(&packet, 0);
// set message version to 1
packet.data[res.unwrap().msg_start as usize] = MESSAGE_VERSION_PREFIX + 1;
packet.buffer_mut()[res.unwrap().msg_start as usize] = MESSAGE_VERSION_PREFIX + 1;
let res = sigverify::do_get_packet_offsets(&packet, 0);
assert_eq!(res, Err(PacketError::UnsupportedVersion));
@ -997,10 +997,11 @@ mod tests {
// set message version to 0
let msg_start = legacy_offsets.msg_start as usize;
let msg_bytes = packet.data[msg_start..packet.meta.size].to_vec();
packet.data[msg_start] = MESSAGE_VERSION_PREFIX;
let msg_bytes = packet.data()[msg_start..].to_vec();
packet.buffer_mut()[msg_start] = MESSAGE_VERSION_PREFIX;
packet.meta.size += 1;
packet.data[msg_start + 1..packet.meta.size].copy_from_slice(&msg_bytes);
let msg_end = packet.meta.size;
packet.buffer_mut()[msg_start + 1..msg_end].copy_from_slice(&msg_bytes);
let offsets = sigverify::do_get_packet_offsets(&packet, 0).unwrap();
let expected_offsets = {
@ -1119,7 +1120,7 @@ mod tests {
// jumble some data to test failure
if modify_data {
packet.data[20] = packet.data[20].wrapping_add(10);
packet.buffer_mut()[20] = packet.data()[20].wrapping_add(10);
}
let mut batches = generate_packet_batches(&packet, n, 2);
@ -1185,7 +1186,7 @@ mod tests {
let num_batches = 3;
let mut batches = generate_packet_batches(&packet, n, num_batches);
packet.data[40] = packet.data[40].wrapping_add(8);
packet.buffer_mut()[40] = packet.data()[40].wrapping_add(8);
batches[0].push(packet);
@ -1229,8 +1230,8 @@ mod tests {
let packet = thread_rng().gen_range(0, batches[batch].len());
let offset = thread_rng().gen_range(0, batches[batch][packet].meta.size);
let add = thread_rng().gen_range(0, 255);
batches[batch][packet].data[offset] =
batches[batch][packet].data[offset].wrapping_add(add);
batches[batch][packet].buffer_mut()[offset] =
batches[batch][packet].data()[offset].wrapping_add(add);
}
let batch_to_disable = thread_rng().gen_range(0, batches.len());
@ -1504,7 +1505,7 @@ mod tests {
.filter(|p| !p.meta.discard())
.for_each(|p| start.push(p.clone()))
});
start.sort_by_key(|p| p.data);
start.sort_by(|a, b| a.data().cmp(b.data()));
let packet_count = count_valid_packets(&batches, |_| ());
let res = shrink_batches(&mut batches);
@ -1517,7 +1518,7 @@ mod tests {
.filter(|p| !p.meta.discard())
.for_each(|p| end.push(p.clone()))
});
end.sort_by_key(|p| p.data);
end.sort_by(|a, b| a.data().cmp(b.data()));
let packet_count2 = count_valid_packets(&batches, |_| ());
assert_eq!(packet_count, packet_count2);
assert_eq!(start, end);

View File

@ -38,13 +38,30 @@ pub struct Meta {
#[derive(Clone, Eq)]
#[repr(C)]
pub struct Packet {
pub data: [u8; PACKET_DATA_SIZE],
// Bytes past Packet.meta.size are not valid to read from.
// Use Packet.data() to read from the buffer.
buffer: [u8; PACKET_DATA_SIZE],
pub meta: Meta,
}
impl Packet {
pub fn new(data: [u8; PACKET_DATA_SIZE], meta: Meta) -> Self {
Self { data, meta }
pub fn new(buffer: [u8; PACKET_DATA_SIZE], meta: Meta) -> Self {
Self { buffer, meta }
}
/// Returns an immutable reference to the underlying buffer up to
/// Packet.meta.size. The rest of the buffer is not valid to read from.
#[inline]
pub fn data(&self) -> &[u8] {
&self.buffer[..self.meta.size]
}
/// Returns a mutable reference to the entirety of the underlying buffer to
/// write into. The caller is responsible for updating Packet.meta.size
/// after writing to the buffer.
#[inline]
pub fn buffer_mut(&mut self) -> &mut [u8] {
&mut self.buffer[..]
}
pub fn from_data<T: Serialize>(dest: Option<&SocketAddr>, data: T) -> Result<Self> {
@ -58,7 +75,7 @@ impl Packet {
dest: Option<&SocketAddr>,
data: &T,
) -> Result<()> {
let mut wr = io::Cursor::new(&mut packet.data[..]);
let mut wr = io::Cursor::new(packet.buffer_mut());
bincode::serialize_into(&mut wr, data)?;
let len = wr.position() as usize;
packet.meta.size = len;
@ -73,7 +90,7 @@ impl Packet {
T: serde::de::DeserializeOwned,
I: std::slice::SliceIndex<[u8], Output = [u8]>,
{
let data = &self.data[0..self.meta.size];
let data = self.data();
let bytes = data.get(index).ok_or(bincode::ErrorKind::SizeLimit)?;
bincode::options()
.with_limit(PACKET_DATA_SIZE as u64)
@ -98,7 +115,7 @@ impl fmt::Debug for Packet {
impl Default for Packet {
fn default() -> Packet {
Packet {
data: unsafe { std::mem::MaybeUninit::uninit().assume_init() },
buffer: unsafe { std::mem::MaybeUninit::uninit().assume_init() },
meta: Meta::default(),
}
}
@ -106,9 +123,7 @@ impl Default for Packet {
impl PartialEq for Packet {
fn eq(&self, other: &Packet) -> bool {
let self_data: &[u8] = self.data.as_ref();
let other_data: &[u8] = other.data.as_ref();
self.meta == other.meta && self_data[..self.meta.size] == other_data[..self.meta.size]
self.meta == other.meta && self.data() == other.data()
}
}

View File

@ -19,7 +19,7 @@ pub async fn recv_mmsg(
let mut i = 0;
for p in packets.iter_mut().take(count) {
p.meta.size = 0;
match socket.try_recv_from(&mut p.data) {
match socket.try_recv_from(p.buffer_mut()) {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
break;
}

View File

@ -138,13 +138,9 @@ mod tests {
let packet = Packet::default();
let sent = multi_target_send(
&sender,
&packet.data[..packet.meta.size],
&[&addr, &addr2, &addr3, &addr4],
)
.await
.ok();
let sent = multi_target_send(&sender, packet.data(), &[&addr, &addr2, &addr3, &addr4])
.await
.ok();
assert_eq!(sent, Some(()));
let mut packets = vec![Packet::default(); 32];

View File

@ -67,7 +67,7 @@ pub fn send_to(
for p in batch.iter() {
let addr = p.meta.socket_addr();
if socket_addr_space.check(&addr) {
socket.send_to(&p.data[..p.meta.size], &addr)?;
socket.send_to(p.data(), &addr)?;
}
}
Ok(())
@ -135,14 +135,14 @@ mod tests {
let mut p2 = Packet::default();
p1.meta.size = 1;
p1.data[0] = 0;
p1.buffer_mut()[0] = 0;
p2.meta.size = 1;
p2.data[0] = 0;
p2.buffer_mut()[0] = 0;
assert!(p1 == p2);
p2.data[0] = 4;
p2.buffer_mut()[0] = 4;
assert!(p1 != p2);
}

View File

@ -189,7 +189,7 @@ fn handle_chunk(
if let Some(batch) = maybe_batch.as_mut() {
let end = chunk.offset as usize + chunk.bytes.len();
batch[0].data[chunk.offset as usize..end].copy_from_slice(&chunk.bytes);
batch[0].buffer_mut()[chunk.offset as usize..end].copy_from_slice(&chunk.bytes);
batch[0].meta.size = std::cmp::max(batch[0].meta.size, end);
stats.total_chunks_received.fetch_add(1, Ordering::Relaxed);
}

View File

@ -22,7 +22,7 @@ pub fn recv_mmsg(socket: &UdpSocket, packets: &mut [Packet]) -> io::Result</*num
let count = cmp::min(NUM_RCVMMSGS, packets.len());
for p in packets.iter_mut().take(count) {
p.meta.size = 0;
match socket.recv_from(&mut p.data) {
match socket.recv_from(p.buffer_mut()) {
Err(_) if i > 0 => {
break;
}
@ -84,9 +84,10 @@ pub fn recv_mmsg(sock: &UdpSocket, packets: &mut [Packet]) -> io::Result</*num p
for (packet, hdr, iov, addr) in
izip!(packets.iter_mut(), &mut hdrs, &mut iovs, &mut addrs).take(count)
{
let buffer = packet.buffer_mut();
*iov = iovec {
iov_base: packet.data.as_mut_ptr() as *mut libc::c_void,
iov_len: packet.data.len(),
iov_base: buffer.as_mut_ptr() as *mut libc::c_void,
iov_len: buffer.len(),
};
hdr.msg_hdr.msg_name = addr as *mut _ as *mut _;
hdr.msg_hdr.msg_namelen = SOCKADDR_STORAGE_SIZE as socklen_t;

View File

@ -242,12 +242,7 @@ mod tests {
let packet = Packet::default();
let sent = multi_target_send(
&sender,
&packet.data[..packet.meta.size],
&[&addr, &addr2, &addr3, &addr4],
)
.ok();
let sent = multi_target_send(&sender, packet.data(), &[&addr, &addr2, &addr3, &addr4]).ok();
assert_eq!(sent, Some(()));
let mut packets = vec![Packet::default(); 32];

View File

@ -279,7 +279,7 @@ impl StreamerSendStats {
fn record(&mut self, pkt: &Packet) {
let ent = self.host_map.entry(pkt.meta.addr).or_default();
ent.count += 1;
ent.bytes += pkt.data.len() as u64;
ent.bytes += pkt.data().len() as u64;
}
}
@ -296,9 +296,7 @@ fn recv_send(
}
let packets = packet_batch.iter().filter_map(|pkt| {
let addr = pkt.meta.socket_addr();
socket_addr_space
.check(&addr)
.then(|| (&pkt.data[..pkt.meta.size], addr))
socket_addr_space.check(&addr).then(|| (pkt.data(), addr))
});
batch_send(sock, &packets.collect::<Vec<_>>())?;
Ok(())
@ -478,7 +476,7 @@ mod test {
for i in 0..NUM_PACKETS {
let mut p = Packet::default();
{
p.data[0] = i as u8;
p.buffer_mut()[0] = i as u8;
p.meta.size = PACKET_DATA_SIZE;
p.meta.set_socket_addr(&addr);
}