From 94685e1222b3289859a447d62fadea20898241e0 Mon Sep 17 00:00:00 2001 From: Pankaj Garg Date: Thu, 30 Jun 2022 17:56:15 -0700 Subject: [PATCH] Implement randomized pruning of QUIC connection from staked peers (#26299) --- Cargo.lock | 1 + core/src/staked_nodes_updater_service.rs | 2 +- programs/bpf/Cargo.lock | 1 + streamer/Cargo.toml | 1 + streamer/src/nonblocking/quic.rs | 240 +++++++++++++++++------ streamer/src/quic.rs | 14 ++ streamer/src/streamer.rs | 2 +- 7 files changed, 201 insertions(+), 60 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index df37853f86..a40406c770 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6156,6 +6156,7 @@ dependencies = [ "crossbeam-channel", "futures-util", "histogram", + "indexmap", "itertools", "libc", "log", diff --git a/core/src/staked_nodes_updater_service.rs b/core/src/staked_nodes_updater_service.rs index e26740a3c7..23a3587b0d 100644 --- a/core/src/staked_nodes_updater_service.rs +++ b/core/src/staked_nodes_updater_service.rs @@ -42,7 +42,7 @@ impl StakedNodesUpdaterService { &cluster_info, ) { let mut shared = shared_staked_nodes.write().unwrap(); - shared.total_stake = total_stake as f64; + shared.total_stake = total_stake; shared.stake_map = new_ip_to_stake; } } diff --git a/programs/bpf/Cargo.lock b/programs/bpf/Cargo.lock index 7d5cb0c362..cf966d1b29 100644 --- a/programs/bpf/Cargo.lock +++ b/programs/bpf/Cargo.lock @@ -5446,6 +5446,7 @@ dependencies = [ "crossbeam-channel", "futures-util", "histogram", + "indexmap", "itertools", "libc", "log", diff --git a/streamer/Cargo.toml b/streamer/Cargo.toml index 3dcde3f090..7a89d07731 100644 --- a/streamer/Cargo.toml +++ b/streamer/Cargo.toml @@ -13,6 +13,7 @@ edition = "2021" crossbeam-channel = "0.5" futures-util = "0.3.21" histogram = "0.6.9" +indexmap = "1.8.1" itertools = "0.10.3" libc = "0.2.126" log = "0.4.17" diff --git a/streamer/src/nonblocking/quic.rs b/streamer/src/nonblocking/quic.rs index 1b55d958ec..133f26ca63 100644 --- a/streamer/src/nonblocking/quic.rs +++ b/streamer/src/nonblocking/quic.rs @@ -5,11 +5,13 @@ use { }, crossbeam_channel::Sender, futures_util::stream::StreamExt, + indexmap::map::{Entry, IndexMap}, percentage::Percentage, quinn::{ Connecting, Connection, Endpoint, EndpointConfig, Incoming, IncomingUniStreams, NewConnection, VarInt, }, + rand::{thread_rng, Rng}, solana_perf::packet::PacketBatch, solana_sdk::{ packet::{Packet, PACKET_DATA_SIZE}, @@ -18,11 +20,10 @@ use { timing, }, std::{ - collections::{hash_map::Entry, HashMap}, net::{IpAddr, SocketAddr, UdpSocket}, sync::{ atomic::{AtomicBool, AtomicU64, Ordering}, - Arc, Mutex, RwLock, + Arc, Mutex, MutexGuard, RwLock, }, time::{Duration, Instant}, }, @@ -116,6 +117,21 @@ pub async fn run_server( } } +fn prune_unstaked_connection_table( + unstaked_connection_table: &mut MutexGuard, + max_unstaked_connections: usize, + stats: Arc, +) { + if unstaked_connection_table.total_size >= max_unstaked_connections { + const PRUNE_TABLE_TO_PERCENTAGE: u8 = 90; + let max_percentage_full = Percentage::from(PRUNE_TABLE_TO_PERCENTAGE); + + let max_connections = max_percentage_full.apply_to(max_unstaked_connections); + let num_pruned = unstaked_connection_table.prune_oldest(max_connections); + stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed); + } +} + async fn setup_connection( connection: Connecting, unstaked_connection_table: Arc>, @@ -138,70 +154,102 @@ async fn setup_connection( let remote_addr = connection.remote_address(); - let (mut connection_table_l, stake) = { - const PRUNE_TABLE_TO_PERCENTAGE: u8 = 90; - let max_percentage_full = Percentage::from(PRUNE_TABLE_TO_PERCENTAGE); - + let table_and_stake = { let staked_nodes = staked_nodes.read().unwrap(); if let Some(stake) = staked_nodes.stake_map.get(&remote_addr.ip()) { let stake = *stake; - let total_stake = staked_nodes.total_stake; drop(staked_nodes); + let mut connection_table_l = staked_connection_table.lock().unwrap(); if connection_table_l.total_size >= max_staked_connections { - let max_connections = max_percentage_full.apply_to(max_staked_connections); - let num_pruned = connection_table_l.prune_oldest(max_connections); - stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed); + let num_pruned = connection_table_l.prune_random(stake); + if num_pruned == 0 { + if max_unstaked_connections > 0 { + // If we couldn't prune a connection in the staked connection table, let's + // put this connection in the unstaked connection table. If needed, prune a + // connection from the unstaked connection table. + connection_table_l = unstaked_connection_table.lock().unwrap(); + prune_unstaked_connection_table( + &mut connection_table_l, + max_unstaked_connections, + stats.clone(), + ); + Some((connection_table_l, stake)) + } else { + stats + .connection_add_failed_on_pruning + .fetch_add(1, Ordering::Relaxed); + None + } + } else { + stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed); + Some((connection_table_l, stake)) + } + } else { + Some((connection_table_l, stake)) } - connection.set_max_concurrent_uni_streams( - VarInt::from_u64( - ((stake as f64 / total_stake as f64) * QUIC_TOTAL_STAKED_CONCURRENT_STREAMS) - as u64, - ) - .unwrap(), - ); - (connection_table_l, stake) - } else { + } else if max_unstaked_connections > 0 { drop(staked_nodes); let mut connection_table_l = unstaked_connection_table.lock().unwrap(); - if connection_table_l.total_size >= max_unstaked_connections { - let max_connections = max_percentage_full.apply_to(max_unstaked_connections); - let num_pruned = connection_table_l.prune_oldest(max_connections); - stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed); - } - connection.set_max_concurrent_uni_streams( - VarInt::from_u64(QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS as u64).unwrap(), + prune_unstaked_connection_table( + &mut connection_table_l, + max_unstaked_connections, + stats.clone(), ); - (connection_table_l, 0) + Some((connection_table_l, 0)) + } else { + None } }; - if stake != 0 || max_unstaked_connections > 0 { - if let Some((last_update, stream_exit)) = connection_table_l.try_add_connection( - &remote_addr, - Some(connection), - timing::timestamp(), - max_connections_per_ip, - ) { - let table_type = connection_table_l.peer_type; - drop(connection_table_l); - let stats = stats.clone(); - let connection_table = match table_type { - ConnectionPeerType::Unstaked => unstaked_connection_table.clone(), - ConnectionPeerType::Staked => staked_connection_table.clone(), - }; - tokio::spawn(handle_connection( - uni_streams, - packet_sender, - remote_addr, - last_update, - connection_table, - stream_exit, - stats, + if let Some((mut connection_table_l, stake)) = table_and_stake { + let table_type = connection_table_l.peer_type; + let max_uni_streams = match table_type { + ConnectionPeerType::Unstaked => { + VarInt::from_u64(QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS as u64) + } + ConnectionPeerType::Staked => { + let staked_nodes = staked_nodes.read().unwrap(); + VarInt::from_u64( + ((stake as f64 / staked_nodes.total_stake as f64) + * QUIC_TOTAL_STAKED_CONCURRENT_STREAMS) as u64, + ) + } + }; + + if let Ok(max_uni_streams) = max_uni_streams { + connection.set_max_concurrent_uni_streams(max_uni_streams); + + if let Some((last_update, stream_exit)) = connection_table_l.try_add_connection( + &remote_addr, + Some(connection), stake, - )); + timing::timestamp(), + max_connections_per_ip, + ) { + drop(connection_table_l); + let stats = stats.clone(); + let connection_table = match table_type { + ConnectionPeerType::Unstaked => unstaked_connection_table.clone(), + ConnectionPeerType::Staked => staked_connection_table.clone(), + }; + tokio::spawn(handle_connection( + uni_streams, + packet_sender, + remote_addr, + last_update, + connection_table, + stream_exit, + stats, + stake, + )); + } else { + stats.connection_add_failed.fetch_add(1, Ordering::Relaxed); + } } else { - stats.connection_add_failed.fetch_add(1, Ordering::Relaxed); + stats + .connection_add_failed_invalid_stream_count + .fetch_add(1, Ordering::Relaxed); } } else { connection.close(0u32.into(), &[0u8]); @@ -387,6 +435,7 @@ fn handle_chunk( #[derive(Debug)] struct ConnectionEntry { exit: Arc, + stake: u64, last_update: Arc, port: u16, connection: Option, @@ -395,12 +444,14 @@ struct ConnectionEntry { impl ConnectionEntry { fn new( exit: Arc, + stake: u64, last_update: Arc, port: u16, connection: Option, ) -> Self { Self { exit, + stake, last_update, port, connection, @@ -429,7 +480,7 @@ enum ConnectionPeerType { // Map of IP to list of connection entries struct ConnectionTable { - table: HashMap>, + table: IndexMap>, total_size: usize, peer_type: ConnectionPeerType, } @@ -439,7 +490,7 @@ struct ConnectionTable { impl ConnectionTable { fn new(peer_type: ConnectionPeerType) -> Self { Self { - table: HashMap::default(), + table: IndexMap::default(), total_size: 0, peer_type, } @@ -473,10 +524,47 @@ impl ConnectionTable { num_pruned } + fn connection_stake(&self, index: usize) -> Option { + self.table + .get_index(index) + .and_then(|(_, connection_vec)| connection_vec.first()) + .map(|connection| connection.stake) + } + + // Randomly select two connections, and evict the one with lower stake. If the stakes of both + // the connections are higher than the threshold_stake, reject the pruning attempt, and return 0. + fn prune_random(&mut self, threshold_stake: u64) -> usize { + let mut num_pruned = 0; + let mut rng = thread_rng(); + // The candidate1 and candidate2 could potentially be the same. If so, the stake of the candidate + // will be compared just against the threshold_stake. + let candidate1 = rng.gen_range(0, self.table.len()); + let candidate2 = rng.gen_range(0, self.table.len()); + + let candidate1_stake = self.connection_stake(candidate1).unwrap_or(0); + let candidate2_stake = self.connection_stake(candidate2).unwrap_or(0); + + if candidate1_stake < threshold_stake || candidate2_stake < threshold_stake { + let removed = if candidate1_stake < candidate2_stake { + self.table.swap_remove_index(candidate1) + } else { + self.table.swap_remove_index(candidate2) + }; + + if let Some((_, removed_value)) = removed { + self.total_size -= removed_value.len(); + num_pruned += removed_value.len(); + } + } + + num_pruned + } + fn try_add_connection( &mut self, addr: &SocketAddr, connection: Option, + stake: u64, last_update: u64, max_connections_per_ip: usize, ) -> Option<(Arc, Arc)> { @@ -491,6 +579,7 @@ impl ConnectionTable { let last_update = Arc::new(AtomicU64::new(last_update)); connection_entry.push(ConnectionEntry::new( exit.clone(), + stake, last_update.clone(), addr.port(), connection, @@ -818,7 +907,7 @@ pub mod test { staked_nodes .stake_map .insert(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 100000); - staked_nodes.total_stake = 100000_f64; + staked_nodes.total_stake = 100000; let (t, exit, receiver, server_address, stats) = setup_quic_server(Some(staked_nodes)); check_multiple_writes(receiver, server_address).await; @@ -918,12 +1007,12 @@ pub mod test { .collect(); for (i, socket) in sockets.iter().enumerate() { table - .try_add_connection(socket, None, i as u64, max_connections_per_ip) + .try_add_connection(socket, None, 0, i as u64, max_connections_per_ip) .unwrap(); } num_entries += 1; table - .try_add_connection(&sockets[0], None, 5, max_connections_per_ip) + .try_add_connection(&sockets[0], None, 0, 5, max_connections_per_ip) .unwrap(); let new_size = 3; @@ -942,6 +1031,40 @@ pub mod test { assert_eq!(table.total_size, 0); } + #[test] + fn test_prune_table_random() { + use std::net::Ipv4Addr; + solana_logger::setup(); + let mut table = ConnectionTable::new(ConnectionPeerType::Staked); + let num_entries = 5; + let max_connections_per_ip = 10; + let sockets: Vec<_> = (0..num_entries) + .into_iter() + .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0)) + .collect(); + for (i, socket) in sockets.iter().enumerate() { + table + .try_add_connection( + socket, + None, + (i + 1) as u64, + i as u64, + max_connections_per_ip, + ) + .unwrap(); + } + + // Try pruninng with threshold stake less than all the entries in the table + // It should fail to prune (i.e. return 0 number of pruned entries) + let pruned = table.prune_random(0); + assert_eq!(pruned, 0); + + // Try pruninng with threshold stake higher than all the entries in the table + // It should succeed to prune (i.e. return 1 number of pruned entries) + let pruned = table.prune_random(num_entries as u64 + 1); + assert_eq!(pruned, 1); + } + #[test] fn test_remove_connections() { use std::net::Ipv4Addr; @@ -955,11 +1078,11 @@ pub mod test { .collect(); for (i, socket) in sockets.iter().enumerate() { table - .try_add_connection(socket, None, (i * 2) as u64, max_connections_per_ip) + .try_add_connection(socket, None, 0, (i * 2) as u64, max_connections_per_ip) .unwrap(); table - .try_add_connection(socket, None, (i * 2 + 1) as u64, max_connections_per_ip) + .try_add_connection(socket, None, 0, (i * 2 + 1) as u64, max_connections_per_ip) .unwrap(); } @@ -969,6 +1092,7 @@ pub mod test { .try_add_connection( &single_connection_addr, None, + 0, (num_ips * 2) as u64, max_connections_per_ip, ) diff --git a/streamer/src/quic.rs b/streamer/src/quic.rs index f1c17c1565..2b49f05c5b 100644 --- a/streamer/src/quic.rs +++ b/streamer/src/quic.rs @@ -158,7 +158,9 @@ pub struct StreamStats { pub(crate) total_stream_read_timeouts: AtomicUsize, pub(crate) num_evictions: AtomicUsize, pub(crate) connection_add_failed: AtomicUsize, + pub(crate) connection_add_failed_invalid_stream_count: AtomicUsize, pub(crate) connection_add_failed_unstaked_node: AtomicUsize, + pub(crate) connection_add_failed_on_pruning: AtomicUsize, pub(crate) connection_setup_timeout: AtomicUsize, pub(crate) connection_removed: AtomicUsize, pub(crate) connection_remove_failed: AtomicUsize, @@ -198,12 +200,24 @@ impl StreamStats { self.connection_add_failed.swap(0, Ordering::Relaxed), i64 ), + ( + "connection_add_failed_invalid_stream_count", + self.connection_add_failed_invalid_stream_count + .swap(0, Ordering::Relaxed), + i64 + ), ( "connection_add_failed_unstaked_node", self.connection_add_failed_unstaked_node .swap(0, Ordering::Relaxed), i64 ), + ( + "connection_add_failed_on_pruning", + self.connection_add_failed_on_pruning + .swap(0, Ordering::Relaxed), + i64 + ), ( "connection_removed", self.connection_removed.swap(0, Ordering::Relaxed), diff --git a/streamer/src/streamer.rs b/streamer/src/streamer.rs index 1838318112..4dad613535 100644 --- a/streamer/src/streamer.rs +++ b/streamer/src/streamer.rs @@ -27,7 +27,7 @@ use { // Total stake and nodes => stake map #[derive(Default)] pub struct StakedNodes { - pub total_stake: f64, + pub total_stake: u64, pub stake_map: HashMap, }