diff --git a/Cargo.lock b/Cargo.lock index ebe8a8c82..fc9b97648 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4511,7 +4511,6 @@ name = "solana-net-utils" version = "1.6.0" dependencies = [ "bincode", - "bytes 0.4.12", "clap", "log 0.4.11", "nix 0.19.0", @@ -4522,7 +4521,7 @@ dependencies = [ "solana-clap-utils", "solana-logger 1.6.0", "solana-version", - "tokio 0.1.22", + "tokio 0.3.5", "url 2.1.1", ] diff --git a/core/src/validator.rs b/core/src/validator.rs index acc4d8323..dd5f317ac 100644 --- a/core/src/validator.rs +++ b/core/src/validator.rs @@ -756,7 +756,7 @@ impl Validator { self.completed_data_sets_service .join() .expect("completed_data_sets_service"); - self.ip_echo_server.shutdown_now(); + self.ip_echo_server.shutdown_background(); } } diff --git a/net-utils/Cargo.toml b/net-utils/Cargo.toml index a95df4833..4304d4310 100644 --- a/net-utils/Cargo.toml +++ b/net-utils/Cargo.toml @@ -10,7 +10,6 @@ edition = "2018" [dependencies] bincode = "1.3.1" -bytes = "0.4" clap = "2.33.1" log = "0.4.11" nix = "0.19.0" @@ -21,7 +20,7 @@ socket2 = "0.3.17" solana-clap-utils = { path = "../clap-utils", version = "1.6.0" } solana-logger = { path = "../logger", version = "1.6.0" } solana-version = { path = "../version", version = "1.6.0" } -tokio = "0.1" +tokio = { version = "0.3", features = ["full"] } url = "2.1.1" [lib] diff --git a/net-utils/src/ip_echo_server.rs b/net-utils/src/ip_echo_server.rs index 94346ef2c..7b656fd2f 100644 --- a/net-utils/src/ip_echo_server.rs +++ b/net-utils/src/ip_echo_server.rs @@ -1,15 +1,23 @@ -use crate::{ip_echo_server_reply_length, HEADER_LENGTH}; -use bytes::Bytes; -use log::*; -use serde_derive::{Deserialize, Serialize}; -use std::{io, net::SocketAddr, time::Duration}; -use tokio::{net::TcpListener, prelude::*, reactor::Handle, runtime::Runtime}; +use { + crate::{ip_echo_server_reply_length, HEADER_LENGTH}, + log::*, + serde_derive::{Deserialize, Serialize}, + std::{io, net::SocketAddr, time::Duration}, + tokio::{ + net::{TcpListener, TcpStream}, + prelude::*, + runtime::{self, Runtime}, + time::timeout, + }, +}; pub type IpEchoServer = Runtime; pub const MAX_PORT_COUNT_PER_MESSAGE: usize = 4; -#[derive(Serialize, Deserialize, Default)] +const IO_TIMEOUT: Duration = Duration::from_secs(5); + +#[derive(Serialize, Deserialize, Default, Debug)] pub(crate) struct IpEchoServerMessage { tcp_ports: [u16; MAX_PORT_COUNT_PER_MESSAGE], // Fixed size list of ports to avoid vec serde udp_ports: [u16; MAX_PORT_COUNT_PER_MESSAGE], // Fixed size list of ports to avoid vec serde @@ -34,173 +42,111 @@ pub(crate) fn ip_echo_server_request_length() -> usize { + REQUEST_TERMINUS_LENGTH } +async fn process_connection(mut socket: TcpStream, peer_addr: SocketAddr) -> io::Result<()> { + info!("connection from {:?}", peer_addr); + + let mut data = vec![0u8; ip_echo_server_request_length()]; + let (mut reader, mut writer) = socket.split(); + + let _ = timeout(IO_TIMEOUT, reader.read_exact(&mut data)).await??; + drop(reader); + + let request_header: String = data[0..HEADER_LENGTH].iter().map(|b| *b as char).collect(); + if request_header != "\0\0\0\0" { + // Explicitly check for HTTP GET/POST requests to more gracefully handle + // the case where a user accidentally tried to use a gossip entrypoint in + // place of a JSON RPC URL: + if request_header == "GET " || request_header == "POST" { + // Send HTTP error response + timeout( + IO_TIMEOUT, + writer.write_all(b"HTTP/1.1 400 Bad Request\nContent-length: 0\n\n"), + ) + .await??; + return Ok(()); + } + return Err(io::Error::new( + io::ErrorKind::Other, + format!("Bad request header: {}", request_header), + )); + } + + let msg = + bincode::deserialize::(&data[HEADER_LENGTH..]).map_err(|err| { + io::Error::new( + io::ErrorKind::Other, + format!("Failed to deserialize IpEchoServerMessage: {:?}", err), + ) + })?; + + trace!("request: {:?}", msg); + + // Fire a datagram at each non-zero UDP port + match std::net::UdpSocket::bind("0.0.0.0:0") { + Ok(udp_socket) => { + for udp_port in &msg.udp_ports { + if *udp_port != 0 { + match udp_socket.send_to(&[0], SocketAddr::from((peer_addr.ip(), *udp_port))) { + Ok(_) => debug!("Successful send_to udp/{}", udp_port), + Err(err) => info!("Failed to send_to udp/{}: {}", udp_port, err), + } + } + } + } + Err(err) => { + warn!("Failed to bind local udp socket: {}", err); + } + } + + // Try to connect to each non-zero TCP port + for tcp_port in &msg.tcp_ports { + if *tcp_port != 0 { + debug!("Connecting to tcp/{}", tcp_port); + + let tcp_stream = timeout( + IO_TIMEOUT, + TcpStream::connect(&SocketAddr::new(peer_addr.ip(), *tcp_port)), + ) + .await??; + + debug!("Connection established to tcp/{}", *tcp_port); + let _ = tcp_stream.shutdown(std::net::Shutdown::Both); + } + } + + // "\0\0\0\0" header is added to ensure a valid response will never + // conflict with the first four bytes of a valid HTTP response. + let mut bytes = vec![0u8; ip_echo_server_reply_length()]; + bincode::serialize_into(&mut bytes[HEADER_LENGTH..], &peer_addr.ip()).unwrap(); + trace!("response: {:?}", bytes); + writer.write_all(&bytes).await +} + +async fn run_echo_server(tcp_listener: std::net::TcpListener) { + info!("bound to {:?}", tcp_listener.local_addr().unwrap()); + let tcp_listener = + TcpListener::from_std(tcp_listener).expect("Failed to convert std::TcpListener"); + + loop { + match tcp_listener.accept().await { + Ok((socket, peer_addr)) => { + runtime::Handle::current().spawn(async move { + if let Err(err) = process_connection(socket, peer_addr).await { + info!("session failed: {:?}", err); + } + }); + } + Err(err) => warn!("listener accept failed: {:?}", err), + } + } +} + /// Starts a simple TCP server on the given port that echos the IP address of any peer that /// connects. Used by |get_public_ip_addr| -pub fn ip_echo_server(tcp: std::net::TcpListener) -> IpEchoServer { - info!("bound to {:?}", tcp.local_addr()); - let tcp = - TcpListener::from_std(tcp, &Handle::default()).expect("Failed to convert std::TcpListener"); +pub fn ip_echo_server(tcp_listener: std::net::TcpListener) -> IpEchoServer { + tcp_listener.set_nonblocking(true).unwrap(); - let server = tcp - .incoming() - .map_err(|err| warn!("accept failed: {:?}", err)) - .filter_map(|socket| match socket.peer_addr() { - Ok(peer_addr) => { - info!("connection from {:?}", peer_addr); - Some((peer_addr, socket)) - } - Err(err) => { - info!("peer_addr failed for {:?}: {:?}", socket, err); - None - } - }) - .for_each(move |(peer_addr, socket)| { - let data = vec![0u8; ip_echo_server_request_length()]; - let (reader, writer) = socket.split(); - - let processor = tokio::io::read_exact(reader, data) - .and_then(move |(_, data)| { - if data.len() < HEADER_LENGTH { - return Err(io::Error::new( - io::ErrorKind::Other, - format!("Request too short, received {} bytes", data.len()), - )); - } - let request_header: String = - data[0..HEADER_LENGTH].iter().map(|b| *b as char).collect(); - if request_header != "\0\0\0\0" { - // Explicitly check for HTTP GET/POST requests to more gracefully handle - // the case where a user accidentally tried to use a gossip entrypoint in - // place of a JSON RPC URL: - if request_header == "GET " || request_header == "POST" { - return Ok(None); // None -> Send HTTP error response - } - return Err(io::Error::new( - io::ErrorKind::Other, - format!("Bad request header: {}", request_header), - )); - } - - let expected_len = - bincode::serialized_size(&IpEchoServerMessage::default()).unwrap() as usize; - let actual_len = data[HEADER_LENGTH..].len(); - if actual_len < expected_len { - return Err(io::Error::new( - io::ErrorKind::Other, - format!( - "Request too short, actual {} < expected {}", - actual_len, expected_len - ), - )); - } - - bincode::deserialize::(&data[HEADER_LENGTH..]) - .map(Some) - .map_err(|err| { - io::Error::new( - io::ErrorKind::Other, - format!("Failed to deserialize IpEchoServerMessage: {:?}", err), - ) - }) - }) - .and_then(move |maybe_msg| { - match maybe_msg { - None => None, // Send HTTP error response - Some(msg) => { - // Fire a datagram at each non-zero UDP port - if !msg.udp_ports.is_empty() { - match std::net::UdpSocket::bind("0.0.0.0:0") { - Ok(udp_socket) => { - for udp_port in &msg.udp_ports { - if *udp_port != 0 { - match udp_socket.send_to( - &[0], - SocketAddr::from((peer_addr.ip(), *udp_port)), - ) { - Ok(_) => debug!( - "Successful send_to udp/{}", - udp_port - ), - Err(err) => info!( - "Failed to send_to udp/{}: {}", - udp_port, err - ), - } - } - } - } - Err(err) => { - warn!("Failed to bind local udp socket: {}", err); - } - } - } - - // Try to connect to each non-zero TCP port - let tcp_futures: Vec<_> = - msg.tcp_ports - .iter() - .filter_map(|tcp_port| { - let tcp_port = *tcp_port; - if tcp_port == 0 { - None - } else { - Some( - tokio::net::TcpStream::connect(&SocketAddr::new( - peer_addr.ip(), - tcp_port, - )) - .and_then(move |tcp_stream| { - debug!( - "Connection established to tcp/{}", - tcp_port - ); - let _ = tcp_stream - .shutdown(std::net::Shutdown::Both); - Ok(()) - }) - .timeout(Duration::from_secs(5)) - .or_else(move |err| { - Err(io::Error::new( - io::ErrorKind::Other, - format!( - "Connection timeout to {}: {:?}", - tcp_port, err - ), - )) - }), - ) - } - }) - .collect(); - Some(future::join_all(tcp_futures)) - } - } - }) - .and_then(move |valid_request| { - let bytes = if valid_request.is_none() { - Bytes::from("HTTP/1.1 400 Bad Request\nContent-length: 0\n\n") - } else { - // "\0\0\0\0" header is added to ensure a valid response will never - // conflict with the first four bytes of a valid HTTP response. - let mut bytes = vec![0u8; ip_echo_server_reply_length()]; - bincode::serialize_into(&mut bytes[HEADER_LENGTH..], &peer_addr.ip()) - .unwrap(); - Bytes::from(bytes) - }; - tokio::io::write_all(writer, bytes) - }) - .timeout(Duration::from_secs(5)) - .then(|result| { - if let Err(err) = result { - info!("Session failed: {:?}", err); - } - Ok(()) - }); - - tokio::spawn(processor) - }); - - let mut rt = Runtime::new().expect("Failed to create Runtime"); - rt.spawn(server); - rt + let runtime = Runtime::new().expect("Failed to create Runtime"); + runtime.spawn(run_echo_server(tcp_listener)); + runtime }