solana/streamer/src/nonblocking/sendmmsg.rs

276 lines
9.2 KiB
Rust

//! The `sendmmsg` module provides a nonblocking sendmmsg() API implementation
use {
crate::sendmmsg::SendPktsError,
futures_util::future::join_all,
std::{borrow::Borrow, iter::repeat, net::SocketAddr},
tokio::net::UdpSocket,
};
pub async fn batch_send<S, T>(sock: &UdpSocket, packets: &[(T, S)]) -> Result<(), SendPktsError>
where
S: Borrow<SocketAddr>,
T: AsRef<[u8]>,
{
let mut num_failed = 0;
let mut erropt = None;
let futures = packets
.iter()
.map(|(p, a)| sock.send_to(p.as_ref(), a.borrow()))
.collect::<Vec<_>>();
let results = join_all(futures).await;
for result in results {
if let Err(e) = result {
num_failed += 1;
if erropt.is_none() {
erropt = Some(e);
}
}
}
if let Some(err) = erropt {
Err(SendPktsError::IoError(err, num_failed))
} else {
Ok(())
}
}
pub async fn multi_target_send<S, T>(
sock: &UdpSocket,
packet: T,
dests: &[S],
) -> Result<(), SendPktsError>
where
S: Borrow<SocketAddr>,
T: AsRef<[u8]>,
{
let dests = dests.iter().map(Borrow::borrow);
let pkts: Vec<_> = repeat(&packet).zip(dests).collect();
batch_send(sock, &pkts).await
}
#[cfg(test)]
mod tests {
use {
crate::{
nonblocking::{
recvmmsg::{recv_mmsg, recv_mmsg_exact},
sendmmsg::{batch_send, multi_target_send},
},
packet::Packet,
sendmmsg::SendPktsError,
},
assert_matches::assert_matches,
solana_sdk::packet::PACKET_DATA_SIZE,
std::{
io::ErrorKind,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
},
tokio::net::UdpSocket,
};
#[tokio::test]
async fn test_send_mmsg_one_dest() {
let reader = UdpSocket::bind("127.0.0.1:0").await.expect("bind");
let addr = reader.local_addr().unwrap();
let sender = UdpSocket::bind("127.0.0.1:0").await.expect("bind");
let packets: Vec<_> = (0..32).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
let packet_refs: Vec<_> = packets.iter().map(|p| (&p[..], &addr)).collect();
let sent = batch_send(&sender, &packet_refs[..]).await.ok();
assert_eq!(sent, Some(()));
let mut packets = vec![Packet::default(); 32];
let recv = recv_mmsg_exact(&reader, &mut packets[..]).await.unwrap();
assert_eq!(32, recv);
}
#[tokio::test]
async fn test_send_mmsg_multi_dest() {
let reader = UdpSocket::bind("127.0.0.1:0").await.expect("bind");
let addr = reader.local_addr().unwrap();
let reader2 = UdpSocket::bind("127.0.0.1:0").await.expect("bind");
let addr2 = reader2.local_addr().unwrap();
let sender = UdpSocket::bind("127.0.0.1:0").await.expect("bind");
let packets: Vec<_> = (0..32).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
let packet_refs: Vec<_> = packets
.iter()
.enumerate()
.map(|(i, p)| {
if i < 16 {
(&p[..], &addr)
} else {
(&p[..], &addr2)
}
})
.collect();
let sent = batch_send(&sender, &packet_refs[..]).await.ok();
assert_eq!(sent, Some(()));
let mut packets = vec![Packet::default(); 16];
let recv = recv_mmsg_exact(&reader, &mut packets[..]).await.unwrap();
assert_eq!(16, recv);
let mut packets = vec![Packet::default(); 16];
let recv = recv_mmsg_exact(&reader2, &mut packets[..]).await.unwrap();
assert_eq!(16, recv);
}
#[tokio::test]
async fn test_multicast_msg() {
let reader = UdpSocket::bind("127.0.0.1:0").await.expect("bind");
let addr = reader.local_addr().unwrap();
let reader2 = UdpSocket::bind("127.0.0.1:0").await.expect("bind");
let addr2 = reader2.local_addr().unwrap();
let reader3 = UdpSocket::bind("127.0.0.1:0").await.expect("bind");
let addr3 = reader3.local_addr().unwrap();
let reader4 = UdpSocket::bind("127.0.0.1:0").await.expect("bind");
let addr4 = reader4.local_addr().unwrap();
let sender = UdpSocket::bind("127.0.0.1:0").await.expect("bind");
let packet = Packet::default();
let sent = multi_target_send(
&sender,
packet.data(..).unwrap(),
&[&addr, &addr2, &addr3, &addr4],
)
.await
.ok();
assert_eq!(sent, Some(()));
let mut packets = vec![Packet::default(); 32];
let recv = recv_mmsg(&reader, &mut packets[..]).await.unwrap();
assert_eq!(1, recv);
let mut packets = vec![Packet::default(); 32];
let recv = recv_mmsg(&reader2, &mut packets[..]).await.unwrap();
assert_eq!(1, recv);
let mut packets = vec![Packet::default(); 32];
let recv = recv_mmsg(&reader3, &mut packets[..]).await.unwrap();
assert_eq!(1, recv);
let mut packets = vec![Packet::default(); 32];
let recv = recv_mmsg(&reader4, &mut packets[..]).await.unwrap();
assert_eq!(1, recv);
}
#[tokio::test]
async fn test_intermediate_failures_mismatched_bind() {
let packets: Vec<_> = (0..3).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
let ip4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080);
let ip6 = SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 8080);
let packet_refs: Vec<_> = vec![
(&packets[0][..], &ip4),
(&packets[1][..], &ip6),
(&packets[2][..], &ip4),
];
let dest_refs: Vec<_> = vec![&ip4, &ip6, &ip4];
let sender = UdpSocket::bind("0.0.0.0:0").await.expect("bind");
if let Err(SendPktsError::IoError(_, num_failed)) =
batch_send(&sender, &packet_refs[..]).await
{
assert_eq!(num_failed, 1);
}
if let Err(SendPktsError::IoError(_, num_failed)) =
multi_target_send(&sender, &packets[0], &dest_refs).await
{
assert_eq!(num_failed, 1);
}
}
#[tokio::test]
async fn test_intermediate_failures_unreachable_address() {
let packets: Vec<_> = (0..5).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect();
let ipv4local = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080);
let ipv4broadcast = SocketAddr::new(IpAddr::V4(Ipv4Addr::BROADCAST), 8080);
let sender = UdpSocket::bind("0.0.0.0:0").await.expect("bind");
// test intermediate failures for batch_send
let packet_refs: Vec<_> = vec![
(&packets[0][..], &ipv4local),
(&packets[1][..], &ipv4broadcast),
(&packets[2][..], &ipv4local),
(&packets[3][..], &ipv4broadcast),
(&packets[4][..], &ipv4local),
];
if let Err(SendPktsError::IoError(ioerror, num_failed)) =
batch_send(&sender, &packet_refs[..]).await
{
assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied);
assert_eq!(num_failed, 2);
}
// test leading and trailing failures for batch_send
let packet_refs: Vec<_> = vec![
(&packets[0][..], &ipv4broadcast),
(&packets[1][..], &ipv4local),
(&packets[2][..], &ipv4broadcast),
(&packets[3][..], &ipv4local),
(&packets[4][..], &ipv4broadcast),
];
if let Err(SendPktsError::IoError(ioerror, num_failed)) =
batch_send(&sender, &packet_refs[..]).await
{
assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied);
assert_eq!(num_failed, 3);
}
// test consecutive intermediate failures for batch_send
let packet_refs: Vec<_> = vec![
(&packets[0][..], &ipv4local),
(&packets[1][..], &ipv4local),
(&packets[2][..], &ipv4broadcast),
(&packets[3][..], &ipv4broadcast),
(&packets[4][..], &ipv4local),
];
if let Err(SendPktsError::IoError(ioerror, num_failed)) =
batch_send(&sender, &packet_refs[..]).await
{
assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied);
assert_eq!(num_failed, 2);
}
// test intermediate failures for multi_target_send
let dest_refs: Vec<_> = vec![
&ipv4local,
&ipv4broadcast,
&ipv4local,
&ipv4broadcast,
&ipv4local,
];
if let Err(SendPktsError::IoError(ioerror, num_failed)) =
multi_target_send(&sender, &packets[0], &dest_refs).await
{
assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied);
assert_eq!(num_failed, 2);
}
// test leading and trailing failures for multi_target_send
let dest_refs: Vec<_> = vec![
&ipv4broadcast,
&ipv4local,
&ipv4broadcast,
&ipv4local,
&ipv4broadcast,
];
if let Err(SendPktsError::IoError(ioerror, num_failed)) =
multi_target_send(&sender, &packets[0], &dest_refs).await
{
assert_matches!(ioerror.kind(), ErrorKind::PermissionDenied);
assert_eq!(num_failed, 3);
}
}
}