From 90df6237c63ce12aecf92106412dd05512e79980 Mon Sep 17 00:00:00 2001 From: Pankaj Garg Date: Thu, 13 Sep 2018 14:41:28 -0700 Subject: [PATCH] Implements recvmmsg() for UDP packets (#1161) * Implemented recvmmsg() for UDP packets - This change implements binding between libc API for recvmmsg() - The function can receive multiple packets using one system call Fixes #1141 * Added unit tests for recvmmsg() * Added recv_mmsg() wrapper for non Linux OS * Address review comments for recvmmsg() * Remove unnecessary imports * Moved target specific dependencies to the function --- Cargo.toml | 1 + src/lib.rs | 2 + src/packet.rs | 24 +++---- src/recvmmsg.rs | 185 ++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 200 insertions(+), 12 deletions(-) create mode 100644 src/recvmmsg.rs diff --git a/Cargo.toml b/Cargo.toml index 2cca6495d0..9e6d3502f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -97,6 +97,7 @@ sys-info = "0.5.6" tokio = "0.1" tokio-codec = "0.1" untrusted = "0.6.2" +libc = "0.2.43" [[bench]] name = "bank" diff --git a/src/lib.rs b/src/lib.rs index 485c899525..4364588d00 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -36,6 +36,7 @@ pub mod packet; pub mod payment_plan; pub mod record_stage; pub mod recorder; +pub mod recvmmsg; pub mod replicate_stage; pub mod request; pub mod request_processor; @@ -69,6 +70,7 @@ extern crate jsonrpc_core; #[macro_use] extern crate jsonrpc_macros; extern crate jsonrpc_http_server; +extern crate libc; #[macro_use] extern crate log; extern crate nix; diff --git a/src/packet.rs b/src/packet.rs index 4d3b262464..626ba6a9fb 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -3,6 +3,7 @@ use bincode::{deserialize, serialize}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use counter::Counter; use log::Level; +use recvmmsg::{recv_mmsg, NUM_RCVMMSGS}; use result::{Error, Result}; use serde::Serialize; use signature::Pubkey; @@ -251,31 +252,30 @@ impl Packets { // * read until it fails // * set it back to blocking before returning socket.set_nonblocking(false)?; - for p in &mut self.packets { - p.meta.size = 0; - trace!("receiving on {}", socket.local_addr().unwrap()); - match socket.recv_from(&mut p.data) { + trace!("receiving on {}", socket.local_addr().unwrap()); + loop { + match recv_mmsg(socket, &mut self.packets[i..]) { Err(_) if i > 0 => { inc_new_counter_info!("packets-recv_count", i); debug!("got {:?} messages on {}", i, socket.local_addr().unwrap()); - break; + socket.set_nonblocking(true)?; + return Ok(i); } Err(e) => { trace!("recv_from err {:?}", e); return Err(Error::IO(e)); } - Ok((nrecv, from)) => { - p.meta.size = nrecv; - p.meta.set_addr(&from); - trace!("got {} bytes from {}", nrecv, from); - if i == 0 { + Ok(npkts) => { + trace!("got {} packets", npkts); + i += npkts; + if npkts != NUM_RCVMMSGS { socket.set_nonblocking(true)?; + inc_new_counter_info!("packets-recv_count", i); + return Ok(i); } } } - i += 1; } - Ok(i) } pub fn recv_from(&mut self, socket: &UdpSocket) -> Result<()> { let sz = self.run_read_from(socket)?; diff --git a/src/recvmmsg.rs b/src/recvmmsg.rs new file mode 100644 index 0000000000..908713fb85 --- /dev/null +++ b/src/recvmmsg.rs @@ -0,0 +1,185 @@ +//! The `recvmmsg` module provides recvmmsg() API implementation + +use packet::Packet; +use std::cmp; +use std::io; +use std::net::UdpSocket; + +pub const NUM_RCVMMSGS: usize = 16; + +#[cfg(not(target_os = "linux"))] +pub fn recv_mmsg(socket: &UdpSocket, packets: &mut [Packet]) -> io::Result { + let mut i = 0; + socket.set_nonblocking(false)?; + let count = cmp::min(NUM_RCVMMSGS, packets.len()); + for n in 0..count { + let p = &mut packets[n]; + p.meta.size = 0; + match socket.recv_from(&mut p.data) { + Err(_) if i > 0 => { + break; + } + Err(e) => { + return Err(e); + } + Ok((nrecv, from)) => { + p.meta.size = nrecv; + p.meta.set_addr(&from); + if i == 0 { + socket.set_nonblocking(true)?; + } + } + } + i += 1; + } + Ok(i) +} + +#[cfg(target_os = "linux")] +pub fn recv_mmsg(sock: &UdpSocket, packets: &mut [Packet]) -> io::Result { + use libc::{ + c_void, iovec, mmsghdr, recvmmsg, sockaddr_in, socklen_t, time_t, timespec, MSG_WAITFORONE, + }; + use nix::sys::socket::InetAddr; + use std::mem; + use std::os::unix::io::AsRawFd; + + let mut hdrs: [mmsghdr; NUM_RCVMMSGS] = unsafe { mem::zeroed() }; + let mut iovs: [iovec; NUM_RCVMMSGS] = unsafe { mem::zeroed() }; + let mut addr: [sockaddr_in; NUM_RCVMMSGS] = unsafe { mem::zeroed() }; + let addrlen = mem::size_of_val(&addr) as socklen_t; + + let sock_fd = sock.as_raw_fd(); + + let count = cmp::min(iovs.len(), packets.len()); + + for i in 0..count { + iovs[i].iov_base = packets[i].data.as_mut_ptr() as *mut c_void; + iovs[i].iov_len = packets[i].data.len(); + + hdrs[i].msg_hdr.msg_name = &mut addr[i] as *mut _ as *mut _; + hdrs[i].msg_hdr.msg_namelen = addrlen; + hdrs[i].msg_hdr.msg_iov = &mut iovs[i]; + hdrs[i].msg_hdr.msg_iovlen = 1; + } + let mut ts = timespec { + tv_sec: 1 as time_t, + tv_nsec: 0, + }; + + let npkts = + match unsafe { recvmmsg(sock_fd, &mut hdrs[0], count as u32, MSG_WAITFORONE, &mut ts) } { + -1 => return Err(io::Error::last_os_error()), + n => { + for i in 0..n as usize { + let mut p = &mut packets[i]; + p.meta.size = hdrs[i].msg_len as usize; + let inet_addr = InetAddr::V4(addr[i]); + p.meta.set_addr(&inet_addr.to_std()); + } + n as usize + } + }; + + Ok(npkts) +} + +#[cfg(test)] +mod tests { + use packet::PACKET_DATA_SIZE; + use recvmmsg::*; + + #[test] + pub fn test_recv_mmsg_one_iter() { + let reader = UdpSocket::bind("127.0.0.1:0").expect("bind"); + let addr = reader.local_addr().unwrap(); + let sender = UdpSocket::bind("127.0.0.1:0").expect("bind"); + let saddr = sender.local_addr().unwrap(); + let sent = NUM_RCVMMSGS - 1; + for _ in 0..sent { + let data = [0; PACKET_DATA_SIZE]; + sender.send_to(&data[..], &addr).unwrap(); + } + + let mut packets = vec![Packet::default(); NUM_RCVMMSGS]; + let recv = recv_mmsg(&reader, &mut packets[..]).unwrap(); + assert_eq!(sent, recv); + for i in 0..recv { + assert_eq!(packets[i].meta.size, PACKET_DATA_SIZE); + assert_eq!(packets[i].meta.addr(), saddr); + } + } + + #[test] + pub fn test_recv_mmsg_multi_iter() { + let reader = UdpSocket::bind("127.0.0.1:0").expect("bind"); + let addr = reader.local_addr().unwrap(); + let sender = UdpSocket::bind("127.0.0.1:0").expect("bind"); + let saddr = sender.local_addr().unwrap(); + let sent = NUM_RCVMMSGS + 10; + for _ in 0..sent { + let data = [0; PACKET_DATA_SIZE]; + sender.send_to(&data[..], &addr).unwrap(); + } + + let mut packets = vec![Packet::default(); NUM_RCVMMSGS * 2]; + let recv = recv_mmsg(&reader, &mut packets[..]).unwrap(); + assert_eq!(NUM_RCVMMSGS, recv); + for i in 0..recv { + assert_eq!(packets[i].meta.size, PACKET_DATA_SIZE); + assert_eq!(packets[i].meta.addr(), saddr); + } + + let recv = recv_mmsg(&reader, &mut packets[..]).unwrap(); + assert_eq!(sent - NUM_RCVMMSGS, recv); + for i in 0..recv { + assert_eq!(packets[i].meta.size, PACKET_DATA_SIZE); + assert_eq!(packets[i].meta.addr(), saddr); + } + } + + #[test] + pub fn test_recv_mmsg_multi_addrs() { + let reader = UdpSocket::bind("127.0.0.1:0").expect("bind"); + let addr = reader.local_addr().unwrap(); + + let sender1 = UdpSocket::bind("127.0.0.1:0").expect("bind"); + let saddr1 = sender1.local_addr().unwrap(); + let sent1 = NUM_RCVMMSGS - 1; + + let sender2 = UdpSocket::bind("127.0.0.1:0").expect("bind"); + let saddr2 = sender2.local_addr().unwrap(); + let sent2 = NUM_RCVMMSGS + 1; + + for _ in 0..sent1 { + let data = [0; PACKET_DATA_SIZE]; + sender1.send_to(&data[..], &addr).unwrap(); + } + + for _ in 0..sent2 { + let data = [0; PACKET_DATA_SIZE]; + sender2.send_to(&data[..], &addr).unwrap(); + } + + let mut packets = vec![Packet::default(); NUM_RCVMMSGS * 2]; + + let recv = recv_mmsg(&reader, &mut packets[..]).unwrap(); + assert_eq!(NUM_RCVMMSGS, recv); + for i in 0..sent1 { + assert_eq!(packets[i].meta.size, PACKET_DATA_SIZE); + assert_eq!(packets[i].meta.addr(), saddr1); + } + + for i in sent1..recv { + assert_eq!(packets[i].meta.size, PACKET_DATA_SIZE); + assert_eq!(packets[i].meta.addr(), saddr2); + } + + let recv = recv_mmsg(&reader, &mut packets[..]).unwrap(); + assert_eq!(sent1 + sent2 - NUM_RCVMMSGS, recv); + for i in 0..recv { + assert_eq!(packets[i].meta.size, PACKET_DATA_SIZE); + assert_eq!(packets[i].meta.addr(), saddr2); + } + } +}