From 7a9884c831a64568259c24213fcfb83ac7c17429 Mon Sep 17 00:00:00 2001 From: sakridge Date: Wed, 9 Mar 2022 10:52:31 +0100 Subject: [PATCH] Quic limit connections (#23283) * quic server limit connections * bump per_ip * Review comments * Make the connections per port --- core/src/tpu.rs | 4 + streamer/src/quic.rs | 340 +++++++++++++++++++++++++++++++++++++++---- 2 files changed, 316 insertions(+), 28 deletions(-) diff --git a/core/src/tpu.rs b/core/src/tpu.rs index a058075d0..a1deaf571 100644 --- a/core/src/tpu.rs +++ b/core/src/tpu.rs @@ -36,6 +36,9 @@ use { pub const DEFAULT_TPU_COALESCE_MS: u64 = 5; +// allow multiple connections for NAT and any open/close overlap +pub const MAX_QUIC_CONNECTIONS_PER_IP: usize = 8; + pub struct TpuSockets { pub transactions: Vec, pub transaction_forwards: Vec, @@ -108,6 +111,7 @@ impl Tpu { cluster_info.my_contact_info().tpu.ip(), packet_sender, exit.clone(), + MAX_QUIC_CONNECTIONS_PER_IP, ) .unwrap(); diff --git a/streamer/src/quic.rs b/streamer/src/quic.rs index 30657c494..f2f0a0738 100644 --- a/streamer/src/quic.rs +++ b/streamer/src/quic.rs @@ -3,22 +3,24 @@ use { futures_util::stream::StreamExt, pem::Pem, pkcs8::{der::Document, AlgorithmIdentifier, ObjectIdentifier}, - quinn::{Endpoint, EndpointConfig, ServerConfig}, + quinn::{Endpoint, EndpointConfig, IncomingUniStreams, ServerConfig}, rcgen::{CertificateParams, DistinguishedName, DnType, SanType}, solana_perf::packet::PacketBatch, solana_sdk::{ packet::{Packet, PACKET_DATA_SIZE}, signature::Keypair, + timing, }, std::{ + collections::{hash_map::Entry, HashMap}, error::Error, net::{IpAddr, SocketAddr, UdpSocket}, sync::{ - atomic::{AtomicBool, Ordering}, - Arc, + atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}, + Arc, Mutex, }, thread, - time::Duration, + time::{Duration, Instant}, }, tokio::{ runtime::{Builder, Runtime}, @@ -120,8 +122,12 @@ fn new_cert_params(identity_keypair: &Keypair, san: IpAddr) -> CertificateParams cert_params } -pub fn rt() -> Runtime { - Builder::new_current_thread().enable_all().build().unwrap() +fn rt() -> Runtime { + Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build() + .unwrap() } #[derive(thiserror::Error, Debug)] @@ -190,12 +196,207 @@ fn handle_chunk( false } +#[derive(Debug)] +struct ConnectionEntry { + exit: Arc, + last_update: Arc, + port: u16, +} + +impl ConnectionEntry { + fn new(exit: Arc, last_update: Arc, port: u16) -> Self { + Self { + exit, + last_update, + port, + } + } + + fn last_update(&self) -> u64 { + self.last_update.load(Ordering::Relaxed) + } +} + +impl Drop for ConnectionEntry { + fn drop(&mut self) { + self.exit.store(true, Ordering::Relaxed); + } +} + +// Map of IP to list of connection entries +#[derive(Default, Debug)] +struct ConnectionTable { + table: HashMap>, + total_size: usize, +} + +// Prune the connection which has the oldest update +// Return number pruned +impl ConnectionTable { + fn prune_oldest(&mut self, max_size: usize) -> usize { + let mut num_pruned = 0; + while self.total_size > max_size { + let mut oldest = std::u64::MAX; + let mut oldest_ip = None; + for (ip, connections) in self.table.iter() { + for entry in connections { + let last_update = entry.last_update(); + if last_update < oldest { + oldest = last_update; + oldest_ip = Some(*ip); + } + } + } + self.table.remove(&oldest_ip.unwrap()); + self.total_size -= 1; + num_pruned += 1; + } + num_pruned + } + + fn try_add_connection( + &mut self, + addr: &SocketAddr, + last_update: u64, + max_connections_per_ip: usize, + ) -> Option<(Arc, Arc)> { + let connection_entry = self.table.entry(addr.ip()).or_insert_with(Vec::new); + let has_connection_capacity = connection_entry + .len() + .checked_add(1) + .map(|c| c <= max_connections_per_ip) + .unwrap_or(false); + if has_connection_capacity { + let exit = Arc::new(AtomicBool::new(false)); + let last_update = Arc::new(AtomicU64::new(last_update)); + connection_entry.push(ConnectionEntry::new( + exit.clone(), + last_update.clone(), + addr.port(), + )); + self.total_size += 1; + Some((last_update, exit)) + } else { + None + } + } + + fn remove_connection(&mut self, addr: &SocketAddr) { + if let Entry::Occupied(mut e) = self.table.entry(addr.ip()) { + let e_ref = e.get_mut(); + e_ref.retain(|connection| connection.port != addr.port()); + if e_ref.is_empty() { + e.remove_entry(); + } + self.total_size -= 1; + } + } +} + +#[derive(Default)] +struct StreamStats { + total_connections: AtomicUsize, + total_new_connections: AtomicUsize, + total_streams: AtomicUsize, + total_new_streams: AtomicUsize, + num_evictions: AtomicUsize, +} + +impl StreamStats { + fn report(&self) { + datapoint_info!( + "quic-connections", + ( + "active_connections", + self.total_connections.load(Ordering::Relaxed), + i64 + ), + ( + "active_streams", + self.total_streams.load(Ordering::Relaxed), + i64 + ), + ( + "new_connections", + self.total_new_connections.swap(0, Ordering::Relaxed), + i64 + ), + ( + "new_streams", + self.total_new_streams.swap(0, Ordering::Relaxed), + i64 + ), + ( + "evictions", + self.num_evictions.swap(0, Ordering::Relaxed), + i64 + ), + ); + } +} + +fn handle_connection( + mut uni_streams: IncomingUniStreams, + packet_sender: Sender, + remote_addr: SocketAddr, + last_update: Arc, + connection_table: Arc>, + stream_exit: Arc, + stats: Arc, +) { + tokio::spawn(async move { + debug!( + "quic new connection {} streams: {} connections: {}", + remote_addr, + stats.total_streams.load(Ordering::Relaxed), + stats.total_connections.load(Ordering::Relaxed), + ); + while !stream_exit.load(Ordering::Relaxed) { + match uni_streams.next().await { + Some(stream_result) => match stream_result { + Ok(mut stream) => { + stats.total_streams.fetch_add(1, Ordering::Relaxed); + stats.total_new_streams.fetch_add(1, Ordering::Relaxed); + let mut maybe_batch = None; + while !stream_exit.load(Ordering::Relaxed) { + if handle_chunk( + &stream.read_chunk(PACKET_DATA_SIZE, false).await, + &mut maybe_batch, + &remote_addr, + &packet_sender, + ) { + last_update.store(timing::timestamp(), Ordering::Relaxed); + break; + } + } + } + Err(e) => { + debug!("stream error: {:?}", e); + stats.total_streams.fetch_sub(1, Ordering::Relaxed); + break; + } + }, + None => { + stats.total_streams.fetch_sub(1, Ordering::Relaxed); + break; + } + } + } + connection_table + .lock() + .unwrap() + .remove_connection(&remote_addr); + stats.total_connections.fetch_sub(1, Ordering::Relaxed); + }); +} + pub fn spawn_server( sock: UdpSocket, keypair: &Keypair, gossip_host: IpAddr, packet_sender: Sender, exit: Arc, + max_connections_per_ip: usize, ) -> Result, QuicServerError> { let (config, _cert) = configure_server(keypair, gossip_host)?; @@ -206,8 +407,13 @@ pub fn spawn_server( .map_err(|_e| QuicServerError::EndpointFailed)? }; + let stats = Arc::new(StreamStats::default()); let handle = thread::spawn(move || { let handle = runtime.spawn(async move { + debug!("spawn quic server"); + let mut last_datapoint = Instant::now(); + let connection_table: Arc> = + Arc::new(Mutex::new(ConnectionTable::default())); while !exit.load(Ordering::Relaxed) { const WAIT_FOR_CONNECTION_TIMEOUT_MS: u64 = 1000; let timeout_connection = timeout( @@ -216,33 +422,49 @@ pub fn spawn_server( ) .await; + if last_datapoint.elapsed().as_secs() >= 5 { + stats.report(); + last_datapoint = Instant::now(); + } + if let Ok(Some(connection)) = timeout_connection { if let Ok(new_connection) = connection.await { - let exit = exit.clone(); + stats.total_connections.fetch_add(1, Ordering::Relaxed); + stats.total_new_connections.fetch_add(1, Ordering::Relaxed); let quinn::NewConnection { connection, - mut uni_streams, + uni_streams, .. } = new_connection; let remote_addr = connection.remote_address(); - let packet_sender = packet_sender.clone(); - tokio::spawn(async move { - debug!("new connection {}", remote_addr); - while let Some(Ok(mut stream)) = uni_streams.next().await { - let mut maybe_batch = None; - while !exit.load(Ordering::Relaxed) { - if handle_chunk( - &stream.read_chunk(PACKET_DATA_SIZE, false).await, - &mut maybe_batch, - &remote_addr, - &packet_sender, - ) { - break; - } - } - } - }); + + let mut connection_table_l = connection_table.lock().unwrap(); + const MAX_CONNECTION_TABLE_SIZE: usize = 5000; + let num_pruned = connection_table_l.prune_oldest(MAX_CONNECTION_TABLE_SIZE); + stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed); + + if let Some((last_update, stream_exit)) = connection_table_l + .try_add_connection( + &remote_addr, + timing::timestamp(), + max_connections_per_ip, + ) + { + drop(connection_table_l); + let packet_sender = packet_sender.clone(); + let stats = stats.clone(); + let connection_table1 = connection_table.clone(); + handle_connection( + uni_streams, + packet_sender, + remote_addr, + last_update, + connection_table1, + stream_exit, + stats, + ); + } } } } @@ -300,7 +522,7 @@ mod test { let (sender, _receiver) = unbounded(); let keypair = Keypair::new(); let ip = "127.0.0.1".parse().unwrap(); - let t = spawn_server(s, &keypair, ip, sender, exit.clone()).unwrap(); + let t = spawn_server(s, &keypair, ip, sender, exit.clone(), 1).unwrap(); exit.store(true, Ordering::Relaxed); t.join().unwrap(); } @@ -316,6 +538,35 @@ mod test { .unwrap() } + #[test] + fn test_quic_server_block_multiple_connections() { + solana_logger::setup(); + let s = UdpSocket::bind("127.0.0.1:0").unwrap(); + let exit = Arc::new(AtomicBool::new(false)); + let (sender, _receiver) = unbounded(); + let keypair = Keypair::new(); + let ip = "127.0.0.1".parse().unwrap(); + let server_address = s.local_addr().unwrap(); + let t = spawn_server(s, &keypair, ip, sender, exit.clone(), 1).unwrap(); + + let runtime = rt(); + let _rt_guard = runtime.enter(); + let conn1 = make_client_endpoint(&runtime, &server_address); + let conn2 = make_client_endpoint(&runtime, &server_address); + let handle = runtime.spawn(async move { + let mut s1 = conn1.connection.open_uni().await.unwrap(); + let mut s2 = conn2.connection.open_uni().await.unwrap(); + s1.write_all(&[0u8]).await.unwrap(); + s1.finish().await.unwrap(); + s2.write_all(&[0u8]) + .await + .expect_err("shouldn't be able to open 2 connections"); + }); + runtime.block_on(handle).unwrap(); + exit.store(true, Ordering::Relaxed); + t.join().unwrap(); + } + #[test] fn test_quic_server_multiple_streams() { solana_logger::setup(); @@ -325,7 +576,7 @@ mod test { let keypair = Keypair::new(); let ip = "127.0.0.1".parse().unwrap(); let server_address = s.local_addr().unwrap(); - let t = spawn_server(s, &keypair, ip, sender, exit.clone()).unwrap(); + let t = spawn_server(s, &keypair, ip, sender, exit.clone(), 2).unwrap(); let runtime = rt(); let _rt_guard = runtime.enter(); @@ -380,7 +631,7 @@ mod test { let keypair = Keypair::new(); let ip = "127.0.0.1".parse().unwrap(); let server_address = s.local_addr().unwrap(); - let t = spawn_server(s, &keypair, ip, sender, exit.clone()).unwrap(); + let t = spawn_server(s, &keypair, ip, sender, exit.clone(), 1).unwrap(); let runtime = rt(); let _rt_guard = runtime.enter(); @@ -420,4 +671,37 @@ mod test { exit.store(true, Ordering::Relaxed); t.join().unwrap(); } + + #[test] + fn test_prune_table() { + use std::net::Ipv4Addr; + solana_logger::setup(); + let mut table = ConnectionTable::default(); + let num_entries = 5; + let max_connections_per_ip = 10; + let sockets: Vec<_> = (0..num_entries) + .into_iter() + .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0)) + .collect(); + for (i, socket) in sockets.iter().enumerate() { + table + .try_add_connection(socket, i as u64, max_connections_per_ip) + .unwrap(); + } + let new_size = 3; + let pruned = table.prune_oldest(new_size); + assert_eq!(pruned, num_entries as usize - new_size); + for v in table.table.values() { + for x in v { + assert!(x.last_update() >= (num_entries as u64 - new_size as u64)); + } + } + assert_eq!(table.table.len(), new_size); + assert_eq!(table.total_size, new_size); + for socket in sockets.iter().take(num_entries as usize).skip(new_size - 1) { + table.remove_connection(socket); + } + info!("{:?}", table); + assert_eq!(table.total_size, 0); + } }