diff --git a/net-utils/src/lib.rs b/net-utils/src/lib.rs index 739d1fda59..141852108c 100644 --- a/net-utils/src/lib.rs +++ b/net-utils/src/lib.rs @@ -1,13 +1,17 @@ //! The `net_utils` module assists with networking -use log::*; -use rand::{thread_rng, Rng}; -use socket2::{Domain, SockAddr, Socket, Type}; -use std::collections::{BTreeMap, BTreeSet}; -use std::io::{self, Read, Write}; -use std::net::{IpAddr, SocketAddr, TcpListener, TcpStream, ToSocketAddrs, UdpSocket}; -use std::sync::mpsc::channel; -use std::time::Duration; -use url::Url; +use { + log::*, + rand::{thread_rng, Rng}, + socket2::{Domain, SockAddr, Socket, Type}, + std::{ + collections::{BTreeMap, HashSet}, + io::{self, Read, Write}, + net::{IpAddr, SocketAddr, TcpListener, TcpStream, ToSocketAddrs, UdpSocket}, + sync::{mpsc::channel, Arc, RwLock}, + time::{Duration, Instant}, + }, + url::Url, +}; mod ip_echo_server; use ip_echo_server::IpEchoServerMessage; @@ -204,21 +208,38 @@ fn do_verify_reachable_ports( .map_err(|err| warn!("ip_echo_server request failed: {}", err)); // Spawn threads at once! + let reachable_ports = Arc::new(RwLock::new(HashSet::new())); let thread_handles: Vec<_> = checked_socket_iter .map(|udp_socket| { let port = udp_socket.local_addr().unwrap().port(); let udp_socket = udp_socket.try_clone().expect("Unable to clone udp socket"); + let reachable_ports = reachable_ports.clone(); std::thread::spawn(move || { - let mut buf = [0; 1]; + let start = Instant::now(); + let original_read_timeout = udp_socket.read_timeout().unwrap(); - udp_socket.set_read_timeout(Some(timeout)).unwrap(); - let recv_result = udp_socket.recv(&mut buf); - debug!( - "Waited for incoming datagram on udp/{}: {:?}", - port, recv_result - ); + udp_socket + .set_read_timeout(Some(Duration::from_millis(250))) + .unwrap(); + loop { + if reachable_ports.read().unwrap().contains(&port) + || Instant::now().duration_since(start) >= timeout + { + break; + } + + let recv_result = udp_socket.recv(&mut [0; 1]); + debug!( + "Waited for incoming datagram on udp/{}: {:?}", + port, recv_result + ); + + if recv_result.is_ok() { + reachable_ports.write().unwrap().insert(port); + break; + } + } udp_socket.set_read_timeout(original_read_timeout).unwrap(); - recv_result.map(|_| port).ok() }) }) .collect(); @@ -227,11 +248,11 @@ fn do_verify_reachable_ports( // Separate from the above by collect()-ing as an intermediately step to make the iterator // eager not lazy so that joining happens here at once after creating bunch of threads // at once. - let reachable_ports: BTreeSet<_> = thread_handles - .into_iter() - .filter_map(|t| t.join().unwrap()) - .collect(); + for thread in thread_handles { + thread.join().unwrap(); + } + let reachable_ports = reachable_ports.read().unwrap().clone(); if reachable_ports.len() == checked_ports.len() { info!( "checked udp ports: {:?}, reachable udp ports: {:?}",