streamer: Add nonblocking versions of sendmmsg / recvmmsg (#25415)

This commit is contained in:
Jon Cinque 2022-05-20 22:40:40 +02:00 committed by GitHub
parent 851958f77a
commit a5792885ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 472 additions and 0 deletions

View File

@ -1,4 +1,5 @@
#![allow(clippy::integer_arithmetic)]
pub mod nonblocking;
pub mod packet;
pub mod quic;
pub mod recvmmsg;

View File

@ -0,0 +1,2 @@
pub mod recvmmsg;
pub mod sendmmsg;

View File

@ -0,0 +1,195 @@
//! The `recvmmsg` module provides a nonblocking recvmmsg() API implementation
use {
crate::{
packet::{Meta, Packet},
recvmmsg::NUM_RCVMMSGS,
},
std::{cmp, io},
tokio::net::UdpSocket,
};
pub async fn recv_mmsg(
socket: &UdpSocket,
packets: &mut [Packet],
) -> io::Result</*num packets:*/ usize> {
debug_assert!(packets.iter().all(|pkt| pkt.meta == Meta::default()));
let count = cmp::min(NUM_RCVMMSGS, packets.len());
socket.readable().await?;
let mut i = 0;
for p in packets.iter_mut().take(count) {
p.meta.size = 0;
match socket.try_recv_from(&mut p.data) {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
break;
}
Err(e) => {
return Err(e);
}
Ok((nrecv, from)) => {
p.meta.size = nrecv;
p.meta.set_addr(&from);
}
}
i += 1;
}
Ok(i)
}
#[cfg(test)]
mod tests {
use {
crate::{nonblocking::recvmmsg::*, packet::PACKET_DATA_SIZE},
std::{net::SocketAddr, time::Instant},
tokio::net::UdpSocket,
};
type TestConfig = (UdpSocket, SocketAddr, UdpSocket, SocketAddr);
async fn test_setup_reader_sender(ip_str: &str) -> io::Result<TestConfig> {
let reader = UdpSocket::bind(ip_str).await?;
let addr = reader.local_addr()?;
let sender = UdpSocket::bind(ip_str).await?;
let saddr = sender.local_addr()?;
Ok((reader, addr, sender, saddr))
}
const TEST_NUM_MSGS: usize = 32;
async fn test_one_iter((reader, addr, sender, saddr): TestConfig) {
let sent = TEST_NUM_MSGS - 1;
for _ in 0..sent {
let data = [0; PACKET_DATA_SIZE];
sender.send_to(&data[..], &addr).await.unwrap();
}
let mut packets = vec![Packet::default(); TEST_NUM_MSGS];
let recv = recv_mmsg(&reader, &mut packets[..]).await.unwrap();
assert_eq!(sent, recv);
for packet in packets.iter().take(recv) {
assert_eq!(packet.meta.size, PACKET_DATA_SIZE);
assert_eq!(packet.meta.addr(), saddr);
}
}
#[tokio::test]
async fn test_recv_mmsg_one_iter() {
test_one_iter(test_setup_reader_sender("127.0.0.1:0").await.unwrap()).await;
match test_setup_reader_sender("::1:0").await {
Ok(config) => test_one_iter(config).await,
Err(e) => warn!("Failed to configure IPv6: {:?}", e),
}
}
async fn test_multi_iter((reader, addr, sender, saddr): TestConfig) {
let sent = TEST_NUM_MSGS + 10;
for _ in 0..sent {
let data = [0; PACKET_DATA_SIZE];
sender.send_to(&data[..], &addr).await.unwrap();
}
let mut packets = vec![Packet::default(); TEST_NUM_MSGS];
let recv = recv_mmsg(&reader, &mut packets[..]).await.unwrap();
assert_eq!(TEST_NUM_MSGS, recv);
for packet in packets.iter().take(recv) {
assert_eq!(packet.meta.size, PACKET_DATA_SIZE);
assert_eq!(packet.meta.addr(), saddr);
}
packets
.iter_mut()
.for_each(|pkt| pkt.meta = Meta::default());
let recv = recv_mmsg(&reader, &mut packets[..]).await.unwrap();
assert_eq!(sent - TEST_NUM_MSGS, recv);
for packet in packets.iter().take(recv) {
assert_eq!(packet.meta.size, PACKET_DATA_SIZE);
assert_eq!(packet.meta.addr(), saddr);
}
}
#[tokio::test]
async fn test_recv_mmsg_multi_iter() {
test_multi_iter(test_setup_reader_sender("127.0.0.1:0").await.unwrap()).await;
match test_setup_reader_sender("::1:0").await {
Ok(config) => test_multi_iter(config).await,
Err(e) => warn!("Failed to configure IPv6: {:?}", e),
}
}
#[tokio::test]
async fn test_recv_mmsg_multi_iter_timeout() {
let reader = UdpSocket::bind("127.0.0.1:0").await.expect("bind");
let addr = reader.local_addr().unwrap();
let sender = UdpSocket::bind("127.0.0.1:0").await.expect("bind");
let saddr = sender.local_addr().unwrap();
let sent = TEST_NUM_MSGS;
for _ in 0..sent {
let data = [0; PACKET_DATA_SIZE];
sender.send_to(&data[..], &addr).await.unwrap();
}
let start = Instant::now();
let mut packets = vec![Packet::default(); TEST_NUM_MSGS];
let recv = recv_mmsg(&reader, &mut packets[..]).await.unwrap();
assert_eq!(TEST_NUM_MSGS, recv);
for packet in packets.iter().take(recv) {
assert_eq!(packet.meta.size, PACKET_DATA_SIZE);
assert_eq!(packet.meta.addr(), saddr);
}
packets
.iter_mut()
.for_each(|pkt| pkt.meta = Meta::default());
let _recv = recv_mmsg(&reader, &mut packets[..]).await;
assert!(start.elapsed().as_secs() < 5);
}
#[tokio::test]
async fn test_recv_mmsg_multi_addrs() {
let reader = UdpSocket::bind("127.0.0.1:0").await.expect("bind");
let addr = reader.local_addr().unwrap();
let sender1 = UdpSocket::bind("127.0.0.1:0").await.expect("bind");
let saddr1 = sender1.local_addr().unwrap();
let sent1 = TEST_NUM_MSGS - 1;
let sender2 = UdpSocket::bind("127.0.0.1:0").await.expect("bind");
let saddr2 = sender2.local_addr().unwrap();
let sent2 = TEST_NUM_MSGS + 1;
for _ in 0..sent1 {
let data = [0; PACKET_DATA_SIZE];
sender1.send_to(&data[..], &addr).await.unwrap();
}
for _ in 0..sent2 {
let data = [0; PACKET_DATA_SIZE];
sender2.send_to(&data[..], &addr).await.unwrap();
}
let mut packets = vec![Packet::default(); TEST_NUM_MSGS];
let recv = recv_mmsg(&reader, &mut packets[..]).await.unwrap();
assert_eq!(TEST_NUM_MSGS, recv);
for packet in packets.iter().take(sent1) {
assert_eq!(packet.meta.size, PACKET_DATA_SIZE);
assert_eq!(packet.meta.addr(), saddr1);
}
for packet in packets.iter().skip(sent1).take(recv - sent1) {
assert_eq!(packet.meta.size, PACKET_DATA_SIZE);
assert_eq!(packet.meta.addr(), saddr2);
}
packets
.iter_mut()
.for_each(|pkt| pkt.meta = Meta::default());
let recv = recv_mmsg(&reader, &mut packets[..]).await.unwrap();
assert_eq!(sent1 + sent2 - TEST_NUM_MSGS, recv);
for packet in packets.iter().take(recv) {
assert_eq!(packet.meta.size, PACKET_DATA_SIZE);
assert_eq!(packet.meta.addr(), saddr2);
}
}
}

View File

@ -0,0 +1,274 @@
//! The `sendmmsg` module provides a nonblocking sendmmsg() API implementation
use {
crate::sendmmsg::SendPktsError,
futures_util::future::join_all,
std::{borrow::Borrow, iter::repeat, net::SocketAddr},
tokio::net::UdpSocket,
};
pub async fn batch_send<S, T>(sock: &UdpSocket, packets: &[(T, S)]) -> Result<(), SendPktsError>
where
S: Borrow<SocketAddr>,
T: AsRef<[u8]>,
{
let mut num_failed = 0;
let mut erropt = None;
let futures = packets
.iter()
.map(|(p, a)| sock.send_to(p.as_ref(), a.borrow()))
.collect::<Vec<_>>();
let results = join_all(futures).await;
for result in results {
if let Err(e) = result {
num_failed += 1;
if erropt.is_none() {
erropt = Some(e);
}
}
}
if let Some(err) = erropt {
Err(SendPktsError::IoError(err, num_failed))
} else {
Ok(())
}
}
pub async fn multi_target_send<S, T>(
sock: &UdpSocket,
packet: T,
dests: &[S],
) -> Result<(), SendPktsError>
where
S: Borrow<SocketAddr>,
T: AsRef<[u8]>,
{
let dests = dests.iter().map(Borrow::borrow);
let pkts: Vec<_> = repeat(&packet).zip(dests).collect();
batch_send(sock, &pkts).await
}
#[cfg(test)]
mod tests {
use {
crate::{
nonblocking::{
recvmmsg::recv_mmsg,
sendmmsg::{batch_send, multi_target_send},
},
packet::Packet,
sendmmsg::SendPktsError,
},
solana_sdk::packet::PACKET_DATA_SIZE,
std::{
io::ErrorKind,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
},
tokio::net::UdpSocket,
};
#[tokio::test]
async fn test_send_mmsg_one_dest() {
let reader = UdpSocket::bind("127.0.0.1:0").await.expect("bind");
let addr = reader.local_addr().unwrap();
let sender = UdpSocket::bind("127.0.0.1:0").await.expect("bind");
let packets: Vec<_> = (0..32).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
let packet_refs: Vec<_> = packets.iter().map(|p| (&p[..], &addr)).collect();
let sent = batch_send(&sender, &packet_refs[..]).await.ok();
assert_eq!(sent, Some(()));
let mut packets = vec![Packet::default(); 32];
let recv = recv_mmsg(&reader, &mut packets[..]).await.unwrap();
assert_eq!(32, recv);
}
#[tokio::test]
async fn test_send_mmsg_multi_dest() {
let reader = UdpSocket::bind("127.0.0.1:0").await.expect("bind");
let addr = reader.local_addr().unwrap();
let reader2 = UdpSocket::bind("127.0.0.1:0").await.expect("bind");
let addr2 = reader2.local_addr().unwrap();
let sender = UdpSocket::bind("127.0.0.1:0").await.expect("bind");
let packets: Vec<_> = (0..32).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
let packet_refs: Vec<_> = packets
.iter()
.enumerate()
.map(|(i, p)| {
if i < 16 {
(&p[..], &addr)
} else {
(&p[..], &addr2)
}
})
.collect();
let sent = batch_send(&sender, &packet_refs[..]).await.ok();
assert_eq!(sent, Some(()));
let mut packets = vec![Packet::default(); 32];
let recv = recv_mmsg(&reader, &mut packets[..]).await.unwrap();
assert_eq!(16, recv);
let mut packets = vec![Packet::default(); 32];
let recv = recv_mmsg(&reader2, &mut packets[..]).await.unwrap();
assert_eq!(16, recv);
}
#[tokio::test]
async fn test_multicast_msg() {
let reader = UdpSocket::bind("127.0.0.1:0").await.expect("bind");
let addr = reader.local_addr().unwrap();
let reader2 = UdpSocket::bind("127.0.0.1:0").await.expect("bind");
let addr2 = reader2.local_addr().unwrap();
let reader3 = UdpSocket::bind("127.0.0.1:0").await.expect("bind");
let addr3 = reader3.local_addr().unwrap();
let reader4 = UdpSocket::bind("127.0.0.1:0").await.expect("bind");
let addr4 = reader4.local_addr().unwrap();
let sender = UdpSocket::bind("127.0.0.1:0").await.expect("bind");
let packet = Packet::default();
let sent = multi_target_send(
&sender,
&packet.data[..packet.meta.size],
&[&addr, &addr2, &addr3, &addr4],
)
.await
.ok();
assert_eq!(sent, Some(()));
let mut packets = vec![Packet::default(); 32];
let recv = recv_mmsg(&reader, &mut packets[..]).await.unwrap();
assert_eq!(1, recv);
let mut packets = vec![Packet::default(); 32];
let recv = recv_mmsg(&reader2, &mut packets[..]).await.unwrap();
assert_eq!(1, recv);
let mut packets = vec![Packet::default(); 32];
let recv = recv_mmsg(&reader3, &mut packets[..]).await.unwrap();
assert_eq!(1, recv);
let mut packets = vec![Packet::default(); 32];
let recv = recv_mmsg(&reader4, &mut packets[..]).await.unwrap();
assert_eq!(1, recv);
}
#[tokio::test]
async fn test_intermediate_failures_mismatched_bind() {
let packets: Vec<_> = (0..3).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
let ip4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080);
let ip6 = SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 8080);
let packet_refs: Vec<_> = vec![
(&packets[0][..], &ip4),
(&packets[1][..], &ip6),
(&packets[2][..], &ip4),
];
let dest_refs: Vec<_> = vec![&ip4, &ip6, &ip4];
let sender = UdpSocket::bind("0.0.0.0:0").await.expect("bind");
if let Err(SendPktsError::IoError(_, num_failed)) =
batch_send(&sender, &packet_refs[..]).await
{
assert_eq!(num_failed, 1);
}
if let Err(SendPktsError::IoError(_, num_failed)) =
multi_target_send(&sender, &packets[0], &dest_refs).await
{
assert_eq!(num_failed, 1);
}
}
#[tokio::test]
async fn test_intermediate_failures_unreachable_address() {
let packets: Vec<_> = (0..5).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
let ipv4local = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080);
let ipv4broadcast = SocketAddr::new(IpAddr::V4(Ipv4Addr::BROADCAST), 8080);
let sender = UdpSocket::bind("0.0.0.0:0").await.expect("bind");
// test intermediate failures for batch_send
let packet_refs: Vec<_> = vec![
(&packets[0][..], &ipv4local),
(&packets[1][..], &ipv4broadcast),
(&packets[2][..], &ipv4local),
(&packets[3][..], &ipv4broadcast),
(&packets[4][..], &ipv4local),
];
if let Err(SendPktsError::IoError(ioerror, num_failed)) =
batch_send(&sender, &packet_refs[..]).await
{
assert!(matches!(ioerror.kind(), ErrorKind::PermissionDenied));
assert_eq!(num_failed, 2);
}
// test leading and trailing failures for batch_send
let packet_refs: Vec<_> = vec![
(&packets[0][..], &ipv4broadcast),
(&packets[1][..], &ipv4local),
(&packets[2][..], &ipv4broadcast),
(&packets[3][..], &ipv4local),
(&packets[4][..], &ipv4broadcast),
];
if let Err(SendPktsError::IoError(ioerror, num_failed)) =
batch_send(&sender, &packet_refs[..]).await
{
assert!(matches!(ioerror.kind(), ErrorKind::PermissionDenied));
assert_eq!(num_failed, 3);
}
// test consecutive intermediate failures for batch_send
let packet_refs: Vec<_> = vec![
(&packets[0][..], &ipv4local),
(&packets[1][..], &ipv4local),
(&packets[2][..], &ipv4broadcast),
(&packets[3][..], &ipv4broadcast),
(&packets[4][..], &ipv4local),
];
if let Err(SendPktsError::IoError(ioerror, num_failed)) =
batch_send(&sender, &packet_refs[..]).await
{
assert!(matches!(ioerror.kind(), ErrorKind::PermissionDenied));
assert_eq!(num_failed, 2);
}
// test intermediate failures for multi_target_send
let dest_refs: Vec<_> = vec![
&ipv4local,
&ipv4broadcast,
&ipv4local,
&ipv4broadcast,
&ipv4local,
];
if let Err(SendPktsError::IoError(ioerror, num_failed)) =
multi_target_send(&sender, &packets[0], &dest_refs).await
{
assert!(matches!(ioerror.kind(), ErrorKind::PermissionDenied));
assert_eq!(num_failed, 2);
}
// test leading and trailing failures for multi_target_send
let dest_refs: Vec<_> = vec![
&ipv4broadcast,
&ipv4local,
&ipv4broadcast,
&ipv4local,
&ipv4broadcast,
];
if let Err(SendPktsError::IoError(ioerror, num_failed)) =
multi_target_send(&sender, &packets[0], &dest_refs).await
{
assert!(matches!(ioerror.kind(), ErrorKind::PermissionDenied));
assert_eq!(num_failed, 3);
}
}
}