Implement randomized pruning of QUIC connection from staked peers (#26299)
This commit is contained in:
parent
72a968fbe8
commit
94685e1222
|
@ -6156,6 +6156,7 @@ dependencies = [
|
|||
"crossbeam-channel",
|
||||
"futures-util",
|
||||
"histogram",
|
||||
"indexmap",
|
||||
"itertools",
|
||||
"libc",
|
||||
"log",
|
||||
|
|
|
@ -42,7 +42,7 @@ impl StakedNodesUpdaterService {
|
|||
&cluster_info,
|
||||
) {
|
||||
let mut shared = shared_staked_nodes.write().unwrap();
|
||||
shared.total_stake = total_stake as f64;
|
||||
shared.total_stake = total_stake;
|
||||
shared.stake_map = new_ip_to_stake;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5446,6 +5446,7 @@ dependencies = [
|
|||
"crossbeam-channel",
|
||||
"futures-util",
|
||||
"histogram",
|
||||
"indexmap",
|
||||
"itertools",
|
||||
"libc",
|
||||
"log",
|
||||
|
|
|
@ -13,6 +13,7 @@ edition = "2021"
|
|||
crossbeam-channel = "0.5"
|
||||
futures-util = "0.3.21"
|
||||
histogram = "0.6.9"
|
||||
indexmap = "1.8.1"
|
||||
itertools = "0.10.3"
|
||||
libc = "0.2.126"
|
||||
log = "0.4.17"
|
||||
|
|
|
@ -5,11 +5,13 @@ use {
|
|||
},
|
||||
crossbeam_channel::Sender,
|
||||
futures_util::stream::StreamExt,
|
||||
indexmap::map::{Entry, IndexMap},
|
||||
percentage::Percentage,
|
||||
quinn::{
|
||||
Connecting, Connection, Endpoint, EndpointConfig, Incoming, IncomingUniStreams,
|
||||
NewConnection, VarInt,
|
||||
},
|
||||
rand::{thread_rng, Rng},
|
||||
solana_perf::packet::PacketBatch,
|
||||
solana_sdk::{
|
||||
packet::{Packet, PACKET_DATA_SIZE},
|
||||
|
@ -18,11 +20,10 @@ use {
|
|||
timing,
|
||||
},
|
||||
std::{
|
||||
collections::{hash_map::Entry, HashMap},
|
||||
net::{IpAddr, SocketAddr, UdpSocket},
|
||||
sync::{
|
||||
atomic::{AtomicBool, AtomicU64, Ordering},
|
||||
Arc, Mutex, RwLock,
|
||||
Arc, Mutex, MutexGuard, RwLock,
|
||||
},
|
||||
time::{Duration, Instant},
|
||||
},
|
||||
|
@ -116,6 +117,21 @@ pub async fn run_server(
|
|||
}
|
||||
}
|
||||
|
||||
fn prune_unstaked_connection_table(
|
||||
unstaked_connection_table: &mut MutexGuard<ConnectionTable>,
|
||||
max_unstaked_connections: usize,
|
||||
stats: Arc<StreamStats>,
|
||||
) {
|
||||
if unstaked_connection_table.total_size >= max_unstaked_connections {
|
||||
const PRUNE_TABLE_TO_PERCENTAGE: u8 = 90;
|
||||
let max_percentage_full = Percentage::from(PRUNE_TABLE_TO_PERCENTAGE);
|
||||
|
||||
let max_connections = max_percentage_full.apply_to(max_unstaked_connections);
|
||||
let num_pruned = unstaked_connection_table.prune_oldest(max_connections);
|
||||
stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
async fn setup_connection(
|
||||
connection: Connecting,
|
||||
unstaked_connection_table: Arc<Mutex<ConnectionTable>>,
|
||||
|
@ -138,70 +154,102 @@ async fn setup_connection(
|
|||
|
||||
let remote_addr = connection.remote_address();
|
||||
|
||||
let (mut connection_table_l, stake) = {
|
||||
const PRUNE_TABLE_TO_PERCENTAGE: u8 = 90;
|
||||
let max_percentage_full = Percentage::from(PRUNE_TABLE_TO_PERCENTAGE);
|
||||
|
||||
let table_and_stake = {
|
||||
let staked_nodes = staked_nodes.read().unwrap();
|
||||
if let Some(stake) = staked_nodes.stake_map.get(&remote_addr.ip()) {
|
||||
let stake = *stake;
|
||||
let total_stake = staked_nodes.total_stake;
|
||||
drop(staked_nodes);
|
||||
|
||||
let mut connection_table_l = staked_connection_table.lock().unwrap();
|
||||
if connection_table_l.total_size >= max_staked_connections {
|
||||
let max_connections = max_percentage_full.apply_to(max_staked_connections);
|
||||
let num_pruned = connection_table_l.prune_oldest(max_connections);
|
||||
stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed);
|
||||
let num_pruned = connection_table_l.prune_random(stake);
|
||||
if num_pruned == 0 {
|
||||
if max_unstaked_connections > 0 {
|
||||
// If we couldn't prune a connection in the staked connection table, let's
|
||||
// put this connection in the unstaked connection table. If needed, prune a
|
||||
// connection from the unstaked connection table.
|
||||
connection_table_l = unstaked_connection_table.lock().unwrap();
|
||||
prune_unstaked_connection_table(
|
||||
&mut connection_table_l,
|
||||
max_unstaked_connections,
|
||||
stats.clone(),
|
||||
);
|
||||
Some((connection_table_l, stake))
|
||||
} else {
|
||||
stats
|
||||
.connection_add_failed_on_pruning
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
None
|
||||
}
|
||||
} else {
|
||||
stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed);
|
||||
Some((connection_table_l, stake))
|
||||
}
|
||||
} else {
|
||||
Some((connection_table_l, stake))
|
||||
}
|
||||
connection.set_max_concurrent_uni_streams(
|
||||
VarInt::from_u64(
|
||||
((stake as f64 / total_stake as f64) * QUIC_TOTAL_STAKED_CONCURRENT_STREAMS)
|
||||
as u64,
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
(connection_table_l, stake)
|
||||
} else {
|
||||
} else if max_unstaked_connections > 0 {
|
||||
drop(staked_nodes);
|
||||
let mut connection_table_l = unstaked_connection_table.lock().unwrap();
|
||||
if connection_table_l.total_size >= max_unstaked_connections {
|
||||
let max_connections = max_percentage_full.apply_to(max_unstaked_connections);
|
||||
let num_pruned = connection_table_l.prune_oldest(max_connections);
|
||||
stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed);
|
||||
}
|
||||
connection.set_max_concurrent_uni_streams(
|
||||
VarInt::from_u64(QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS as u64).unwrap(),
|
||||
prune_unstaked_connection_table(
|
||||
&mut connection_table_l,
|
||||
max_unstaked_connections,
|
||||
stats.clone(),
|
||||
);
|
||||
(connection_table_l, 0)
|
||||
Some((connection_table_l, 0))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
if stake != 0 || max_unstaked_connections > 0 {
|
||||
if let Some((last_update, stream_exit)) = connection_table_l.try_add_connection(
|
||||
&remote_addr,
|
||||
Some(connection),
|
||||
timing::timestamp(),
|
||||
max_connections_per_ip,
|
||||
) {
|
||||
let table_type = connection_table_l.peer_type;
|
||||
drop(connection_table_l);
|
||||
let stats = stats.clone();
|
||||
let connection_table = match table_type {
|
||||
ConnectionPeerType::Unstaked => unstaked_connection_table.clone(),
|
||||
ConnectionPeerType::Staked => staked_connection_table.clone(),
|
||||
};
|
||||
tokio::spawn(handle_connection(
|
||||
uni_streams,
|
||||
packet_sender,
|
||||
remote_addr,
|
||||
last_update,
|
||||
connection_table,
|
||||
stream_exit,
|
||||
stats,
|
||||
if let Some((mut connection_table_l, stake)) = table_and_stake {
|
||||
let table_type = connection_table_l.peer_type;
|
||||
let max_uni_streams = match table_type {
|
||||
ConnectionPeerType::Unstaked => {
|
||||
VarInt::from_u64(QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS as u64)
|
||||
}
|
||||
ConnectionPeerType::Staked => {
|
||||
let staked_nodes = staked_nodes.read().unwrap();
|
||||
VarInt::from_u64(
|
||||
((stake as f64 / staked_nodes.total_stake as f64)
|
||||
* QUIC_TOTAL_STAKED_CONCURRENT_STREAMS) as u64,
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
if let Ok(max_uni_streams) = max_uni_streams {
|
||||
connection.set_max_concurrent_uni_streams(max_uni_streams);
|
||||
|
||||
if let Some((last_update, stream_exit)) = connection_table_l.try_add_connection(
|
||||
&remote_addr,
|
||||
Some(connection),
|
||||
stake,
|
||||
));
|
||||
timing::timestamp(),
|
||||
max_connections_per_ip,
|
||||
) {
|
||||
drop(connection_table_l);
|
||||
let stats = stats.clone();
|
||||
let connection_table = match table_type {
|
||||
ConnectionPeerType::Unstaked => unstaked_connection_table.clone(),
|
||||
ConnectionPeerType::Staked => staked_connection_table.clone(),
|
||||
};
|
||||
tokio::spawn(handle_connection(
|
||||
uni_streams,
|
||||
packet_sender,
|
||||
remote_addr,
|
||||
last_update,
|
||||
connection_table,
|
||||
stream_exit,
|
||||
stats,
|
||||
stake,
|
||||
));
|
||||
} else {
|
||||
stats.connection_add_failed.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
} else {
|
||||
stats.connection_add_failed.fetch_add(1, Ordering::Relaxed);
|
||||
stats
|
||||
.connection_add_failed_invalid_stream_count
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
} else {
|
||||
connection.close(0u32.into(), &[0u8]);
|
||||
|
@ -387,6 +435,7 @@ fn handle_chunk(
|
|||
#[derive(Debug)]
|
||||
struct ConnectionEntry {
|
||||
exit: Arc<AtomicBool>,
|
||||
stake: u64,
|
||||
last_update: Arc<AtomicU64>,
|
||||
port: u16,
|
||||
connection: Option<Connection>,
|
||||
|
@ -395,12 +444,14 @@ struct ConnectionEntry {
|
|||
impl ConnectionEntry {
|
||||
fn new(
|
||||
exit: Arc<AtomicBool>,
|
||||
stake: u64,
|
||||
last_update: Arc<AtomicU64>,
|
||||
port: u16,
|
||||
connection: Option<Connection>,
|
||||
) -> Self {
|
||||
Self {
|
||||
exit,
|
||||
stake,
|
||||
last_update,
|
||||
port,
|
||||
connection,
|
||||
|
@ -429,7 +480,7 @@ enum ConnectionPeerType {
|
|||
|
||||
// Map of IP to list of connection entries
|
||||
struct ConnectionTable {
|
||||
table: HashMap<IpAddr, Vec<ConnectionEntry>>,
|
||||
table: IndexMap<IpAddr, Vec<ConnectionEntry>>,
|
||||
total_size: usize,
|
||||
peer_type: ConnectionPeerType,
|
||||
}
|
||||
|
@ -439,7 +490,7 @@ struct ConnectionTable {
|
|||
impl ConnectionTable {
|
||||
fn new(peer_type: ConnectionPeerType) -> Self {
|
||||
Self {
|
||||
table: HashMap::default(),
|
||||
table: IndexMap::default(),
|
||||
total_size: 0,
|
||||
peer_type,
|
||||
}
|
||||
|
@ -473,10 +524,47 @@ impl ConnectionTable {
|
|||
num_pruned
|
||||
}
|
||||
|
||||
fn connection_stake(&self, index: usize) -> Option<u64> {
|
||||
self.table
|
||||
.get_index(index)
|
||||
.and_then(|(_, connection_vec)| connection_vec.first())
|
||||
.map(|connection| connection.stake)
|
||||
}
|
||||
|
||||
// Randomly select two connections, and evict the one with lower stake. If the stakes of both
|
||||
// the connections are higher than the threshold_stake, reject the pruning attempt, and return 0.
|
||||
fn prune_random(&mut self, threshold_stake: u64) -> usize {
|
||||
let mut num_pruned = 0;
|
||||
let mut rng = thread_rng();
|
||||
// The candidate1 and candidate2 could potentially be the same. If so, the stake of the candidate
|
||||
// will be compared just against the threshold_stake.
|
||||
let candidate1 = rng.gen_range(0, self.table.len());
|
||||
let candidate2 = rng.gen_range(0, self.table.len());
|
||||
|
||||
let candidate1_stake = self.connection_stake(candidate1).unwrap_or(0);
|
||||
let candidate2_stake = self.connection_stake(candidate2).unwrap_or(0);
|
||||
|
||||
if candidate1_stake < threshold_stake || candidate2_stake < threshold_stake {
|
||||
let removed = if candidate1_stake < candidate2_stake {
|
||||
self.table.swap_remove_index(candidate1)
|
||||
} else {
|
||||
self.table.swap_remove_index(candidate2)
|
||||
};
|
||||
|
||||
if let Some((_, removed_value)) = removed {
|
||||
self.total_size -= removed_value.len();
|
||||
num_pruned += removed_value.len();
|
||||
}
|
||||
}
|
||||
|
||||
num_pruned
|
||||
}
|
||||
|
||||
fn try_add_connection(
|
||||
&mut self,
|
||||
addr: &SocketAddr,
|
||||
connection: Option<Connection>,
|
||||
stake: u64,
|
||||
last_update: u64,
|
||||
max_connections_per_ip: usize,
|
||||
) -> Option<(Arc<AtomicU64>, Arc<AtomicBool>)> {
|
||||
|
@ -491,6 +579,7 @@ impl ConnectionTable {
|
|||
let last_update = Arc::new(AtomicU64::new(last_update));
|
||||
connection_entry.push(ConnectionEntry::new(
|
||||
exit.clone(),
|
||||
stake,
|
||||
last_update.clone(),
|
||||
addr.port(),
|
||||
connection,
|
||||
|
@ -818,7 +907,7 @@ pub mod test {
|
|||
staked_nodes
|
||||
.stake_map
|
||||
.insert(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 100000);
|
||||
staked_nodes.total_stake = 100000_f64;
|
||||
staked_nodes.total_stake = 100000;
|
||||
|
||||
let (t, exit, receiver, server_address, stats) = setup_quic_server(Some(staked_nodes));
|
||||
check_multiple_writes(receiver, server_address).await;
|
||||
|
@ -918,12 +1007,12 @@ pub mod test {
|
|||
.collect();
|
||||
for (i, socket) in sockets.iter().enumerate() {
|
||||
table
|
||||
.try_add_connection(socket, None, i as u64, max_connections_per_ip)
|
||||
.try_add_connection(socket, None, 0, i as u64, max_connections_per_ip)
|
||||
.unwrap();
|
||||
}
|
||||
num_entries += 1;
|
||||
table
|
||||
.try_add_connection(&sockets[0], None, 5, max_connections_per_ip)
|
||||
.try_add_connection(&sockets[0], None, 0, 5, max_connections_per_ip)
|
||||
.unwrap();
|
||||
|
||||
let new_size = 3;
|
||||
|
@ -942,6 +1031,40 @@ pub mod test {
|
|||
assert_eq!(table.total_size, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prune_table_random() {
|
||||
use std::net::Ipv4Addr;
|
||||
solana_logger::setup();
|
||||
let mut table = ConnectionTable::new(ConnectionPeerType::Staked);
|
||||
let num_entries = 5;
|
||||
let max_connections_per_ip = 10;
|
||||
let sockets: Vec<_> = (0..num_entries)
|
||||
.into_iter()
|
||||
.map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0))
|
||||
.collect();
|
||||
for (i, socket) in sockets.iter().enumerate() {
|
||||
table
|
||||
.try_add_connection(
|
||||
socket,
|
||||
None,
|
||||
(i + 1) as u64,
|
||||
i as u64,
|
||||
max_connections_per_ip,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Try pruninng with threshold stake less than all the entries in the table
|
||||
// It should fail to prune (i.e. return 0 number of pruned entries)
|
||||
let pruned = table.prune_random(0);
|
||||
assert_eq!(pruned, 0);
|
||||
|
||||
// Try pruninng with threshold stake higher than all the entries in the table
|
||||
// It should succeed to prune (i.e. return 1 number of pruned entries)
|
||||
let pruned = table.prune_random(num_entries as u64 + 1);
|
||||
assert_eq!(pruned, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove_connections() {
|
||||
use std::net::Ipv4Addr;
|
||||
|
@ -955,11 +1078,11 @@ pub mod test {
|
|||
.collect();
|
||||
for (i, socket) in sockets.iter().enumerate() {
|
||||
table
|
||||
.try_add_connection(socket, None, (i * 2) as u64, max_connections_per_ip)
|
||||
.try_add_connection(socket, None, 0, (i * 2) as u64, max_connections_per_ip)
|
||||
.unwrap();
|
||||
|
||||
table
|
||||
.try_add_connection(socket, None, (i * 2 + 1) as u64, max_connections_per_ip)
|
||||
.try_add_connection(socket, None, 0, (i * 2 + 1) as u64, max_connections_per_ip)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
|
@ -969,6 +1092,7 @@ pub mod test {
|
|||
.try_add_connection(
|
||||
&single_connection_addr,
|
||||
None,
|
||||
0,
|
||||
(num_ips * 2) as u64,
|
||||
max_connections_per_ip,
|
||||
)
|
||||
|
|
|
@ -158,7 +158,9 @@ pub struct StreamStats {
|
|||
pub(crate) total_stream_read_timeouts: AtomicUsize,
|
||||
pub(crate) num_evictions: AtomicUsize,
|
||||
pub(crate) connection_add_failed: AtomicUsize,
|
||||
pub(crate) connection_add_failed_invalid_stream_count: AtomicUsize,
|
||||
pub(crate) connection_add_failed_unstaked_node: AtomicUsize,
|
||||
pub(crate) connection_add_failed_on_pruning: AtomicUsize,
|
||||
pub(crate) connection_setup_timeout: AtomicUsize,
|
||||
pub(crate) connection_removed: AtomicUsize,
|
||||
pub(crate) connection_remove_failed: AtomicUsize,
|
||||
|
@ -198,12 +200,24 @@ impl StreamStats {
|
|||
self.connection_add_failed.swap(0, Ordering::Relaxed),
|
||||
i64
|
||||
),
|
||||
(
|
||||
"connection_add_failed_invalid_stream_count",
|
||||
self.connection_add_failed_invalid_stream_count
|
||||
.swap(0, Ordering::Relaxed),
|
||||
i64
|
||||
),
|
||||
(
|
||||
"connection_add_failed_unstaked_node",
|
||||
self.connection_add_failed_unstaked_node
|
||||
.swap(0, Ordering::Relaxed),
|
||||
i64
|
||||
),
|
||||
(
|
||||
"connection_add_failed_on_pruning",
|
||||
self.connection_add_failed_on_pruning
|
||||
.swap(0, Ordering::Relaxed),
|
||||
i64
|
||||
),
|
||||
(
|
||||
"connection_removed",
|
||||
self.connection_removed.swap(0, Ordering::Relaxed),
|
||||
|
|
|
@ -27,7 +27,7 @@ use {
|
|||
// Total stake and nodes => stake map
|
||||
#[derive(Default)]
|
||||
pub struct StakedNodes {
|
||||
pub total_stake: f64,
|
||||
pub total_stake: u64,
|
||||
pub stake_map: HashMap<IpAddr, u64>,
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue