diff --git a/streamer/src/nonblocking/quic.rs b/streamer/src/nonblocking/quic.rs index 42fadcc27d..1b55d958ec 100644 --- a/streamer/src/nonblocking/quic.rs +++ b/streamer/src/nonblocking/quic.rs @@ -80,10 +80,11 @@ pub async fn run_server( ) { debug!("spawn quic server"); let mut last_datapoint = Instant::now(); - let connection_table: Arc> = - Arc::new(Mutex::new(ConnectionTable::default())); + let unstaked_connection_table: Arc> = Arc::new(Mutex::new( + ConnectionTable::new(ConnectionPeerType::Unstaked), + )); let staked_connection_table: Arc> = - Arc::new(Mutex::new(ConnectionTable::default())); + Arc::new(Mutex::new(ConnectionTable::new(ConnectionPeerType::Staked))); while !exit.load(Ordering::Relaxed) { const WAIT_FOR_CONNECTION_TIMEOUT_MS: u64 = 1000; const WAIT_BETWEEN_NEW_CONNECTIONS_US: u64 = 1000; @@ -101,7 +102,7 @@ pub async fn run_server( if let Ok(Some(connection)) = timeout_connection { tokio::spawn(setup_connection( connection, - connection_table.clone(), + unstaked_connection_table.clone(), staked_connection_table.clone(), packet_sender.clone(), max_connections_per_ip, @@ -117,7 +118,7 @@ pub async fn run_server( async fn setup_connection( connection: Connecting, - connection_table: Arc>, + unstaked_connection_table: Arc>, staked_connection_table: Arc>, packet_sender: Sender, max_connections_per_ip: usize, @@ -162,7 +163,7 @@ async fn setup_connection( (connection_table_l, stake) } else { drop(staked_nodes); - let mut connection_table_l = connection_table.lock().unwrap(); + let mut connection_table_l = unstaked_connection_table.lock().unwrap(); if connection_table_l.total_size >= max_unstaked_connections { let max_connections = max_percentage_full.apply_to(max_unstaked_connections); let num_pruned = connection_table_l.prune_oldest(max_connections); @@ -182,15 +183,19 @@ async fn setup_connection( timing::timestamp(), max_connections_per_ip, ) { + let table_type = connection_table_l.peer_type; drop(connection_table_l); let stats = stats.clone(); - let connection_table1 = connection_table.clone(); + let connection_table = match table_type { + ConnectionPeerType::Unstaked => unstaked_connection_table.clone(), + ConnectionPeerType::Staked => staked_connection_table.clone(), + }; tokio::spawn(handle_connection( uni_streams, packet_sender, remote_addr, last_update, - connection_table1, + connection_table, stream_exit, stats, stake, @@ -284,10 +289,17 @@ async fn handle_connection( } } } - connection_table + if connection_table .lock() .unwrap() - .remove_connection(&remote_addr); + .remove_connection(&remote_addr) + { + stats.connection_removed.fetch_add(1, Ordering::Relaxed); + } else { + stats + .connection_remove_failed + .fetch_add(1, Ordering::Relaxed); + } stats.total_connections.fetch_sub(1, Ordering::Relaxed); } @@ -409,16 +421,30 @@ impl Drop for ConnectionEntry { } } +#[derive(Copy, Clone)] +enum ConnectionPeerType { + Unstaked, + Staked, +} + // Map of IP to list of connection entries -#[derive(Default, Debug)] struct ConnectionTable { table: HashMap>, total_size: usize, + peer_type: ConnectionPeerType, } // Prune the connection which has the oldest update // Return number pruned impl ConnectionTable { + fn new(peer_type: ConnectionPeerType) -> Self { + Self { + table: HashMap::default(), + total_size: 0, + peer_type, + } + } + fn prune_oldest(&mut self, max_size: usize) -> usize { let mut num_pruned = 0; while self.total_size > max_size { @@ -476,7 +502,7 @@ impl ConnectionTable { } } - fn remove_connection(&mut self, addr: &SocketAddr) { + fn remove_connection(&mut self, addr: &SocketAddr) -> bool { if let Entry::Occupied(mut e) = self.table.entry(addr.ip()) { let e_ref = e.get_mut(); let old_size = e_ref.len(); @@ -488,6 +514,9 @@ impl ConnectionTable { self.total_size = self .total_size .saturating_sub(old_size.saturating_sub(new_size)); + true + } else { + false } } } @@ -503,6 +532,7 @@ pub mod test { quic::{QUIC_KEEP_ALIVE_MS, QUIC_MAX_TIMEOUT_MS}, signature::Keypair, }, + std::net::Ipv4Addr, tokio::time::sleep, }; @@ -543,7 +573,9 @@ pub mod test { config } - fn setup_quic_server() -> ( + fn setup_quic_server( + option_staked_nodes: Option, + ) -> ( JoinHandle<()>, Arc, crossbeam_channel::Receiver, @@ -556,7 +588,7 @@ pub mod test { let keypair = Keypair::new(); let ip = "127.0.0.1".parse().unwrap(); let server_address = s.local_addr().unwrap(); - let staked_nodes = Arc::new(RwLock::new(StakedNodes::default())); + let staked_nodes = Arc::new(RwLock::new(option_staked_nodes.unwrap_or_default())); let stats = Arc::new(StreamStats::default()); let t = spawn_server( s, @@ -714,7 +746,7 @@ pub mod test { #[tokio::test] async fn test_quic_server_exit() { - let (t, exit, _receiver, _server_address, _stats) = setup_quic_server(); + let (t, exit, _receiver, _server_address, _stats) = setup_quic_server(None); exit.store(true, Ordering::Relaxed); t.await.unwrap(); } @@ -722,7 +754,7 @@ pub mod test { #[tokio::test] async fn test_quic_timeout() { solana_logger::setup(); - let (t, exit, receiver, server_address, _stats) = setup_quic_server(); + let (t, exit, receiver, server_address, _stats) = setup_quic_server(None); check_timeout(receiver, server_address).await; exit.store(true, Ordering::Relaxed); t.await.unwrap(); @@ -731,7 +763,7 @@ pub mod test { #[tokio::test] async fn test_quic_stream_timeout() { solana_logger::setup(); - let (t, exit, _receiver, server_address, stats) = setup_quic_server(); + let (t, exit, _receiver, server_address, stats) = setup_quic_server(None); let conn1 = make_client_endpoint(&server_address).await; assert_eq!(stats.total_streams.load(Ordering::Relaxed), 0); @@ -763,7 +795,7 @@ pub mod test { #[tokio::test] async fn test_quic_server_block_multiple_connections() { solana_logger::setup(); - let (t, exit, _receiver, server_address, _stats) = setup_quic_server(); + let (t, exit, _receiver, server_address, _stats) = setup_quic_server(None); check_block_multiple_connections(server_address).await; exit.store(true, Ordering::Relaxed); t.await.unwrap(); @@ -772,12 +804,41 @@ pub mod test { #[tokio::test] async fn test_quic_server_multiple_writes() { solana_logger::setup(); - let (t, exit, receiver, server_address, _stats) = setup_quic_server(); + let (t, exit, receiver, server_address, _stats) = setup_quic_server(None); check_multiple_writes(receiver, server_address).await; exit.store(true, Ordering::Relaxed); t.await.unwrap(); } + #[tokio::test] + async fn test_quic_server_staked_connection_removal() { + solana_logger::setup(); + + let mut staked_nodes = StakedNodes::default(); + staked_nodes + .stake_map + .insert(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 100000); + staked_nodes.total_stake = 100000_f64; + + let (t, exit, receiver, server_address, stats) = setup_quic_server(Some(staked_nodes)); + check_multiple_writes(receiver, server_address).await; + exit.store(true, Ordering::Relaxed); + t.await.unwrap(); + assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 1); + assert_eq!(stats.connection_remove_failed.load(Ordering::Relaxed), 0); + } + + #[tokio::test] + async fn test_quic_server_unstaked_connection_removal() { + solana_logger::setup(); + let (t, exit, receiver, server_address, stats) = setup_quic_server(None); + check_multiple_writes(receiver, server_address).await; + exit.store(true, Ordering::Relaxed); + t.await.unwrap(); + assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 1); + assert_eq!(stats.connection_remove_failed.load(Ordering::Relaxed), 0); + } + #[tokio::test] async fn test_quic_server_unstaked_node_connect_failure() { solana_logger::setup(); @@ -848,7 +909,7 @@ pub mod test { fn test_prune_table() { use std::net::Ipv4Addr; solana_logger::setup(); - let mut table = ConnectionTable::default(); + let mut table = ConnectionTable::new(ConnectionPeerType::Staked); let mut num_entries = 5; let max_connections_per_ip = 10; let sockets: Vec<_> = (0..num_entries) @@ -885,7 +946,7 @@ pub mod test { fn test_remove_connections() { use std::net::Ipv4Addr; solana_logger::setup(); - let mut table = ConnectionTable::default(); + let mut table = ConnectionTable::new(ConnectionPeerType::Staked); let num_ips = 5; let max_connections_per_ip = 10; let mut sockets: Vec<_> = (0..num_ips) diff --git a/streamer/src/quic.rs b/streamer/src/quic.rs index 7f8d74c742..f1c17c1565 100644 --- a/streamer/src/quic.rs +++ b/streamer/src/quic.rs @@ -160,6 +160,8 @@ pub struct StreamStats { pub(crate) connection_add_failed: AtomicUsize, pub(crate) connection_add_failed_unstaked_node: AtomicUsize, pub(crate) connection_setup_timeout: AtomicUsize, + pub(crate) connection_removed: AtomicUsize, + pub(crate) connection_remove_failed: AtomicUsize, } impl StreamStats { @@ -202,6 +204,16 @@ impl StreamStats { .swap(0, Ordering::Relaxed), i64 ), + ( + "connection_removed", + self.connection_removed.swap(0, Ordering::Relaxed), + i64 + ), + ( + "connection_remove_failed", + self.connection_remove_failed.swap(0, Ordering::Relaxed), + i64 + ), ( "connection_setup_timeout", self.connection_setup_timeout.swap(0, Ordering::Relaxed),