Compute maximum parallel QUIC streams using client stake (#26802)

* Compute maximum parallel QUIC streams using client stake

* clippy fixes

* Add unit test
This commit is contained in:
Pankaj Garg 2022-07-29 08:44:24 -07:00 committed by GitHub
parent 9d31b1d290
commit fb922f613c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 147 additions and 43 deletions

View File

@ -9,8 +9,14 @@ use {
indexmap::map::{Entry, IndexMap}, indexmap::map::{Entry, IndexMap},
rand::{thread_rng, Rng}, rand::{thread_rng, Rng},
solana_measure::measure::Measure, solana_measure::measure::Measure,
solana_sdk::{quic::QUIC_PORT_OFFSET, signature::Keypair, timing::AtomicInterval}, solana_sdk::{
solana_streamer::tls_certificates::new_self_signed_tls_certificate_chain, pubkey::Pubkey, quic::QUIC_PORT_OFFSET, signature::Keypair, timing::AtomicInterval,
},
solana_streamer::{
nonblocking::quic::{compute_max_allowed_uni_streams, ConnectionPeerType},
streamer::StakedNodes,
tls_certificates::new_self_signed_tls_certificate_chain,
},
std::{ std::{
error::Error, error::Error,
net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket}, net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket},
@ -228,6 +234,8 @@ pub struct ConnectionCache {
tpu_udp_socket: Arc<UdpSocket>, tpu_udp_socket: Arc<UdpSocket>,
client_certificate: Arc<QuicClientCertificate>, client_certificate: Arc<QuicClientCertificate>,
use_quic: bool, use_quic: bool,
maybe_staked_nodes: Option<Arc<RwLock<StakedNodes>>>,
maybe_client_pubkey: Option<Pubkey>,
} }
/// Models the pool of connections /// Models the pool of connections
@ -279,6 +287,15 @@ impl ConnectionCache {
Ok(()) Ok(())
} }
pub fn set_staked_nodes(
&mut self,
staked_nodes: &Arc<RwLock<StakedNodes>>,
client_pubkey: &Pubkey,
) {
self.maybe_staked_nodes = Some(staked_nodes.clone());
self.maybe_client_pubkey = Some(*client_pubkey);
}
pub fn with_udp(connection_pool_size: usize) -> Self { pub fn with_udp(connection_pool_size: usize) -> Self {
// The minimum pool size is 1. // The minimum pool size is 1.
let connection_pool_size = 1.max(connection_pool_size); let connection_pool_size = 1.max(connection_pool_size);
@ -303,6 +320,24 @@ impl ConnectionCache {
} }
} }
fn compute_max_parallel_chunks(&self) -> usize {
let (client_type, stake, total_stake) =
self.maybe_client_pubkey
.map_or((ConnectionPeerType::Unstaked, 0, 0), |pubkey| {
self.maybe_staked_nodes.as_ref().map_or(
(ConnectionPeerType::Unstaked, 0, 0),
|stakes| {
let rstakes = stakes.read().unwrap();
rstakes.pubkey_stake_map.get(&pubkey).map_or(
(ConnectionPeerType::Unstaked, 0, rstakes.total_stake),
|stake| (ConnectionPeerType::Staked, *stake, rstakes.total_stake),
)
},
)
});
compute_max_allowed_uni_streams(client_type, stake, total_stake)
}
/// Create a lazy connection object under the exclusive lock of the cache map if there is not /// Create a lazy connection object under the exclusive lock of the cache map if there is not
/// enough used connections in the connection pool for the specified address. /// enough used connections in the connection pool for the specified address.
/// Returns CreateConnectionResult. /// Returns CreateConnectionResult.
@ -335,6 +370,7 @@ impl ConnectionCache {
BaseTpuConnection::Quic(Arc::new(QuicClient::new( BaseTpuConnection::Quic(Arc::new(QuicClient::new(
endpoint.as_ref().unwrap().clone(), endpoint.as_ref().unwrap().clone(),
*addr, *addr,
self.compute_max_parallel_chunks(),
))) )))
}; };
@ -534,6 +570,8 @@ impl Default for ConnectionCache {
key: priv_key, key: priv_key,
}), }),
use_quic: DEFAULT_TPU_USE_QUIC, use_quic: DEFAULT_TPU_USE_QUIC,
maybe_staked_nodes: None,
maybe_client_pubkey: None,
} }
} }
} }
@ -604,8 +642,18 @@ mod tests {
}, },
rand::{Rng, SeedableRng}, rand::{Rng, SeedableRng},
rand_chacha::ChaChaRng, rand_chacha::ChaChaRng,
solana_sdk::quic::QUIC_PORT_OFFSET, solana_sdk::{
std::net::{IpAddr, Ipv4Addr, SocketAddr}, pubkey::Pubkey,
quic::{
QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS, QUIC_MIN_STAKED_CONCURRENT_STREAMS,
QUIC_PORT_OFFSET,
},
},
solana_streamer::streamer::StakedNodes,
std::{
net::{IpAddr, Ipv4Addr, SocketAddr},
sync::{Arc, RwLock},
},
}; };
fn get_addr(rng: &mut ChaChaRng) -> SocketAddr { fn get_addr(rng: &mut ChaChaRng) -> SocketAddr {
@ -661,6 +709,55 @@ mod tests {
let _conn = map.get(&addr).expect("Address not found"); let _conn = map.get(&addr).expect("Address not found");
} }
#[test]
fn test_connection_cache_max_parallel_chunks() {
solana_logger::setup();
let mut connection_cache = ConnectionCache::default();
assert_eq!(
connection_cache.compute_max_parallel_chunks(),
QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS
);
let staked_nodes = Arc::new(RwLock::new(StakedNodes::default()));
let pubkey = Pubkey::new_unique();
connection_cache.set_staked_nodes(&staked_nodes, &pubkey);
assert_eq!(
connection_cache.compute_max_parallel_chunks(),
QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS
);
staked_nodes.write().unwrap().total_stake = 10000;
assert_eq!(
connection_cache.compute_max_parallel_chunks(),
QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS
);
staked_nodes
.write()
.unwrap()
.pubkey_stake_map
.insert(pubkey, 1);
assert_eq!(
connection_cache.compute_max_parallel_chunks(),
QUIC_MIN_STAKED_CONCURRENT_STREAMS
);
staked_nodes
.write()
.unwrap()
.pubkey_stake_map
.remove(&pubkey);
staked_nodes
.write()
.unwrap()
.pubkey_stake_map
.insert(pubkey, 1000);
assert_ne!(
connection_cache.compute_max_parallel_chunks(),
QUIC_MIN_STAKED_CONCURRENT_STREAMS
);
}
// Test that we can get_connection with a connection cache configured for quic // Test that we can get_connection with a connection cache configured for quic
// on an address with a port that, if QUIC_PORT_OFFSET were added to it, it would overflow to // on an address with a port that, if QUIC_PORT_OFFSET were added to it, it would overflow to
// an invalid port. // an invalid port.

View File

@ -263,15 +263,21 @@ pub struct QuicClient {
connection: Arc<Mutex<Option<QuicNewConnection>>>, connection: Arc<Mutex<Option<QuicNewConnection>>>,
addr: SocketAddr, addr: SocketAddr,
stats: Arc<ClientStats>, stats: Arc<ClientStats>,
num_chunks: usize,
} }
impl QuicClient { impl QuicClient {
pub fn new(endpoint: Arc<QuicLazyInitializedEndpoint>, addr: SocketAddr) -> Self { pub fn new(
endpoint: Arc<QuicLazyInitializedEndpoint>,
addr: SocketAddr,
num_chunks: usize,
) -> Self {
Self { Self {
endpoint, endpoint,
connection: Arc::new(Mutex::new(None)), connection: Arc::new(Mutex::new(None)),
addr, addr,
stats: Arc::new(ClientStats::default()), stats: Arc::new(ClientStats::default()),
num_chunks,
} }
} }
@ -439,7 +445,7 @@ impl QuicClient {
fn compute_chunk_length(num_buffers_to_chunk: usize, num_chunks: usize) -> usize { fn compute_chunk_length(num_buffers_to_chunk: usize, num_chunks: usize) -> usize {
// The function is equivalent to checked div_ceil() // The function is equivalent to checked div_ceil()
// Also, if num_chunks == 0 || num_buffers_per_chunk == 0, return 1 // Also, if num_chunks == 0 || num_buffers_to_chunk == 0, return 1
num_buffers_to_chunk num_buffers_to_chunk
.checked_div(num_chunks) .checked_div(num_chunks)
.map_or(1, |value| { .map_or(1, |value| {
@ -483,8 +489,7 @@ impl QuicClient {
// by just getting a reference to the NewConnection once // by just getting a reference to the NewConnection once
let connection_ref: &NewConnection = &connection; let connection_ref: &NewConnection = &connection;
let chunk_len = let chunk_len = Self::compute_chunk_length(buffers.len() - 1, self.num_chunks);
Self::compute_chunk_length(buffers.len() - 1, QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS);
let chunks = buffers[1..buffers.len()].iter().chunks(chunk_len); let chunks = buffers[1..buffers.len()].iter().chunks(chunk_len);
let futures: Vec<_> = chunks let futures: Vec<_> = chunks
@ -528,7 +533,11 @@ impl QuicTpuConnection {
addr: SocketAddr, addr: SocketAddr,
connection_stats: Arc<ConnectionCacheStats>, connection_stats: Arc<ConnectionCacheStats>,
) -> Self { ) -> Self {
let client = Arc::new(QuicClient::new(endpoint, addr)); let client = Arc::new(QuicClient::new(
endpoint,
addr,
QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS,
));
Self::new_with_client(client, connection_stats) Self::new_with_client(client, connection_stats)
} }

View File

@ -97,6 +97,7 @@ impl Tpu {
keypair: &Keypair, keypair: &Keypair,
log_messages_bytes_limit: Option<usize>, log_messages_bytes_limit: Option<usize>,
enable_quic_servers: bool, enable_quic_servers: bool,
staked_nodes: &Arc<RwLock<StakedNodes>>,
) -> Self { ) -> Self {
let TpuSockets { let TpuSockets {
transactions: transactions_sockets, transactions: transactions_sockets,
@ -124,7 +125,6 @@ impl Tpu {
Some(bank_forks.read().unwrap().get_vote_only_mode_signal()), Some(bank_forks.read().unwrap().get_vote_only_mode_signal()),
); );
let staked_nodes = Arc::new(RwLock::new(StakedNodes::default()));
let staked_nodes_updater_service = StakedNodesUpdaterService::new( let staked_nodes_updater_service = StakedNodesUpdaterService::new(
exit.clone(), exit.clone(),
cluster_info.clone(), cluster_info.clone(),
@ -178,7 +178,7 @@ impl Tpu {
forwarded_packet_sender, forwarded_packet_sender,
exit.clone(), exit.clone(),
MAX_QUIC_CONNECTIONS_PER_PEER, MAX_QUIC_CONNECTIONS_PER_PEER,
staked_nodes, staked_nodes.clone(),
MAX_STAKED_CONNECTIONS.saturating_add(MAX_UNSTAKED_CONNECTIONS), MAX_STAKED_CONNECTIONS.saturating_add(MAX_UNSTAKED_CONNECTIONS),
0, // Prevent unstaked nodes from forwarding transactions 0, // Prevent unstaked nodes from forwarding transactions
stats, stats,

View File

@ -98,7 +98,7 @@ use {
timing::timestamp, timing::timestamp,
}, },
solana_send_transaction_service::send_transaction_service, solana_send_transaction_service::send_transaction_service,
solana_streamer::socket::SocketAddrSpace, solana_streamer::{socket::SocketAddrSpace, streamer::StakedNodes},
solana_vote_program::vote_state::VoteState, solana_vote_program::vote_state::VoteState,
std::{ std::{
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
@ -757,12 +757,15 @@ impl Validator {
}; };
let poh_recorder = Arc::new(RwLock::new(poh_recorder)); let poh_recorder = Arc::new(RwLock::new(poh_recorder));
let staked_nodes = Arc::new(RwLock::new(StakedNodes::default()));
let connection_cache = match use_quic { let connection_cache = match use_quic {
true => { true => {
let mut connection_cache = ConnectionCache::new(tpu_connection_pool_size); let mut connection_cache = ConnectionCache::new(tpu_connection_pool_size);
connection_cache connection_cache
.update_client_certificate(&identity_keypair, node.info.gossip.ip()) .update_client_certificate(&identity_keypair, node.info.gossip.ip())
.expect("Failed to update QUIC client certificates"); .expect("Failed to update QUIC client certificates");
connection_cache.set_staked_nodes(&staked_nodes, &identity_keypair.pubkey());
Arc::new(connection_cache) Arc::new(connection_cache)
} }
false => Arc::new(ConnectionCache::with_udp(tpu_connection_pool_size)), false => Arc::new(ConnectionCache::with_udp(tpu_connection_pool_size)),
@ -1025,6 +1028,7 @@ impl Validator {
&identity_keypair, &identity_keypair,
config.runtime_config.log_messages_bytes_limit, config.runtime_config.log_messages_bytes_limit,
config.enable_quic_servers, config.enable_quic_servers,
&staked_nodes,
); );
datapoint_info!( datapoint_info!(

View File

@ -159,10 +159,10 @@ fn get_connection_stake(
}) })
} }
fn compute_max_allowed_uni_streams( pub fn compute_max_allowed_uni_streams(
peer_type: ConnectionPeerType, peer_type: ConnectionPeerType,
peer_stake: u64, peer_stake: u64,
staked_nodes: Arc<RwLock<StakedNodes>>, total_stake: u64,
) -> usize { ) -> usize {
if peer_stake == 0 { if peer_stake == 0 {
// Treat stake = 0 as unstaked // Treat stake = 0 as unstaked
@ -170,13 +170,11 @@ fn compute_max_allowed_uni_streams(
} else { } else {
match peer_type { match peer_type {
ConnectionPeerType::Staked => { ConnectionPeerType::Staked => {
let staked_nodes = staked_nodes.read().unwrap();
// No checked math for f64 type. So let's explicitly check for 0 here // No checked math for f64 type. So let's explicitly check for 0 here
if staked_nodes.total_stake == 0 { if total_stake == 0 {
QUIC_MIN_STAKED_CONCURRENT_STREAMS QUIC_MIN_STAKED_CONCURRENT_STREAMS
} else { } else {
(((peer_stake as f64 / staked_nodes.total_stake as f64) (((peer_stake as f64 / total_stake as f64)
* QUIC_TOTAL_STAKED_CONCURRENT_STREAMS as f64) * QUIC_TOTAL_STAKED_CONCURRENT_STREAMS as f64)
as usize) as usize)
.max(QUIC_MIN_STAKED_CONCURRENT_STREAMS) .max(QUIC_MIN_STAKED_CONCURRENT_STREAMS)
@ -264,17 +262,19 @@ async fn setup_connection(
if let Some((mut connection_table_l, stake)) = table_and_stake { if let Some((mut connection_table_l, stake)) = table_and_stake {
let table_type = connection_table_l.peer_type; let table_type = connection_table_l.peer_type;
let max_uni_streams = VarInt::from_u64(compute_max_allowed_uni_streams( let total_stake = staked_nodes.read().map_or(0, |stakes| stakes.total_stake);
table_type, drop(staked_nodes);
stake,
staked_nodes.clone(), let max_uni_streams =
) as u64); VarInt::from_u64(
compute_max_allowed_uni_streams(table_type, stake, total_stake) as u64,
);
debug!( debug!(
"Peer type: {:?}, stake {}, total stake {}, max streams {}", "Peer type: {:?}, stake {}, total stake {}, max streams {}",
table_type, table_type,
stake, stake,
staked_nodes.read().unwrap().total_stake, total_stake,
max_uni_streams.unwrap().into_inner() max_uni_streams.unwrap().into_inner()
); );
@ -558,7 +558,7 @@ impl Drop for ConnectionEntry {
} }
#[derive(Copy, Clone, Debug)] #[derive(Copy, Clone, Debug)]
enum ConnectionPeerType { pub enum ConnectionPeerType {
Unstaked, Unstaked,
Staked, Staked,
} }
@ -1406,58 +1406,52 @@ pub mod test {
#[test] #[test]
fn test_max_allowed_uni_streams() { fn test_max_allowed_uni_streams() {
let staked_nodes = Arc::new(RwLock::new(StakedNodes::default()));
assert_eq!( assert_eq!(
compute_max_allowed_uni_streams(ConnectionPeerType::Unstaked, 0, staked_nodes.clone()), compute_max_allowed_uni_streams(ConnectionPeerType::Unstaked, 0, 0),
QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS
); );
assert_eq!( assert_eq!(
compute_max_allowed_uni_streams(ConnectionPeerType::Unstaked, 10, staked_nodes.clone()), compute_max_allowed_uni_streams(ConnectionPeerType::Unstaked, 10, 0),
QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS
); );
assert_eq!( assert_eq!(
compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 0, staked_nodes.clone()), compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 0, 0),
QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS
); );
assert_eq!( assert_eq!(
compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 10, staked_nodes.clone()), compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 10, 0),
QUIC_MIN_STAKED_CONCURRENT_STREAMS QUIC_MIN_STAKED_CONCURRENT_STREAMS
); );
staked_nodes.write().unwrap().total_stake = 10000;
assert_eq!( assert_eq!(
compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 1000, staked_nodes.clone()), compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 1000, 10000),
(QUIC_TOTAL_STAKED_CONCURRENT_STREAMS / (10_f64)) as usize (QUIC_TOTAL_STAKED_CONCURRENT_STREAMS / (10_f64)) as usize
); );
assert_eq!( assert_eq!(
compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 100, staked_nodes.clone()), compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 100, 10000),
(QUIC_TOTAL_STAKED_CONCURRENT_STREAMS / (100_f64)) as usize (QUIC_TOTAL_STAKED_CONCURRENT_STREAMS / (100_f64)) as usize
); );
assert_eq!( assert_eq!(
compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 10, staked_nodes.clone()), compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 10, 10000),
QUIC_MIN_STAKED_CONCURRENT_STREAMS QUIC_MIN_STAKED_CONCURRENT_STREAMS
); );
assert_eq!( assert_eq!(
compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 1, staked_nodes.clone()), compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 1, 10000),
QUIC_MIN_STAKED_CONCURRENT_STREAMS QUIC_MIN_STAKED_CONCURRENT_STREAMS
); );
assert_eq!( assert_eq!(
compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 0, staked_nodes.clone()), compute_max_allowed_uni_streams(ConnectionPeerType::Staked, 0, 10000),
QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS
); );
assert_eq!( assert_eq!(
compute_max_allowed_uni_streams( compute_max_allowed_uni_streams(ConnectionPeerType::Unstaked, 1000, 10000),
ConnectionPeerType::Unstaked,
1000,
staked_nodes.clone()
),
QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS
); );
assert_eq!( assert_eq!(
compute_max_allowed_uni_streams(ConnectionPeerType::Unstaked, 1, staked_nodes.clone()), compute_max_allowed_uni_streams(ConnectionPeerType::Unstaked, 1, 10000),
QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS
); );
assert_eq!( assert_eq!(
compute_max_allowed_uni_streams(ConnectionPeerType::Unstaked, 0, staked_nodes), compute_max_allowed_uni_streams(ConnectionPeerType::Unstaked, 0, 10000),
QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS
); );
} }