From 3f9c974587bdef7fb89476ca177cb3769ab41eec Mon Sep 17 00:00:00 2001 From: Pankaj Garg Date: Wed, 25 Jan 2023 16:14:25 -0800 Subject: [PATCH] Remove the specific QUIC connection entry that disconnected (#29883) --- streamer/src/nonblocking/quic.rs | 120 +++++++++++++++++++++++++------ 1 file changed, 97 insertions(+), 23 deletions(-) diff --git a/streamer/src/nonblocking/quic.rs b/streamer/src/nonblocking/quic.rs index 022f744c90..3a53850012 100644 --- a/streamer/src/nonblocking/quic.rs +++ b/streamer/src/nonblocking/quic.rs @@ -539,6 +539,7 @@ async fn handle_connection( stats.total_streams.load(Ordering::Relaxed), stats.total_connections.load(Ordering::Relaxed), ); + let stable_id = connection.stable_id(); stats.total_connections.fetch_add(1, Ordering::Relaxed); while !stream_exit.load(Ordering::Relaxed) { if let Ok(stream) = tokio::time::timeout( @@ -606,11 +607,15 @@ async fn handle_connection( } } - if connection_table.lock().unwrap().remove_connection( + let removed_connection_count = connection_table.lock().unwrap().remove_connection( ConnectionTableKey::new(remote_addr.ip(), remote_pubkey), remote_addr.port(), - ) { - stats.connection_removed.fetch_add(1, Ordering::Relaxed); + stable_id, + ); + if removed_connection_count > 0 { + stats + .connection_removed + .fetch_add(removed_connection_count, Ordering::Relaxed); } else { stats .connection_remove_failed @@ -905,21 +910,34 @@ impl ConnectionTable { } } - fn remove_connection(&mut self, key: ConnectionTableKey, port: u16) -> bool { + // Returns number of connections that were removed + fn remove_connection(&mut self, key: ConnectionTableKey, port: u16, stable_id: usize) -> usize { if let Entry::Occupied(mut e) = self.table.entry(key) { let e_ref = e.get_mut(); let old_size = e_ref.len(); - e_ref.retain(|connection| connection.port != port); + + e_ref.retain(|connection_entry| { + // Retain the connection entry if the port is different, or if the connection's + // stable_id doesn't match the provided stable_id. + // (Some unit tests do not fill in a valid connection in the table. To support that, + // if the connection is none, the stable_id check is ignored. i.e. if the port matches, + // the connection gets removed) + connection_entry.port != port + || connection_entry + .connection + .as_ref() + .and_then(|connection| (connection.stable_id() != stable_id).then_some(0)) + .is_some() + }); let new_size = e_ref.len(); if e_ref.is_empty() { e.remove_entry(); } - self.total_size = self - .total_size - .saturating_sub(old_size.saturating_sub(new_size)); - true + let connections_removed = old_size.saturating_sub(new_size); + self.total_size = self.total_size.saturating_sub(connections_removed); + connections_removed } else { - false + 0 } } } @@ -993,6 +1011,7 @@ pub mod test { fn setup_quic_server( option_staked_nodes: Option, + max_connections_per_peer: usize, ) -> ( JoinHandle<()>, Arc, @@ -1014,7 +1033,7 @@ pub mod test { ip, sender, exit.clone(), - 1, + max_connections_per_peer, staked_nodes, MAX_STAKED_CONNECTIONS, MAX_UNSTAKED_CONNECTIONS, @@ -1176,7 +1195,7 @@ pub mod test { #[tokio::test] async fn test_quic_server_exit() { - let (t, exit, _receiver, _server_address, _stats) = setup_quic_server(None); + let (t, exit, _receiver, _server_address, _stats) = setup_quic_server(None, 1); exit.store(true, Ordering::Relaxed); t.await.unwrap(); } @@ -1184,7 +1203,7 @@ pub mod test { #[tokio::test] async fn test_quic_timeout() { solana_logger::setup(); - let (t, exit, receiver, server_address, _stats) = setup_quic_server(None); + let (t, exit, receiver, server_address, _stats) = setup_quic_server(None, 1); check_timeout(receiver, server_address).await; exit.store(true, Ordering::Relaxed); t.await.unwrap(); @@ -1193,7 +1212,7 @@ pub mod test { #[tokio::test] async fn test_quic_stream_timeout() { solana_logger::setup(); - let (t, exit, _receiver, server_address, stats) = setup_quic_server(None); + let (t, exit, _receiver, server_address, stats) = setup_quic_server(None, 1); let conn1 = make_client_endpoint(&server_address, None).await; assert_eq!(stats.total_streams.load(Ordering::Relaxed), 0); @@ -1223,16 +1242,71 @@ pub mod test { #[tokio::test] async fn test_quic_server_block_multiple_connections() { solana_logger::setup(); - let (t, exit, _receiver, server_address, _stats) = setup_quic_server(None); + let (t, exit, _receiver, server_address, _stats) = setup_quic_server(None, 1); check_block_multiple_connections(server_address).await; exit.store(true, Ordering::Relaxed); t.await.unwrap(); } + #[tokio::test] + async fn test_quic_server_multiple_connections_on_single_client_endpoint() { + solana_logger::setup(); + let (t, exit, _receiver, server_address, stats) = setup_quic_server(None, 2); + + let client_socket = UdpSocket::bind("127.0.0.1:0").unwrap(); + let mut endpoint = + quinn::Endpoint::new(EndpointConfig::default(), None, client_socket, TokioRuntime) + .unwrap(); + let default_keypair = Keypair::new(); + endpoint.set_default_client_config(get_client_config(&default_keypair)); + let conn1 = endpoint + .connect(server_address, "localhost") + .expect("Failed in connecting") + .await + .expect("Failed in waiting"); + + let conn2 = endpoint + .connect(server_address, "localhost") + .expect("Failed in connecting") + .await + .expect("Failed in waiting"); + + let mut s1 = conn1.open_uni().await.unwrap(); + s1.write_all(&[0u8]).await.unwrap(); + s1.finish().await.unwrap(); + + let mut s2 = conn2.open_uni().await.unwrap(); + conn1.close( + CONNECTION_CLOSE_CODE_DROPPED_ENTRY.into(), + CONNECTION_CLOSE_REASON_DROPPED_ENTRY, + ); + // Wait long enough for the stream to timeout in receiving chunks + let sleep_time = (WAIT_FOR_STREAM_TIMEOUT_MS * 1000).min(1000); + sleep(Duration::from_millis(sleep_time)).await; + + assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 1); + + s2.write_all(&[0u8]).await.unwrap(); + s2.finish().await.unwrap(); + + conn2.close( + CONNECTION_CLOSE_CODE_DROPPED_ENTRY.into(), + CONNECTION_CLOSE_REASON_DROPPED_ENTRY, + ); + // Wait long enough for the stream to timeout in receiving chunks + let sleep_time = (WAIT_FOR_STREAM_TIMEOUT_MS * 1000).min(1000); + sleep(Duration::from_millis(sleep_time)).await; + + assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 2); + + exit.store(true, Ordering::Relaxed); + t.await.unwrap(); + } + #[tokio::test] async fn test_quic_server_multiple_writes() { solana_logger::setup(); - let (t, exit, receiver, server_address, _stats) = setup_quic_server(None); + let (t, exit, receiver, server_address, _stats) = setup_quic_server(None, 1); check_multiple_writes(receiver, server_address, None).await; exit.store(true, Ordering::Relaxed); t.await.unwrap(); @@ -1249,7 +1323,7 @@ pub mod test { .insert(client_keypair.pubkey(), 100000); staked_nodes.total_stake = 100000; - let (t, exit, receiver, server_address, stats) = setup_quic_server(Some(staked_nodes)); + let (t, exit, receiver, server_address, stats) = setup_quic_server(Some(staked_nodes), 1); check_multiple_writes(receiver, server_address, Some(&client_keypair)).await; exit.store(true, Ordering::Relaxed); t.await.unwrap(); @@ -1276,7 +1350,7 @@ pub mod test { .insert(client_keypair.pubkey(), 0); staked_nodes.total_stake = 0; - let (t, exit, receiver, server_address, stats) = setup_quic_server(Some(staked_nodes)); + let (t, exit, receiver, server_address, stats) = setup_quic_server(Some(staked_nodes), 1); check_multiple_writes(receiver, server_address, Some(&client_keypair)).await; exit.store(true, Ordering::Relaxed); t.await.unwrap(); @@ -1294,7 +1368,7 @@ pub mod test { #[tokio::test] async fn test_quic_server_unstaked_connection_removal() { solana_logger::setup(); - let (t, exit, receiver, server_address, stats) = setup_quic_server(None); + let (t, exit, receiver, server_address, stats) = setup_quic_server(None, 1); check_multiple_writes(receiver, server_address, None).await; exit.store(true, Ordering::Relaxed); t.await.unwrap(); @@ -1422,7 +1496,7 @@ pub mod test { assert_eq!(table.table.len(), new_size); assert_eq!(table.total_size, new_size); for socket in sockets.iter().take(num_entries as usize).skip(new_size - 1) { - table.remove_connection(ConnectionTableKey::IP(socket.ip()), socket.port()); + table.remove_connection(ConnectionTableKey::IP(socket.ip()), socket.port(), 0); } assert_eq!(table.total_size, 0); } @@ -1457,7 +1531,7 @@ pub mod test { assert_eq!(table.table.len(), new_size); assert_eq!(table.total_size, new_size); for pubkey in pubkeys.iter().take(num_entries as usize).skip(new_size - 1) { - table.remove_connection(ConnectionTableKey::Pubkey(*pubkey), 0); + table.remove_connection(ConnectionTableKey::Pubkey(*pubkey), 0, 0); } assert_eq!(table.total_size, 0); } @@ -1517,7 +1591,7 @@ pub mod test { assert!(table.table.len() <= new_max_size); assert!(table.total_size <= new_max_size); - table.remove_connection(ConnectionTableKey::Pubkey(pubkey2), 0); + table.remove_connection(ConnectionTableKey::Pubkey(pubkey2), 0, 0); assert_eq!(table.total_size, 0); } @@ -1609,7 +1683,7 @@ pub mod test { sockets.push(zero_connection_addr); for socket in sockets.iter() { - table.remove_connection(ConnectionTableKey::IP(socket.ip()), socket.port()); + table.remove_connection(ConnectionTableKey::IP(socket.ip()), socket.port(), 0); } assert_eq!(table.total_size, 0); }