diff --git a/core/src/lib.rs b/core/src/lib.rs index be682be77..2af5f786e 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -59,6 +59,7 @@ pub mod sigverify; pub mod sigverify_shreds; pub mod sigverify_stage; pub mod snapshot_packager_service; +pub mod staked_nodes_updater_service; pub mod stats_reporter_service; pub mod system_monitor_service; mod tower1_7_14; diff --git a/core/src/staked_nodes_updater_service.rs b/core/src/staked_nodes_updater_service.rs new file mode 100644 index 000000000..410213f1c --- /dev/null +++ b/core/src/staked_nodes_updater_service.rs @@ -0,0 +1,76 @@ +use { + solana_gossip::cluster_info::ClusterInfo, + solana_runtime::bank_forks::BankForks, + std::{ + collections::HashMap, + net::IpAddr, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, RwLock, + }, + thread::{self, sleep, Builder, JoinHandle}, + time::{Duration, Instant}, + }, +}; + +const IP_TO_STAKE_REFRESH_DURATION: Duration = Duration::from_secs(5); + +pub struct StakedNodesUpdaterService { + thread_hdl: JoinHandle<()>, +} + +impl StakedNodesUpdaterService { + pub fn new( + exit: Arc, + cluster_info: Arc, + bank_forks: Arc>, + shared_staked_nodes: Arc>>, + ) -> Self { + let thread_hdl = Builder::new() + .name("sol-sn-updater".to_string()) + .spawn(move || { + let mut last_stakes = Instant::now(); + while !exit.load(Ordering::Relaxed) { + let mut new_ip_to_stake = HashMap::new(); + Self::try_refresh_ip_to_stake( + &mut last_stakes, + &mut new_ip_to_stake, + &bank_forks, + &cluster_info, + ); + let mut shared = shared_staked_nodes.write().unwrap(); + *shared = new_ip_to_stake; + } + }) + .unwrap(); + + Self { thread_hdl } + } + + fn try_refresh_ip_to_stake( + last_stakes: &mut Instant, + ip_to_stake: &mut HashMap, + bank_forks: &RwLock, + cluster_info: &ClusterInfo, + ) { + if last_stakes.elapsed() > IP_TO_STAKE_REFRESH_DURATION { + let root_bank = bank_forks.read().unwrap().root_bank(); + let staked_nodes = root_bank.staked_nodes(); + *ip_to_stake = cluster_info + .tvu_peers() + .into_iter() + .filter_map(|node| { + let stake = staked_nodes.get(&node.id)?; + Some((node.tvu.ip(), *stake)) + }) + .collect(); + *last_stakes = Instant::now(); + } else { + sleep(Duration::from_millis(1)); + } + } + + pub fn join(self) -> thread::Result<()> { + self.thread_hdl.join() + } +} diff --git a/core/src/tpu.rs b/core/src/tpu.rs index d85ec9dc2..e12712817 100644 --- a/core/src/tpu.rs +++ b/core/src/tpu.rs @@ -13,6 +13,7 @@ use { find_packet_sender_stake_stage::FindPacketSenderStakeStage, sigverify::TransactionSigVerifier, sigverify_stage::SigVerifyStage, + staked_nodes_updater_service::StakedNodesUpdaterService, }, crossbeam_channel::{bounded, unbounded, Receiver, RecvTimeoutError}, solana_gossip::cluster_info::ClusterInfo, @@ -28,7 +29,9 @@ use { vote_sender_types::{ReplayVoteReceiver, ReplayVoteSender}, }, solana_sdk::signature::Keypair, + solana_streamer::quic::{spawn_server, MAX_STAKED_CONNECTIONS, MAX_UNSTAKED_CONNECTIONS}, std::{ + collections::HashMap, net::UdpSocket, sync::{atomic::AtomicBool, Arc, Mutex, RwLock}, thread, @@ -62,6 +65,7 @@ pub struct Tpu { tpu_quic_t: thread::JoinHandle<()>, find_packet_sender_stake_stage: FindPacketSenderStakeStage, vote_find_packet_sender_stake_stage: FindPacketSenderStakeStage, + staked_nodes_updater_service: StakedNodesUpdaterService, } impl Tpu { @@ -132,13 +136,23 @@ impl Tpu { let (verified_sender, verified_receiver) = unbounded(); - let tpu_quic_t = solana_streamer::quic::spawn_server( + let staked_nodes = Arc::new(RwLock::new(HashMap::new())); + let staked_nodes_updater_service = StakedNodesUpdaterService::new( + exit.clone(), + cluster_info.clone(), + bank_forks.clone(), + staked_nodes.clone(), + ); + let tpu_quic_t = spawn_server( transactions_quic_sockets, keypair, cluster_info.my_contact_info().tpu.ip(), packet_sender, exit.clone(), MAX_QUIC_CONNECTIONS_PER_IP, + staked_nodes, + MAX_STAKED_CONNECTIONS, + MAX_UNSTAKED_CONNECTIONS, ) .unwrap(); @@ -208,6 +222,7 @@ impl Tpu { tpu_quic_t, find_packet_sender_stake_stage, vote_find_packet_sender_stake_stage, + staked_nodes_updater_service, } } @@ -236,6 +251,7 @@ impl Tpu { self.banking_stage.join(), self.find_packet_sender_stake_stage.join(), self.vote_find_packet_sender_stake_stage.join(), + self.staked_nodes_updater_service.join(), ]; self.tpu_quic_t.join()?; let broadcast_result = self.broadcast_stage.join(); diff --git a/streamer/src/quic.rs b/streamer/src/quic.rs index 3dd2e986f..2ae43655c 100644 --- a/streamer/src/quic.rs +++ b/streamer/src/quic.rs @@ -18,7 +18,7 @@ use { net::{IpAddr, SocketAddr, UdpSocket}, sync::{ atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}, - Arc, Mutex, + Arc, Mutex, RwLock, }, thread, time::{Duration, Instant}, @@ -29,6 +29,9 @@ use { }, }; +pub const MAX_STAKED_CONNECTIONS: usize = 2000; +pub const MAX_UNSTAKED_CONNECTIONS: usize = 500; + /// Returns default server configuration along with its PEM certificate chain. #[allow(clippy::field_reassign_with_default)] // https://github.com/rust-lang/rust-clippy/issues/6527 fn configure_server( @@ -401,6 +404,9 @@ pub fn spawn_server( packet_sender: Sender, exit: Arc, max_connections_per_ip: usize, + staked_nodes: Arc>>, + max_staked_connections: usize, + max_unstaked_connections: usize, ) -> Result, QuicServerError> { let (config, _cert) = configure_server(keypair, gossip_host)?; @@ -418,6 +424,8 @@ pub fn spawn_server( let mut last_datapoint = Instant::now(); let connection_table: Arc> = Arc::new(Mutex::new(ConnectionTable::default())); + let staked_connection_table: Arc> = + Arc::new(Mutex::new(ConnectionTable::default())); while !exit.load(Ordering::Relaxed) { const WAIT_FOR_CONNECTION_TIMEOUT_MS: u64 = 1000; let timeout_connection = timeout( @@ -443,10 +451,21 @@ pub fn spawn_server( let remote_addr = connection.remote_address(); - let mut connection_table_l = connection_table.lock().unwrap(); - const MAX_CONNECTION_TABLE_SIZE: usize = 5000; - let num_pruned = connection_table_l.prune_oldest(MAX_CONNECTION_TABLE_SIZE); - stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed); + let mut connection_table_l = + if staked_nodes.read().unwrap().contains_key(&remote_addr.ip()) { + let mut connection_table_l = + staked_connection_table.lock().unwrap(); + let num_pruned = + connection_table_l.prune_oldest(max_staked_connections); + stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed); + connection_table_l + } else { + let mut connection_table_l = connection_table.lock().unwrap(); + let num_pruned = + connection_table_l.prune_oldest(max_unstaked_connections); + stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed); + connection_table_l + }; if let Some((last_update, stream_exit)) = connection_table_l .try_add_connection( @@ -529,12 +548,8 @@ mod test { #[test] fn test_quic_server_exit() { - let s = UdpSocket::bind("127.0.0.1:0").unwrap(); - let exit = Arc::new(AtomicBool::new(false)); - let (sender, _receiver) = unbounded(); - let keypair = Keypair::new(); - let ip = "127.0.0.1".parse().unwrap(); - let t = spawn_server(s, &keypair, ip, sender, exit.clone(), 1).unwrap(); + let (t, exit, _receiver, _server_address) = setup_quic_server(); + exit.store(true, Ordering::Relaxed); t.join().unwrap(); } @@ -592,13 +607,7 @@ mod test { #[test] fn test_quic_server_block_multiple_connections() { solana_logger::setup(); - let s = UdpSocket::bind("127.0.0.1:0").unwrap(); - let exit = Arc::new(AtomicBool::new(false)); - let (sender, _receiver) = unbounded(); - let keypair = Keypair::new(); - let ip = "127.0.0.1".parse().unwrap(); - let server_address = s.local_addr().unwrap(); - let t = spawn_server(s, &keypair, ip, sender, exit.clone(), 1).unwrap(); + let (t, exit, _receiver, server_address) = setup_quic_server(); let runtime = rt(); let _rt_guard = runtime.enter(); @@ -627,7 +636,19 @@ mod test { let keypair = Keypair::new(); let ip = "127.0.0.1".parse().unwrap(); let server_address = s.local_addr().unwrap(); - let t = spawn_server(s, &keypair, ip, sender, exit.clone(), 2).unwrap(); + let staked_nodes = Arc::new(RwLock::new(HashMap::new())); + let t = spawn_server( + s, + &keypair, + ip, + sender, + exit.clone(), + 2, + staked_nodes, + 10, + 10, + ) + .unwrap(); let runtime = rt(); let _rt_guard = runtime.enter(); @@ -673,16 +694,38 @@ mod test { t.join().unwrap(); } - #[test] - fn test_quic_server_multiple_writes() { - solana_logger::setup(); + fn setup_quic_server() -> ( + std::thread::JoinHandle<()>, + Arc, + crossbeam_channel::Receiver, + SocketAddr, + ) { let s = UdpSocket::bind("127.0.0.1:0").unwrap(); let exit = Arc::new(AtomicBool::new(false)); let (sender, receiver) = unbounded(); let keypair = Keypair::new(); let ip = "127.0.0.1".parse().unwrap(); let server_address = s.local_addr().unwrap(); - let t = spawn_server(s, &keypair, ip, sender, exit.clone(), 1).unwrap(); + let staked_nodes = Arc::new(RwLock::new(HashMap::new())); + let t = spawn_server( + s, + &keypair, + ip, + sender, + exit.clone(), + 1, + staked_nodes, + MAX_STAKED_CONNECTIONS, + MAX_UNSTAKED_CONNECTIONS, + ) + .unwrap(); + (t, exit, receiver, server_address) + } + + #[test] + fn test_quic_server_multiple_writes() { + solana_logger::setup(); + let (t, exit, receiver, server_address) = setup_quic_server(); let runtime = rt(); let _rt_guard = runtime.enter();