Implement randomized pruning of QUIC connection from staked peers (#26299)

This commit is contained in:
Pankaj Garg 2022-06-30 17:56:15 -07:00 committed by GitHub
parent 72a968fbe8
commit 94685e1222
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 201 additions and 60 deletions

1
Cargo.lock generated
View File

@ -6156,6 +6156,7 @@ dependencies = [
"crossbeam-channel",
"futures-util",
"histogram",
"indexmap",
"itertools",
"libc",
"log",

View File

@ -42,7 +42,7 @@ impl StakedNodesUpdaterService {
&cluster_info,
) {
let mut shared = shared_staked_nodes.write().unwrap();
shared.total_stake = total_stake as f64;
shared.total_stake = total_stake;
shared.stake_map = new_ip_to_stake;
}
}

View File

@ -5446,6 +5446,7 @@ dependencies = [
"crossbeam-channel",
"futures-util",
"histogram",
"indexmap",
"itertools",
"libc",
"log",

View File

@ -13,6 +13,7 @@ edition = "2021"
crossbeam-channel = "0.5"
futures-util = "0.3.21"
histogram = "0.6.9"
indexmap = "1.8.1"
itertools = "0.10.3"
libc = "0.2.126"
log = "0.4.17"

View File

@ -5,11 +5,13 @@ use {
},
crossbeam_channel::Sender,
futures_util::stream::StreamExt,
indexmap::map::{Entry, IndexMap},
percentage::Percentage,
quinn::{
Connecting, Connection, Endpoint, EndpointConfig, Incoming, IncomingUniStreams,
NewConnection, VarInt,
},
rand::{thread_rng, Rng},
solana_perf::packet::PacketBatch,
solana_sdk::{
packet::{Packet, PACKET_DATA_SIZE},
@ -18,11 +20,10 @@ use {
timing,
},
std::{
collections::{hash_map::Entry, HashMap},
net::{IpAddr, SocketAddr, UdpSocket},
sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc, Mutex, RwLock,
Arc, Mutex, MutexGuard, RwLock,
},
time::{Duration, Instant},
},
@ -116,6 +117,21 @@ pub async fn run_server(
}
}
fn prune_unstaked_connection_table(
unstaked_connection_table: &mut MutexGuard<ConnectionTable>,
max_unstaked_connections: usize,
stats: Arc<StreamStats>,
) {
if unstaked_connection_table.total_size >= max_unstaked_connections {
const PRUNE_TABLE_TO_PERCENTAGE: u8 = 90;
let max_percentage_full = Percentage::from(PRUNE_TABLE_TO_PERCENTAGE);
let max_connections = max_percentage_full.apply_to(max_unstaked_connections);
let num_pruned = unstaked_connection_table.prune_oldest(max_connections);
stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed);
}
}
async fn setup_connection(
connection: Connecting,
unstaked_connection_table: Arc<Mutex<ConnectionTable>>,
@ -138,70 +154,102 @@ async fn setup_connection(
let remote_addr = connection.remote_address();
let (mut connection_table_l, stake) = {
const PRUNE_TABLE_TO_PERCENTAGE: u8 = 90;
let max_percentage_full = Percentage::from(PRUNE_TABLE_TO_PERCENTAGE);
let table_and_stake = {
let staked_nodes = staked_nodes.read().unwrap();
if let Some(stake) = staked_nodes.stake_map.get(&remote_addr.ip()) {
let stake = *stake;
let total_stake = staked_nodes.total_stake;
drop(staked_nodes);
let mut connection_table_l = staked_connection_table.lock().unwrap();
if connection_table_l.total_size >= max_staked_connections {
let max_connections = max_percentage_full.apply_to(max_staked_connections);
let num_pruned = connection_table_l.prune_oldest(max_connections);
stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed);
let num_pruned = connection_table_l.prune_random(stake);
if num_pruned == 0 {
if max_unstaked_connections > 0 {
// If we couldn't prune a connection in the staked connection table, let's
// put this connection in the unstaked connection table. If needed, prune a
// connection from the unstaked connection table.
connection_table_l = unstaked_connection_table.lock().unwrap();
prune_unstaked_connection_table(
&mut connection_table_l,
max_unstaked_connections,
stats.clone(),
);
Some((connection_table_l, stake))
} else {
stats
.connection_add_failed_on_pruning
.fetch_add(1, Ordering::Relaxed);
None
}
} else {
stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed);
Some((connection_table_l, stake))
}
} else {
Some((connection_table_l, stake))
}
connection.set_max_concurrent_uni_streams(
VarInt::from_u64(
((stake as f64 / total_stake as f64) * QUIC_TOTAL_STAKED_CONCURRENT_STREAMS)
as u64,
)
.unwrap(),
);
(connection_table_l, stake)
} else {
} else if max_unstaked_connections > 0 {
drop(staked_nodes);
let mut connection_table_l = unstaked_connection_table.lock().unwrap();
if connection_table_l.total_size >= max_unstaked_connections {
let max_connections = max_percentage_full.apply_to(max_unstaked_connections);
let num_pruned = connection_table_l.prune_oldest(max_connections);
stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed);
}
connection.set_max_concurrent_uni_streams(
VarInt::from_u64(QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS as u64).unwrap(),
prune_unstaked_connection_table(
&mut connection_table_l,
max_unstaked_connections,
stats.clone(),
);
(connection_table_l, 0)
Some((connection_table_l, 0))
} else {
None
}
};
if stake != 0 || max_unstaked_connections > 0 {
if let Some((last_update, stream_exit)) = connection_table_l.try_add_connection(
&remote_addr,
Some(connection),
timing::timestamp(),
max_connections_per_ip,
) {
let table_type = connection_table_l.peer_type;
drop(connection_table_l);
let stats = stats.clone();
let connection_table = match table_type {
ConnectionPeerType::Unstaked => unstaked_connection_table.clone(),
ConnectionPeerType::Staked => staked_connection_table.clone(),
};
tokio::spawn(handle_connection(
uni_streams,
packet_sender,
remote_addr,
last_update,
connection_table,
stream_exit,
stats,
if let Some((mut connection_table_l, stake)) = table_and_stake {
let table_type = connection_table_l.peer_type;
let max_uni_streams = match table_type {
ConnectionPeerType::Unstaked => {
VarInt::from_u64(QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS as u64)
}
ConnectionPeerType::Staked => {
let staked_nodes = staked_nodes.read().unwrap();
VarInt::from_u64(
((stake as f64 / staked_nodes.total_stake as f64)
* QUIC_TOTAL_STAKED_CONCURRENT_STREAMS) as u64,
)
}
};
if let Ok(max_uni_streams) = max_uni_streams {
connection.set_max_concurrent_uni_streams(max_uni_streams);
if let Some((last_update, stream_exit)) = connection_table_l.try_add_connection(
&remote_addr,
Some(connection),
stake,
));
timing::timestamp(),
max_connections_per_ip,
) {
drop(connection_table_l);
let stats = stats.clone();
let connection_table = match table_type {
ConnectionPeerType::Unstaked => unstaked_connection_table.clone(),
ConnectionPeerType::Staked => staked_connection_table.clone(),
};
tokio::spawn(handle_connection(
uni_streams,
packet_sender,
remote_addr,
last_update,
connection_table,
stream_exit,
stats,
stake,
));
} else {
stats.connection_add_failed.fetch_add(1, Ordering::Relaxed);
}
} else {
stats.connection_add_failed.fetch_add(1, Ordering::Relaxed);
stats
.connection_add_failed_invalid_stream_count
.fetch_add(1, Ordering::Relaxed);
}
} else {
connection.close(0u32.into(), &[0u8]);
@ -387,6 +435,7 @@ fn handle_chunk(
#[derive(Debug)]
struct ConnectionEntry {
exit: Arc<AtomicBool>,
stake: u64,
last_update: Arc<AtomicU64>,
port: u16,
connection: Option<Connection>,
@ -395,12 +444,14 @@ struct ConnectionEntry {
impl ConnectionEntry {
fn new(
exit: Arc<AtomicBool>,
stake: u64,
last_update: Arc<AtomicU64>,
port: u16,
connection: Option<Connection>,
) -> Self {
Self {
exit,
stake,
last_update,
port,
connection,
@ -429,7 +480,7 @@ enum ConnectionPeerType {
// Map of IP to list of connection entries
struct ConnectionTable {
table: HashMap<IpAddr, Vec<ConnectionEntry>>,
table: IndexMap<IpAddr, Vec<ConnectionEntry>>,
total_size: usize,
peer_type: ConnectionPeerType,
}
@ -439,7 +490,7 @@ struct ConnectionTable {
impl ConnectionTable {
fn new(peer_type: ConnectionPeerType) -> Self {
Self {
table: HashMap::default(),
table: IndexMap::default(),
total_size: 0,
peer_type,
}
@ -473,10 +524,47 @@ impl ConnectionTable {
num_pruned
}
fn connection_stake(&self, index: usize) -> Option<u64> {
self.table
.get_index(index)
.and_then(|(_, connection_vec)| connection_vec.first())
.map(|connection| connection.stake)
}
// Randomly select two connections, and evict the one with lower stake. If the stakes of both
// the connections are higher than the threshold_stake, reject the pruning attempt, and return 0.
fn prune_random(&mut self, threshold_stake: u64) -> usize {
let mut num_pruned = 0;
let mut rng = thread_rng();
// The candidate1 and candidate2 could potentially be the same. If so, the stake of the candidate
// will be compared just against the threshold_stake.
let candidate1 = rng.gen_range(0, self.table.len());
let candidate2 = rng.gen_range(0, self.table.len());
let candidate1_stake = self.connection_stake(candidate1).unwrap_or(0);
let candidate2_stake = self.connection_stake(candidate2).unwrap_or(0);
if candidate1_stake < threshold_stake || candidate2_stake < threshold_stake {
let removed = if candidate1_stake < candidate2_stake {
self.table.swap_remove_index(candidate1)
} else {
self.table.swap_remove_index(candidate2)
};
if let Some((_, removed_value)) = removed {
self.total_size -= removed_value.len();
num_pruned += removed_value.len();
}
}
num_pruned
}
fn try_add_connection(
&mut self,
addr: &SocketAddr,
connection: Option<Connection>,
stake: u64,
last_update: u64,
max_connections_per_ip: usize,
) -> Option<(Arc<AtomicU64>, Arc<AtomicBool>)> {
@ -491,6 +579,7 @@ impl ConnectionTable {
let last_update = Arc::new(AtomicU64::new(last_update));
connection_entry.push(ConnectionEntry::new(
exit.clone(),
stake,
last_update.clone(),
addr.port(),
connection,
@ -818,7 +907,7 @@ pub mod test {
staked_nodes
.stake_map
.insert(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 100000);
staked_nodes.total_stake = 100000_f64;
staked_nodes.total_stake = 100000;
let (t, exit, receiver, server_address, stats) = setup_quic_server(Some(staked_nodes));
check_multiple_writes(receiver, server_address).await;
@ -918,12 +1007,12 @@ pub mod test {
.collect();
for (i, socket) in sockets.iter().enumerate() {
table
.try_add_connection(socket, None, i as u64, max_connections_per_ip)
.try_add_connection(socket, None, 0, i as u64, max_connections_per_ip)
.unwrap();
}
num_entries += 1;
table
.try_add_connection(&sockets[0], None, 5, max_connections_per_ip)
.try_add_connection(&sockets[0], None, 0, 5, max_connections_per_ip)
.unwrap();
let new_size = 3;
@ -942,6 +1031,40 @@ pub mod test {
assert_eq!(table.total_size, 0);
}
#[test]
fn test_prune_table_random() {
use std::net::Ipv4Addr;
solana_logger::setup();
let mut table = ConnectionTable::new(ConnectionPeerType::Staked);
let num_entries = 5;
let max_connections_per_ip = 10;
let sockets: Vec<_> = (0..num_entries)
.into_iter()
.map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0))
.collect();
for (i, socket) in sockets.iter().enumerate() {
table
.try_add_connection(
socket,
None,
(i + 1) as u64,
i as u64,
max_connections_per_ip,
)
.unwrap();
}
// Try pruninng with threshold stake less than all the entries in the table
// It should fail to prune (i.e. return 0 number of pruned entries)
let pruned = table.prune_random(0);
assert_eq!(pruned, 0);
// Try pruninng with threshold stake higher than all the entries in the table
// It should succeed to prune (i.e. return 1 number of pruned entries)
let pruned = table.prune_random(num_entries as u64 + 1);
assert_eq!(pruned, 1);
}
#[test]
fn test_remove_connections() {
use std::net::Ipv4Addr;
@ -955,11 +1078,11 @@ pub mod test {
.collect();
for (i, socket) in sockets.iter().enumerate() {
table
.try_add_connection(socket, None, (i * 2) as u64, max_connections_per_ip)
.try_add_connection(socket, None, 0, (i * 2) as u64, max_connections_per_ip)
.unwrap();
table
.try_add_connection(socket, None, (i * 2 + 1) as u64, max_connections_per_ip)
.try_add_connection(socket, None, 0, (i * 2 + 1) as u64, max_connections_per_ip)
.unwrap();
}
@ -969,6 +1092,7 @@ pub mod test {
.try_add_connection(
&single_connection_addr,
None,
0,
(num_ips * 2) as u64,
max_connections_per_ip,
)

View File

@ -158,7 +158,9 @@ pub struct StreamStats {
pub(crate) total_stream_read_timeouts: AtomicUsize,
pub(crate) num_evictions: AtomicUsize,
pub(crate) connection_add_failed: AtomicUsize,
pub(crate) connection_add_failed_invalid_stream_count: AtomicUsize,
pub(crate) connection_add_failed_unstaked_node: AtomicUsize,
pub(crate) connection_add_failed_on_pruning: AtomicUsize,
pub(crate) connection_setup_timeout: AtomicUsize,
pub(crate) connection_removed: AtomicUsize,
pub(crate) connection_remove_failed: AtomicUsize,
@ -198,12 +200,24 @@ impl StreamStats {
self.connection_add_failed.swap(0, Ordering::Relaxed),
i64
),
(
"connection_add_failed_invalid_stream_count",
self.connection_add_failed_invalid_stream_count
.swap(0, Ordering::Relaxed),
i64
),
(
"connection_add_failed_unstaked_node",
self.connection_add_failed_unstaked_node
.swap(0, Ordering::Relaxed),
i64
),
(
"connection_add_failed_on_pruning",
self.connection_add_failed_on_pruning
.swap(0, Ordering::Relaxed),
i64
),
(
"connection_removed",
self.connection_removed.swap(0, Ordering::Relaxed),

View File

@ -27,7 +27,7 @@ use {
// Total stake and nodes => stake map
#[derive(Default)]
pub struct StakedNodes {
pub total_stake: f64,
pub total_stake: u64,
pub stake_map: HashMap<IpAddr, u64>,
}