Remove the specific QUIC connection entry that disconnected (#29883)

This commit is contained in:
Pankaj Garg 2023-01-25 16:14:25 -08:00 committed by GitHub
parent b4d1769688
commit 3f9c974587
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 97 additions and 23 deletions

View File

@ -539,6 +539,7 @@ async fn handle_connection(
stats.total_streams.load(Ordering::Relaxed), stats.total_streams.load(Ordering::Relaxed),
stats.total_connections.load(Ordering::Relaxed), stats.total_connections.load(Ordering::Relaxed),
); );
let stable_id = connection.stable_id();
stats.total_connections.fetch_add(1, Ordering::Relaxed); stats.total_connections.fetch_add(1, Ordering::Relaxed);
while !stream_exit.load(Ordering::Relaxed) { while !stream_exit.load(Ordering::Relaxed) {
if let Ok(stream) = tokio::time::timeout( if let Ok(stream) = tokio::time::timeout(
@ -606,11 +607,15 @@ async fn handle_connection(
} }
} }
if connection_table.lock().unwrap().remove_connection( let removed_connection_count = connection_table.lock().unwrap().remove_connection(
ConnectionTableKey::new(remote_addr.ip(), remote_pubkey), ConnectionTableKey::new(remote_addr.ip(), remote_pubkey),
remote_addr.port(), remote_addr.port(),
) { stable_id,
stats.connection_removed.fetch_add(1, Ordering::Relaxed); );
if removed_connection_count > 0 {
stats
.connection_removed
.fetch_add(removed_connection_count, Ordering::Relaxed);
} else { } else {
stats stats
.connection_remove_failed .connection_remove_failed
@ -905,21 +910,34 @@ impl ConnectionTable {
} }
} }
fn remove_connection(&mut self, key: ConnectionTableKey, port: u16) -> bool { // Returns number of connections that were removed
fn remove_connection(&mut self, key: ConnectionTableKey, port: u16, stable_id: usize) -> usize {
if let Entry::Occupied(mut e) = self.table.entry(key) { if let Entry::Occupied(mut e) = self.table.entry(key) {
let e_ref = e.get_mut(); let e_ref = e.get_mut();
let old_size = e_ref.len(); let old_size = e_ref.len();
e_ref.retain(|connection| connection.port != port);
e_ref.retain(|connection_entry| {
// Retain the connection entry if the port is different, or if the connection's
// stable_id doesn't match the provided stable_id.
// (Some unit tests do not fill in a valid connection in the table. To support that,
// if the connection is none, the stable_id check is ignored. i.e. if the port matches,
// the connection gets removed)
connection_entry.port != port
|| connection_entry
.connection
.as_ref()
.and_then(|connection| (connection.stable_id() != stable_id).then_some(0))
.is_some()
});
let new_size = e_ref.len(); let new_size = e_ref.len();
if e_ref.is_empty() { if e_ref.is_empty() {
e.remove_entry(); e.remove_entry();
} }
self.total_size = self let connections_removed = old_size.saturating_sub(new_size);
.total_size self.total_size = self.total_size.saturating_sub(connections_removed);
.saturating_sub(old_size.saturating_sub(new_size)); connections_removed
true
} else { } else {
false 0
} }
} }
} }
@ -993,6 +1011,7 @@ pub mod test {
fn setup_quic_server( fn setup_quic_server(
option_staked_nodes: Option<StakedNodes>, option_staked_nodes: Option<StakedNodes>,
max_connections_per_peer: usize,
) -> ( ) -> (
JoinHandle<()>, JoinHandle<()>,
Arc<AtomicBool>, Arc<AtomicBool>,
@ -1014,7 +1033,7 @@ pub mod test {
ip, ip,
sender, sender,
exit.clone(), exit.clone(),
1, max_connections_per_peer,
staked_nodes, staked_nodes,
MAX_STAKED_CONNECTIONS, MAX_STAKED_CONNECTIONS,
MAX_UNSTAKED_CONNECTIONS, MAX_UNSTAKED_CONNECTIONS,
@ -1176,7 +1195,7 @@ pub mod test {
#[tokio::test] #[tokio::test]
async fn test_quic_server_exit() { async fn test_quic_server_exit() {
let (t, exit, _receiver, _server_address, _stats) = setup_quic_server(None); let (t, exit, _receiver, _server_address, _stats) = setup_quic_server(None, 1);
exit.store(true, Ordering::Relaxed); exit.store(true, Ordering::Relaxed);
t.await.unwrap(); t.await.unwrap();
} }
@ -1184,7 +1203,7 @@ pub mod test {
#[tokio::test] #[tokio::test]
async fn test_quic_timeout() { async fn test_quic_timeout() {
solana_logger::setup(); solana_logger::setup();
let (t, exit, receiver, server_address, _stats) = setup_quic_server(None); let (t, exit, receiver, server_address, _stats) = setup_quic_server(None, 1);
check_timeout(receiver, server_address).await; check_timeout(receiver, server_address).await;
exit.store(true, Ordering::Relaxed); exit.store(true, Ordering::Relaxed);
t.await.unwrap(); t.await.unwrap();
@ -1193,7 +1212,7 @@ pub mod test {
#[tokio::test] #[tokio::test]
async fn test_quic_stream_timeout() { async fn test_quic_stream_timeout() {
solana_logger::setup(); solana_logger::setup();
let (t, exit, _receiver, server_address, stats) = setup_quic_server(None); let (t, exit, _receiver, server_address, stats) = setup_quic_server(None, 1);
let conn1 = make_client_endpoint(&server_address, None).await; let conn1 = make_client_endpoint(&server_address, None).await;
assert_eq!(stats.total_streams.load(Ordering::Relaxed), 0); assert_eq!(stats.total_streams.load(Ordering::Relaxed), 0);
@ -1223,16 +1242,71 @@ pub mod test {
#[tokio::test] #[tokio::test]
async fn test_quic_server_block_multiple_connections() { async fn test_quic_server_block_multiple_connections() {
solana_logger::setup(); solana_logger::setup();
let (t, exit, _receiver, server_address, _stats) = setup_quic_server(None); let (t, exit, _receiver, server_address, _stats) = setup_quic_server(None, 1);
check_block_multiple_connections(server_address).await; check_block_multiple_connections(server_address).await;
exit.store(true, Ordering::Relaxed); exit.store(true, Ordering::Relaxed);
t.await.unwrap(); t.await.unwrap();
} }
#[tokio::test]
async fn test_quic_server_multiple_connections_on_single_client_endpoint() {
solana_logger::setup();
let (t, exit, _receiver, server_address, stats) = setup_quic_server(None, 2);
let client_socket = UdpSocket::bind("127.0.0.1:0").unwrap();
let mut endpoint =
quinn::Endpoint::new(EndpointConfig::default(), None, client_socket, TokioRuntime)
.unwrap();
let default_keypair = Keypair::new();
endpoint.set_default_client_config(get_client_config(&default_keypair));
let conn1 = endpoint
.connect(server_address, "localhost")
.expect("Failed in connecting")
.await
.expect("Failed in waiting");
let conn2 = endpoint
.connect(server_address, "localhost")
.expect("Failed in connecting")
.await
.expect("Failed in waiting");
let mut s1 = conn1.open_uni().await.unwrap();
s1.write_all(&[0u8]).await.unwrap();
s1.finish().await.unwrap();
let mut s2 = conn2.open_uni().await.unwrap();
conn1.close(
CONNECTION_CLOSE_CODE_DROPPED_ENTRY.into(),
CONNECTION_CLOSE_REASON_DROPPED_ENTRY,
);
// Wait long enough for the stream to timeout in receiving chunks
let sleep_time = (WAIT_FOR_STREAM_TIMEOUT_MS * 1000).min(1000);
sleep(Duration::from_millis(sleep_time)).await;
assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 1);
s2.write_all(&[0u8]).await.unwrap();
s2.finish().await.unwrap();
conn2.close(
CONNECTION_CLOSE_CODE_DROPPED_ENTRY.into(),
CONNECTION_CLOSE_REASON_DROPPED_ENTRY,
);
// Wait long enough for the stream to timeout in receiving chunks
let sleep_time = (WAIT_FOR_STREAM_TIMEOUT_MS * 1000).min(1000);
sleep(Duration::from_millis(sleep_time)).await;
assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 2);
exit.store(true, Ordering::Relaxed);
t.await.unwrap();
}
#[tokio::test] #[tokio::test]
async fn test_quic_server_multiple_writes() { async fn test_quic_server_multiple_writes() {
solana_logger::setup(); solana_logger::setup();
let (t, exit, receiver, server_address, _stats) = setup_quic_server(None); let (t, exit, receiver, server_address, _stats) = setup_quic_server(None, 1);
check_multiple_writes(receiver, server_address, None).await; check_multiple_writes(receiver, server_address, None).await;
exit.store(true, Ordering::Relaxed); exit.store(true, Ordering::Relaxed);
t.await.unwrap(); t.await.unwrap();
@ -1249,7 +1323,7 @@ pub mod test {
.insert(client_keypair.pubkey(), 100000); .insert(client_keypair.pubkey(), 100000);
staked_nodes.total_stake = 100000; staked_nodes.total_stake = 100000;
let (t, exit, receiver, server_address, stats) = setup_quic_server(Some(staked_nodes)); let (t, exit, receiver, server_address, stats) = setup_quic_server(Some(staked_nodes), 1);
check_multiple_writes(receiver, server_address, Some(&client_keypair)).await; check_multiple_writes(receiver, server_address, Some(&client_keypair)).await;
exit.store(true, Ordering::Relaxed); exit.store(true, Ordering::Relaxed);
t.await.unwrap(); t.await.unwrap();
@ -1276,7 +1350,7 @@ pub mod test {
.insert(client_keypair.pubkey(), 0); .insert(client_keypair.pubkey(), 0);
staked_nodes.total_stake = 0; staked_nodes.total_stake = 0;
let (t, exit, receiver, server_address, stats) = setup_quic_server(Some(staked_nodes)); let (t, exit, receiver, server_address, stats) = setup_quic_server(Some(staked_nodes), 1);
check_multiple_writes(receiver, server_address, Some(&client_keypair)).await; check_multiple_writes(receiver, server_address, Some(&client_keypair)).await;
exit.store(true, Ordering::Relaxed); exit.store(true, Ordering::Relaxed);
t.await.unwrap(); t.await.unwrap();
@ -1294,7 +1368,7 @@ pub mod test {
#[tokio::test] #[tokio::test]
async fn test_quic_server_unstaked_connection_removal() { async fn test_quic_server_unstaked_connection_removal() {
solana_logger::setup(); solana_logger::setup();
let (t, exit, receiver, server_address, stats) = setup_quic_server(None); let (t, exit, receiver, server_address, stats) = setup_quic_server(None, 1);
check_multiple_writes(receiver, server_address, None).await; check_multiple_writes(receiver, server_address, None).await;
exit.store(true, Ordering::Relaxed); exit.store(true, Ordering::Relaxed);
t.await.unwrap(); t.await.unwrap();
@ -1422,7 +1496,7 @@ pub mod test {
assert_eq!(table.table.len(), new_size); assert_eq!(table.table.len(), new_size);
assert_eq!(table.total_size, new_size); assert_eq!(table.total_size, new_size);
for socket in sockets.iter().take(num_entries as usize).skip(new_size - 1) { for socket in sockets.iter().take(num_entries as usize).skip(new_size - 1) {
table.remove_connection(ConnectionTableKey::IP(socket.ip()), socket.port()); table.remove_connection(ConnectionTableKey::IP(socket.ip()), socket.port(), 0);
} }
assert_eq!(table.total_size, 0); assert_eq!(table.total_size, 0);
} }
@ -1457,7 +1531,7 @@ pub mod test {
assert_eq!(table.table.len(), new_size); assert_eq!(table.table.len(), new_size);
assert_eq!(table.total_size, new_size); assert_eq!(table.total_size, new_size);
for pubkey in pubkeys.iter().take(num_entries as usize).skip(new_size - 1) { for pubkey in pubkeys.iter().take(num_entries as usize).skip(new_size - 1) {
table.remove_connection(ConnectionTableKey::Pubkey(*pubkey), 0); table.remove_connection(ConnectionTableKey::Pubkey(*pubkey), 0, 0);
} }
assert_eq!(table.total_size, 0); assert_eq!(table.total_size, 0);
} }
@ -1517,7 +1591,7 @@ pub mod test {
assert!(table.table.len() <= new_max_size); assert!(table.table.len() <= new_max_size);
assert!(table.total_size <= new_max_size); assert!(table.total_size <= new_max_size);
table.remove_connection(ConnectionTableKey::Pubkey(pubkey2), 0); table.remove_connection(ConnectionTableKey::Pubkey(pubkey2), 0, 0);
assert_eq!(table.total_size, 0); assert_eq!(table.total_size, 0);
} }
@ -1609,7 +1683,7 @@ pub mod test {
sockets.push(zero_connection_addr); sockets.push(zero_connection_addr);
for socket in sockets.iter() { for socket in sockets.iter() {
table.remove_connection(ConnectionTableKey::IP(socket.ip()), socket.port()); table.remove_connection(ConnectionTableKey::IP(socket.ip()), socket.port(), 0);
} }
assert_eq!(table.total_size, 0); assert_eq!(table.total_size, 0);
} }