From a5792885ca3af699737ae81c95347cdef54fc471 Mon Sep 17 00:00:00 2001 From: Jon Cinque Date: Fri, 20 May 2022 22:40:40 +0200 Subject: [PATCH] streamer: Add nonblocking versions of sendmmsg / recvmmsg (#25415) --- streamer/src/lib.rs | 1 + streamer/src/nonblocking/mod.rs | 2 + streamer/src/nonblocking/recvmmsg.rs | 195 +++++++++++++++++++ streamer/src/nonblocking/sendmmsg.rs | 274 +++++++++++++++++++++++++++ 4 files changed, 472 insertions(+) create mode 100644 streamer/src/nonblocking/mod.rs create mode 100644 streamer/src/nonblocking/recvmmsg.rs create mode 100644 streamer/src/nonblocking/sendmmsg.rs diff --git a/streamer/src/lib.rs b/streamer/src/lib.rs index 3fbad9a2df..22d14b5932 100644 --- a/streamer/src/lib.rs +++ b/streamer/src/lib.rs @@ -1,4 +1,5 @@ #![allow(clippy::integer_arithmetic)] +pub mod nonblocking; pub mod packet; pub mod quic; pub mod recvmmsg; diff --git a/streamer/src/nonblocking/mod.rs b/streamer/src/nonblocking/mod.rs new file mode 100644 index 0000000000..4c7b0d5995 --- /dev/null +++ b/streamer/src/nonblocking/mod.rs @@ -0,0 +1,2 @@ +pub mod recvmmsg; +pub mod sendmmsg; diff --git a/streamer/src/nonblocking/recvmmsg.rs b/streamer/src/nonblocking/recvmmsg.rs new file mode 100644 index 0000000000..0313630a3f --- /dev/null +++ b/streamer/src/nonblocking/recvmmsg.rs @@ -0,0 +1,195 @@ +//! The `recvmmsg` module provides a nonblocking recvmmsg() API implementation + +use { + crate::{ + packet::{Meta, Packet}, + recvmmsg::NUM_RCVMMSGS, + }, + std::{cmp, io}, + tokio::net::UdpSocket, +}; + +pub async fn recv_mmsg( + socket: &UdpSocket, + packets: &mut [Packet], +) -> io::Result { + debug_assert!(packets.iter().all(|pkt| pkt.meta == Meta::default())); + let count = cmp::min(NUM_RCVMMSGS, packets.len()); + socket.readable().await?; + let mut i = 0; + for p in packets.iter_mut().take(count) { + p.meta.size = 0; + match socket.try_recv_from(&mut p.data) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + break; + } + Err(e) => { + return Err(e); + } + Ok((nrecv, from)) => { + p.meta.size = nrecv; + p.meta.set_addr(&from); + } + } + i += 1; + } + Ok(i) +} + +#[cfg(test)] +mod tests { + use { + crate::{nonblocking::recvmmsg::*, packet::PACKET_DATA_SIZE}, + std::{net::SocketAddr, time::Instant}, + tokio::net::UdpSocket, + }; + + type TestConfig = (UdpSocket, SocketAddr, UdpSocket, SocketAddr); + + async fn test_setup_reader_sender(ip_str: &str) -> io::Result { + let reader = UdpSocket::bind(ip_str).await?; + let addr = reader.local_addr()?; + let sender = UdpSocket::bind(ip_str).await?; + let saddr = sender.local_addr()?; + Ok((reader, addr, sender, saddr)) + } + + const TEST_NUM_MSGS: usize = 32; + + async fn test_one_iter((reader, addr, sender, saddr): TestConfig) { + let sent = TEST_NUM_MSGS - 1; + for _ in 0..sent { + let data = [0; PACKET_DATA_SIZE]; + sender.send_to(&data[..], &addr).await.unwrap(); + } + + let mut packets = vec![Packet::default(); TEST_NUM_MSGS]; + let recv = recv_mmsg(&reader, &mut packets[..]).await.unwrap(); + assert_eq!(sent, recv); + for packet in packets.iter().take(recv) { + assert_eq!(packet.meta.size, PACKET_DATA_SIZE); + assert_eq!(packet.meta.addr(), saddr); + } + } + + #[tokio::test] + async fn test_recv_mmsg_one_iter() { + test_one_iter(test_setup_reader_sender("127.0.0.1:0").await.unwrap()).await; + + match test_setup_reader_sender("::1:0").await { + Ok(config) => test_one_iter(config).await, + Err(e) => warn!("Failed to configure IPv6: {:?}", e), + } + } + + async fn test_multi_iter((reader, addr, sender, saddr): TestConfig) { + let sent = TEST_NUM_MSGS + 10; + for _ in 0..sent { + let data = [0; PACKET_DATA_SIZE]; + sender.send_to(&data[..], &addr).await.unwrap(); + } + + let mut packets = vec![Packet::default(); TEST_NUM_MSGS]; + let recv = recv_mmsg(&reader, &mut packets[..]).await.unwrap(); + assert_eq!(TEST_NUM_MSGS, recv); + for packet in packets.iter().take(recv) { + assert_eq!(packet.meta.size, PACKET_DATA_SIZE); + assert_eq!(packet.meta.addr(), saddr); + } + + packets + .iter_mut() + .for_each(|pkt| pkt.meta = Meta::default()); + let recv = recv_mmsg(&reader, &mut packets[..]).await.unwrap(); + assert_eq!(sent - TEST_NUM_MSGS, recv); + for packet in packets.iter().take(recv) { + assert_eq!(packet.meta.size, PACKET_DATA_SIZE); + assert_eq!(packet.meta.addr(), saddr); + } + } + + #[tokio::test] + async fn test_recv_mmsg_multi_iter() { + test_multi_iter(test_setup_reader_sender("127.0.0.1:0").await.unwrap()).await; + + match test_setup_reader_sender("::1:0").await { + Ok(config) => test_multi_iter(config).await, + Err(e) => warn!("Failed to configure IPv6: {:?}", e), + } + } + + #[tokio::test] + async fn test_recv_mmsg_multi_iter_timeout() { + let reader = UdpSocket::bind("127.0.0.1:0").await.expect("bind"); + let addr = reader.local_addr().unwrap(); + let sender = UdpSocket::bind("127.0.0.1:0").await.expect("bind"); + let saddr = sender.local_addr().unwrap(); + let sent = TEST_NUM_MSGS; + for _ in 0..sent { + let data = [0; PACKET_DATA_SIZE]; + sender.send_to(&data[..], &addr).await.unwrap(); + } + + let start = Instant::now(); + let mut packets = vec![Packet::default(); TEST_NUM_MSGS]; + let recv = recv_mmsg(&reader, &mut packets[..]).await.unwrap(); + assert_eq!(TEST_NUM_MSGS, recv); + for packet in packets.iter().take(recv) { + assert_eq!(packet.meta.size, PACKET_DATA_SIZE); + assert_eq!(packet.meta.addr(), saddr); + } + + packets + .iter_mut() + .for_each(|pkt| pkt.meta = Meta::default()); + let _recv = recv_mmsg(&reader, &mut packets[..]).await; + assert!(start.elapsed().as_secs() < 5); + } + + #[tokio::test] + async fn test_recv_mmsg_multi_addrs() { + let reader = UdpSocket::bind("127.0.0.1:0").await.expect("bind"); + let addr = reader.local_addr().unwrap(); + + let sender1 = UdpSocket::bind("127.0.0.1:0").await.expect("bind"); + let saddr1 = sender1.local_addr().unwrap(); + let sent1 = TEST_NUM_MSGS - 1; + + let sender2 = UdpSocket::bind("127.0.0.1:0").await.expect("bind"); + let saddr2 = sender2.local_addr().unwrap(); + let sent2 = TEST_NUM_MSGS + 1; + + for _ in 0..sent1 { + let data = [0; PACKET_DATA_SIZE]; + sender1.send_to(&data[..], &addr).await.unwrap(); + } + + for _ in 0..sent2 { + let data = [0; PACKET_DATA_SIZE]; + sender2.send_to(&data[..], &addr).await.unwrap(); + } + + let mut packets = vec![Packet::default(); TEST_NUM_MSGS]; + + let recv = recv_mmsg(&reader, &mut packets[..]).await.unwrap(); + assert_eq!(TEST_NUM_MSGS, recv); + for packet in packets.iter().take(sent1) { + assert_eq!(packet.meta.size, PACKET_DATA_SIZE); + assert_eq!(packet.meta.addr(), saddr1); + } + for packet in packets.iter().skip(sent1).take(recv - sent1) { + assert_eq!(packet.meta.size, PACKET_DATA_SIZE); + assert_eq!(packet.meta.addr(), saddr2); + } + + packets + .iter_mut() + .for_each(|pkt| pkt.meta = Meta::default()); + let recv = recv_mmsg(&reader, &mut packets[..]).await.unwrap(); + assert_eq!(sent1 + sent2 - TEST_NUM_MSGS, recv); + for packet in packets.iter().take(recv) { + assert_eq!(packet.meta.size, PACKET_DATA_SIZE); + assert_eq!(packet.meta.addr(), saddr2); + } + } +} diff --git a/streamer/src/nonblocking/sendmmsg.rs b/streamer/src/nonblocking/sendmmsg.rs new file mode 100644 index 0000000000..6797bb9fbf --- /dev/null +++ b/streamer/src/nonblocking/sendmmsg.rs @@ -0,0 +1,274 @@ +//! The `sendmmsg` module provides a nonblocking sendmmsg() API implementation + +use { + crate::sendmmsg::SendPktsError, + futures_util::future::join_all, + std::{borrow::Borrow, iter::repeat, net::SocketAddr}, + tokio::net::UdpSocket, +}; + +pub async fn batch_send(sock: &UdpSocket, packets: &[(T, S)]) -> Result<(), SendPktsError> +where + S: Borrow, + T: AsRef<[u8]>, +{ + let mut num_failed = 0; + let mut erropt = None; + let futures = packets + .iter() + .map(|(p, a)| sock.send_to(p.as_ref(), a.borrow())) + .collect::>(); + let results = join_all(futures).await; + for result in results { + if let Err(e) = result { + num_failed += 1; + if erropt.is_none() { + erropt = Some(e); + } + } + } + + if let Some(err) = erropt { + Err(SendPktsError::IoError(err, num_failed)) + } else { + Ok(()) + } +} + +pub async fn multi_target_send( + sock: &UdpSocket, + packet: T, + dests: &[S], +) -> Result<(), SendPktsError> +where + S: Borrow, + T: AsRef<[u8]>, +{ + let dests = dests.iter().map(Borrow::borrow); + let pkts: Vec<_> = repeat(&packet).zip(dests).collect(); + batch_send(sock, &pkts).await +} + +#[cfg(test)] +mod tests { + use { + crate::{ + nonblocking::{ + recvmmsg::recv_mmsg, + sendmmsg::{batch_send, multi_target_send}, + }, + packet::Packet, + sendmmsg::SendPktsError, + }, + solana_sdk::packet::PACKET_DATA_SIZE, + std::{ + io::ErrorKind, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + }, + tokio::net::UdpSocket, + }; + + #[tokio::test] + async fn test_send_mmsg_one_dest() { + let reader = UdpSocket::bind("127.0.0.1:0").await.expect("bind"); + let addr = reader.local_addr().unwrap(); + let sender = UdpSocket::bind("127.0.0.1:0").await.expect("bind"); + + let packets: Vec<_> = (0..32).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect(); + let packet_refs: Vec<_> = packets.iter().map(|p| (&p[..], &addr)).collect(); + + let sent = batch_send(&sender, &packet_refs[..]).await.ok(); + assert_eq!(sent, Some(())); + + let mut packets = vec![Packet::default(); 32]; + let recv = recv_mmsg(&reader, &mut packets[..]).await.unwrap(); + assert_eq!(32, recv); + } + + #[tokio::test] + async fn test_send_mmsg_multi_dest() { + let reader = UdpSocket::bind("127.0.0.1:0").await.expect("bind"); + let addr = reader.local_addr().unwrap(); + + let reader2 = UdpSocket::bind("127.0.0.1:0").await.expect("bind"); + let addr2 = reader2.local_addr().unwrap(); + + let sender = UdpSocket::bind("127.0.0.1:0").await.expect("bind"); + + let packets: Vec<_> = (0..32).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect(); + let packet_refs: Vec<_> = packets + .iter() + .enumerate() + .map(|(i, p)| { + if i < 16 { + (&p[..], &addr) + } else { + (&p[..], &addr2) + } + }) + .collect(); + + let sent = batch_send(&sender, &packet_refs[..]).await.ok(); + assert_eq!(sent, Some(())); + + let mut packets = vec![Packet::default(); 32]; + let recv = recv_mmsg(&reader, &mut packets[..]).await.unwrap(); + assert_eq!(16, recv); + + let mut packets = vec![Packet::default(); 32]; + let recv = recv_mmsg(&reader2, &mut packets[..]).await.unwrap(); + assert_eq!(16, recv); + } + + #[tokio::test] + async fn test_multicast_msg() { + let reader = UdpSocket::bind("127.0.0.1:0").await.expect("bind"); + let addr = reader.local_addr().unwrap(); + + let reader2 = UdpSocket::bind("127.0.0.1:0").await.expect("bind"); + let addr2 = reader2.local_addr().unwrap(); + + let reader3 = UdpSocket::bind("127.0.0.1:0").await.expect("bind"); + let addr3 = reader3.local_addr().unwrap(); + + let reader4 = UdpSocket::bind("127.0.0.1:0").await.expect("bind"); + let addr4 = reader4.local_addr().unwrap(); + + let sender = UdpSocket::bind("127.0.0.1:0").await.expect("bind"); + + let packet = Packet::default(); + + let sent = multi_target_send( + &sender, + &packet.data[..packet.meta.size], + &[&addr, &addr2, &addr3, &addr4], + ) + .await + .ok(); + assert_eq!(sent, Some(())); + + let mut packets = vec![Packet::default(); 32]; + let recv = recv_mmsg(&reader, &mut packets[..]).await.unwrap(); + assert_eq!(1, recv); + + let mut packets = vec![Packet::default(); 32]; + let recv = recv_mmsg(&reader2, &mut packets[..]).await.unwrap(); + assert_eq!(1, recv); + + let mut packets = vec![Packet::default(); 32]; + let recv = recv_mmsg(&reader3, &mut packets[..]).await.unwrap(); + assert_eq!(1, recv); + + let mut packets = vec![Packet::default(); 32]; + let recv = recv_mmsg(&reader4, &mut packets[..]).await.unwrap(); + assert_eq!(1, recv); + } + + #[tokio::test] + async fn test_intermediate_failures_mismatched_bind() { + let packets: Vec<_> = (0..3).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect(); + let ip4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080); + let ip6 = SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 8080); + let packet_refs: Vec<_> = vec![ + (&packets[0][..], &ip4), + (&packets[1][..], &ip6), + (&packets[2][..], &ip4), + ]; + let dest_refs: Vec<_> = vec![&ip4, &ip6, &ip4]; + + let sender = UdpSocket::bind("0.0.0.0:0").await.expect("bind"); + if let Err(SendPktsError::IoError(_, num_failed)) = + batch_send(&sender, &packet_refs[..]).await + { + assert_eq!(num_failed, 1); + } + if let Err(SendPktsError::IoError(_, num_failed)) = + multi_target_send(&sender, &packets[0], &dest_refs).await + { + assert_eq!(num_failed, 1); + } + } + + #[tokio::test] + async fn test_intermediate_failures_unreachable_address() { + let packets: Vec<_> = (0..5).map(|_| vec![0u8; PACKET_DATA_SIZE]).collect(); + let ipv4local = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080); + let ipv4broadcast = SocketAddr::new(IpAddr::V4(Ipv4Addr::BROADCAST), 8080); + let sender = UdpSocket::bind("0.0.0.0:0").await.expect("bind"); + + // test intermediate failures for batch_send + let packet_refs: Vec<_> = vec![ + (&packets[0][..], &ipv4local), + (&packets[1][..], &ipv4broadcast), + (&packets[2][..], &ipv4local), + (&packets[3][..], &ipv4broadcast), + (&packets[4][..], &ipv4local), + ]; + if let Err(SendPktsError::IoError(ioerror, num_failed)) = + batch_send(&sender, &packet_refs[..]).await + { + assert!(matches!(ioerror.kind(), ErrorKind::PermissionDenied)); + assert_eq!(num_failed, 2); + } + + // test leading and trailing failures for batch_send + let packet_refs: Vec<_> = vec![ + (&packets[0][..], &ipv4broadcast), + (&packets[1][..], &ipv4local), + (&packets[2][..], &ipv4broadcast), + (&packets[3][..], &ipv4local), + (&packets[4][..], &ipv4broadcast), + ]; + if let Err(SendPktsError::IoError(ioerror, num_failed)) = + batch_send(&sender, &packet_refs[..]).await + { + assert!(matches!(ioerror.kind(), ErrorKind::PermissionDenied)); + assert_eq!(num_failed, 3); + } + + // test consecutive intermediate failures for batch_send + let packet_refs: Vec<_> = vec![ + (&packets[0][..], &ipv4local), + (&packets[1][..], &ipv4local), + (&packets[2][..], &ipv4broadcast), + (&packets[3][..], &ipv4broadcast), + (&packets[4][..], &ipv4local), + ]; + if let Err(SendPktsError::IoError(ioerror, num_failed)) = + batch_send(&sender, &packet_refs[..]).await + { + assert!(matches!(ioerror.kind(), ErrorKind::PermissionDenied)); + assert_eq!(num_failed, 2); + } + + // test intermediate failures for multi_target_send + let dest_refs: Vec<_> = vec![ + &ipv4local, + &ipv4broadcast, + &ipv4local, + &ipv4broadcast, + &ipv4local, + ]; + if let Err(SendPktsError::IoError(ioerror, num_failed)) = + multi_target_send(&sender, &packets[0], &dest_refs).await + { + assert!(matches!(ioerror.kind(), ErrorKind::PermissionDenied)); + assert_eq!(num_failed, 2); + } + + // test leading and trailing failures for multi_target_send + let dest_refs: Vec<_> = vec![ + &ipv4broadcast, + &ipv4local, + &ipv4broadcast, + &ipv4local, + &ipv4broadcast, + ]; + if let Err(SendPktsError::IoError(ioerror, num_failed)) = + multi_target_send(&sender, &packets[0], &dest_refs).await + { + assert!(matches!(ioerror.kind(), ErrorKind::PermissionDenied)); + assert_eq!(num_failed, 3); + } + } +}