Implement randomized pruning of QUIC connection from staked peers (#26299)

This commit is contained in:
Pankaj Garg 2022-06-30 17:56:15 -07:00 committed by GitHub
parent 72a968fbe8
commit 94685e1222
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 201 additions and 60 deletions

1
Cargo.lock generated
View File

@ -6156,6 +6156,7 @@ dependencies = [
"crossbeam-channel", "crossbeam-channel",
"futures-util", "futures-util",
"histogram", "histogram",
"indexmap",
"itertools", "itertools",
"libc", "libc",
"log", "log",

View File

@ -42,7 +42,7 @@ impl StakedNodesUpdaterService {
&cluster_info, &cluster_info,
) { ) {
let mut shared = shared_staked_nodes.write().unwrap(); let mut shared = shared_staked_nodes.write().unwrap();
shared.total_stake = total_stake as f64; shared.total_stake = total_stake;
shared.stake_map = new_ip_to_stake; shared.stake_map = new_ip_to_stake;
} }
} }

View File

@ -5446,6 +5446,7 @@ dependencies = [
"crossbeam-channel", "crossbeam-channel",
"futures-util", "futures-util",
"histogram", "histogram",
"indexmap",
"itertools", "itertools",
"libc", "libc",
"log", "log",

View File

@ -13,6 +13,7 @@ edition = "2021"
crossbeam-channel = "0.5" crossbeam-channel = "0.5"
futures-util = "0.3.21" futures-util = "0.3.21"
histogram = "0.6.9" histogram = "0.6.9"
indexmap = "1.8.1"
itertools = "0.10.3" itertools = "0.10.3"
libc = "0.2.126" libc = "0.2.126"
log = "0.4.17" log = "0.4.17"

View File

@ -5,11 +5,13 @@ use {
}, },
crossbeam_channel::Sender, crossbeam_channel::Sender,
futures_util::stream::StreamExt, futures_util::stream::StreamExt,
indexmap::map::{Entry, IndexMap},
percentage::Percentage, percentage::Percentage,
quinn::{ quinn::{
Connecting, Connection, Endpoint, EndpointConfig, Incoming, IncomingUniStreams, Connecting, Connection, Endpoint, EndpointConfig, Incoming, IncomingUniStreams,
NewConnection, VarInt, NewConnection, VarInt,
}, },
rand::{thread_rng, Rng},
solana_perf::packet::PacketBatch, solana_perf::packet::PacketBatch,
solana_sdk::{ solana_sdk::{
packet::{Packet, PACKET_DATA_SIZE}, packet::{Packet, PACKET_DATA_SIZE},
@ -18,11 +20,10 @@ use {
timing, timing,
}, },
std::{ std::{
collections::{hash_map::Entry, HashMap},
net::{IpAddr, SocketAddr, UdpSocket}, net::{IpAddr, SocketAddr, UdpSocket},
sync::{ sync::{
atomic::{AtomicBool, AtomicU64, Ordering}, atomic::{AtomicBool, AtomicU64, Ordering},
Arc, Mutex, RwLock, Arc, Mutex, MutexGuard, RwLock,
}, },
time::{Duration, Instant}, time::{Duration, Instant},
}, },
@ -116,6 +117,21 @@ pub async fn run_server(
} }
} }
fn prune_unstaked_connection_table(
unstaked_connection_table: &mut MutexGuard<ConnectionTable>,
max_unstaked_connections: usize,
stats: Arc<StreamStats>,
) {
if unstaked_connection_table.total_size >= max_unstaked_connections {
const PRUNE_TABLE_TO_PERCENTAGE: u8 = 90;
let max_percentage_full = Percentage::from(PRUNE_TABLE_TO_PERCENTAGE);
let max_connections = max_percentage_full.apply_to(max_unstaked_connections);
let num_pruned = unstaked_connection_table.prune_oldest(max_connections);
stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed);
}
}
async fn setup_connection( async fn setup_connection(
connection: Connecting, connection: Connecting,
unstaked_connection_table: Arc<Mutex<ConnectionTable>>, unstaked_connection_table: Arc<Mutex<ConnectionTable>>,
@ -138,52 +154,79 @@ async fn setup_connection(
let remote_addr = connection.remote_address(); let remote_addr = connection.remote_address();
let (mut connection_table_l, stake) = { let table_and_stake = {
const PRUNE_TABLE_TO_PERCENTAGE: u8 = 90;
let max_percentage_full = Percentage::from(PRUNE_TABLE_TO_PERCENTAGE);
let staked_nodes = staked_nodes.read().unwrap(); let staked_nodes = staked_nodes.read().unwrap();
if let Some(stake) = staked_nodes.stake_map.get(&remote_addr.ip()) { if let Some(stake) = staked_nodes.stake_map.get(&remote_addr.ip()) {
let stake = *stake; let stake = *stake;
let total_stake = staked_nodes.total_stake;
drop(staked_nodes); drop(staked_nodes);
let mut connection_table_l = staked_connection_table.lock().unwrap(); let mut connection_table_l = staked_connection_table.lock().unwrap();
if connection_table_l.total_size >= max_staked_connections { if connection_table_l.total_size >= max_staked_connections {
let max_connections = max_percentage_full.apply_to(max_staked_connections); let num_pruned = connection_table_l.prune_random(stake);
let num_pruned = connection_table_l.prune_oldest(max_connections); if num_pruned == 0 {
stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed); if max_unstaked_connections > 0 {
} // If we couldn't prune a connection in the staked connection table, let's
connection.set_max_concurrent_uni_streams( // put this connection in the unstaked connection table. If needed, prune a
VarInt::from_u64( // connection from the unstaked connection table.
((stake as f64 / total_stake as f64) * QUIC_TOTAL_STAKED_CONCURRENT_STREAMS) connection_table_l = unstaked_connection_table.lock().unwrap();
as u64, prune_unstaked_connection_table(
) &mut connection_table_l,
.unwrap(), max_unstaked_connections,
stats.clone(),
); );
(connection_table_l, stake) Some((connection_table_l, stake))
} else { } else {
stats
.connection_add_failed_on_pruning
.fetch_add(1, Ordering::Relaxed);
None
}
} else {
stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed);
Some((connection_table_l, stake))
}
} else {
Some((connection_table_l, stake))
}
} else if max_unstaked_connections > 0 {
drop(staked_nodes); drop(staked_nodes);
let mut connection_table_l = unstaked_connection_table.lock().unwrap(); let mut connection_table_l = unstaked_connection_table.lock().unwrap();
if connection_table_l.total_size >= max_unstaked_connections { prune_unstaked_connection_table(
let max_connections = max_percentage_full.apply_to(max_unstaked_connections); &mut connection_table_l,
let num_pruned = connection_table_l.prune_oldest(max_connections); max_unstaked_connections,
stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed); stats.clone(),
}
connection.set_max_concurrent_uni_streams(
VarInt::from_u64(QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS as u64).unwrap(),
); );
(connection_table_l, 0) Some((connection_table_l, 0))
} else {
None
} }
}; };
if stake != 0 || max_unstaked_connections > 0 { if let Some((mut connection_table_l, stake)) = table_and_stake {
let table_type = connection_table_l.peer_type;
let max_uni_streams = match table_type {
ConnectionPeerType::Unstaked => {
VarInt::from_u64(QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS as u64)
}
ConnectionPeerType::Staked => {
let staked_nodes = staked_nodes.read().unwrap();
VarInt::from_u64(
((stake as f64 / staked_nodes.total_stake as f64)
* QUIC_TOTAL_STAKED_CONCURRENT_STREAMS) as u64,
)
}
};
if let Ok(max_uni_streams) = max_uni_streams {
connection.set_max_concurrent_uni_streams(max_uni_streams);
if let Some((last_update, stream_exit)) = connection_table_l.try_add_connection( if let Some((last_update, stream_exit)) = connection_table_l.try_add_connection(
&remote_addr, &remote_addr,
Some(connection), Some(connection),
stake,
timing::timestamp(), timing::timestamp(),
max_connections_per_ip, max_connections_per_ip,
) { ) {
let table_type = connection_table_l.peer_type;
drop(connection_table_l); drop(connection_table_l);
let stats = stats.clone(); let stats = stats.clone();
let connection_table = match table_type { let connection_table = match table_type {
@ -203,6 +246,11 @@ async fn setup_connection(
} else { } else {
stats.connection_add_failed.fetch_add(1, Ordering::Relaxed); stats.connection_add_failed.fetch_add(1, Ordering::Relaxed);
} }
} else {
stats
.connection_add_failed_invalid_stream_count
.fetch_add(1, Ordering::Relaxed);
}
} else { } else {
connection.close(0u32.into(), &[0u8]); connection.close(0u32.into(), &[0u8]);
stats stats
@ -387,6 +435,7 @@ fn handle_chunk(
#[derive(Debug)] #[derive(Debug)]
struct ConnectionEntry { struct ConnectionEntry {
exit: Arc<AtomicBool>, exit: Arc<AtomicBool>,
stake: u64,
last_update: Arc<AtomicU64>, last_update: Arc<AtomicU64>,
port: u16, port: u16,
connection: Option<Connection>, connection: Option<Connection>,
@ -395,12 +444,14 @@ struct ConnectionEntry {
impl ConnectionEntry { impl ConnectionEntry {
fn new( fn new(
exit: Arc<AtomicBool>, exit: Arc<AtomicBool>,
stake: u64,
last_update: Arc<AtomicU64>, last_update: Arc<AtomicU64>,
port: u16, port: u16,
connection: Option<Connection>, connection: Option<Connection>,
) -> Self { ) -> Self {
Self { Self {
exit, exit,
stake,
last_update, last_update,
port, port,
connection, connection,
@ -429,7 +480,7 @@ enum ConnectionPeerType {
// Map of IP to list of connection entries // Map of IP to list of connection entries
struct ConnectionTable { struct ConnectionTable {
table: HashMap<IpAddr, Vec<ConnectionEntry>>, table: IndexMap<IpAddr, Vec<ConnectionEntry>>,
total_size: usize, total_size: usize,
peer_type: ConnectionPeerType, peer_type: ConnectionPeerType,
} }
@ -439,7 +490,7 @@ struct ConnectionTable {
impl ConnectionTable { impl ConnectionTable {
fn new(peer_type: ConnectionPeerType) -> Self { fn new(peer_type: ConnectionPeerType) -> Self {
Self { Self {
table: HashMap::default(), table: IndexMap::default(),
total_size: 0, total_size: 0,
peer_type, peer_type,
} }
@ -473,10 +524,47 @@ impl ConnectionTable {
num_pruned num_pruned
} }
fn connection_stake(&self, index: usize) -> Option<u64> {
self.table
.get_index(index)
.and_then(|(_, connection_vec)| connection_vec.first())
.map(|connection| connection.stake)
}
// Randomly select two connections, and evict the one with lower stake. If the stakes of both
// the connections are higher than the threshold_stake, reject the pruning attempt, and return 0.
fn prune_random(&mut self, threshold_stake: u64) -> usize {
let mut num_pruned = 0;
let mut rng = thread_rng();
// The candidate1 and candidate2 could potentially be the same. If so, the stake of the candidate
// will be compared just against the threshold_stake.
let candidate1 = rng.gen_range(0, self.table.len());
let candidate2 = rng.gen_range(0, self.table.len());
let candidate1_stake = self.connection_stake(candidate1).unwrap_or(0);
let candidate2_stake = self.connection_stake(candidate2).unwrap_or(0);
if candidate1_stake < threshold_stake || candidate2_stake < threshold_stake {
let removed = if candidate1_stake < candidate2_stake {
self.table.swap_remove_index(candidate1)
} else {
self.table.swap_remove_index(candidate2)
};
if let Some((_, removed_value)) = removed {
self.total_size -= removed_value.len();
num_pruned += removed_value.len();
}
}
num_pruned
}
fn try_add_connection( fn try_add_connection(
&mut self, &mut self,
addr: &SocketAddr, addr: &SocketAddr,
connection: Option<Connection>, connection: Option<Connection>,
stake: u64,
last_update: u64, last_update: u64,
max_connections_per_ip: usize, max_connections_per_ip: usize,
) -> Option<(Arc<AtomicU64>, Arc<AtomicBool>)> { ) -> Option<(Arc<AtomicU64>, Arc<AtomicBool>)> {
@ -491,6 +579,7 @@ impl ConnectionTable {
let last_update = Arc::new(AtomicU64::new(last_update)); let last_update = Arc::new(AtomicU64::new(last_update));
connection_entry.push(ConnectionEntry::new( connection_entry.push(ConnectionEntry::new(
exit.clone(), exit.clone(),
stake,
last_update.clone(), last_update.clone(),
addr.port(), addr.port(),
connection, connection,
@ -818,7 +907,7 @@ pub mod test {
staked_nodes staked_nodes
.stake_map .stake_map
.insert(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 100000); .insert(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 100000);
staked_nodes.total_stake = 100000_f64; staked_nodes.total_stake = 100000;
let (t, exit, receiver, server_address, stats) = setup_quic_server(Some(staked_nodes)); let (t, exit, receiver, server_address, stats) = setup_quic_server(Some(staked_nodes));
check_multiple_writes(receiver, server_address).await; check_multiple_writes(receiver, server_address).await;
@ -918,12 +1007,12 @@ pub mod test {
.collect(); .collect();
for (i, socket) in sockets.iter().enumerate() { for (i, socket) in sockets.iter().enumerate() {
table table
.try_add_connection(socket, None, i as u64, max_connections_per_ip) .try_add_connection(socket, None, 0, i as u64, max_connections_per_ip)
.unwrap(); .unwrap();
} }
num_entries += 1; num_entries += 1;
table table
.try_add_connection(&sockets[0], None, 5, max_connections_per_ip) .try_add_connection(&sockets[0], None, 0, 5, max_connections_per_ip)
.unwrap(); .unwrap();
let new_size = 3; let new_size = 3;
@ -942,6 +1031,40 @@ pub mod test {
assert_eq!(table.total_size, 0); assert_eq!(table.total_size, 0);
} }
#[test]
fn test_prune_table_random() {
use std::net::Ipv4Addr;
solana_logger::setup();
let mut table = ConnectionTable::new(ConnectionPeerType::Staked);
let num_entries = 5;
let max_connections_per_ip = 10;
let sockets: Vec<_> = (0..num_entries)
.into_iter()
.map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0))
.collect();
for (i, socket) in sockets.iter().enumerate() {
table
.try_add_connection(
socket,
None,
(i + 1) as u64,
i as u64,
max_connections_per_ip,
)
.unwrap();
}
// Try pruninng with threshold stake less than all the entries in the table
// It should fail to prune (i.e. return 0 number of pruned entries)
let pruned = table.prune_random(0);
assert_eq!(pruned, 0);
// Try pruninng with threshold stake higher than all the entries in the table
// It should succeed to prune (i.e. return 1 number of pruned entries)
let pruned = table.prune_random(num_entries as u64 + 1);
assert_eq!(pruned, 1);
}
#[test] #[test]
fn test_remove_connections() { fn test_remove_connections() {
use std::net::Ipv4Addr; use std::net::Ipv4Addr;
@ -955,11 +1078,11 @@ pub mod test {
.collect(); .collect();
for (i, socket) in sockets.iter().enumerate() { for (i, socket) in sockets.iter().enumerate() {
table table
.try_add_connection(socket, None, (i * 2) as u64, max_connections_per_ip) .try_add_connection(socket, None, 0, (i * 2) as u64, max_connections_per_ip)
.unwrap(); .unwrap();
table table
.try_add_connection(socket, None, (i * 2 + 1) as u64, max_connections_per_ip) .try_add_connection(socket, None, 0, (i * 2 + 1) as u64, max_connections_per_ip)
.unwrap(); .unwrap();
} }
@ -969,6 +1092,7 @@ pub mod test {
.try_add_connection( .try_add_connection(
&single_connection_addr, &single_connection_addr,
None, None,
0,
(num_ips * 2) as u64, (num_ips * 2) as u64,
max_connections_per_ip, max_connections_per_ip,
) )

View File

@ -158,7 +158,9 @@ pub struct StreamStats {
pub(crate) total_stream_read_timeouts: AtomicUsize, pub(crate) total_stream_read_timeouts: AtomicUsize,
pub(crate) num_evictions: AtomicUsize, pub(crate) num_evictions: AtomicUsize,
pub(crate) connection_add_failed: AtomicUsize, pub(crate) connection_add_failed: AtomicUsize,
pub(crate) connection_add_failed_invalid_stream_count: AtomicUsize,
pub(crate) connection_add_failed_unstaked_node: AtomicUsize, pub(crate) connection_add_failed_unstaked_node: AtomicUsize,
pub(crate) connection_add_failed_on_pruning: AtomicUsize,
pub(crate) connection_setup_timeout: AtomicUsize, pub(crate) connection_setup_timeout: AtomicUsize,
pub(crate) connection_removed: AtomicUsize, pub(crate) connection_removed: AtomicUsize,
pub(crate) connection_remove_failed: AtomicUsize, pub(crate) connection_remove_failed: AtomicUsize,
@ -198,12 +200,24 @@ impl StreamStats {
self.connection_add_failed.swap(0, Ordering::Relaxed), self.connection_add_failed.swap(0, Ordering::Relaxed),
i64 i64
), ),
(
"connection_add_failed_invalid_stream_count",
self.connection_add_failed_invalid_stream_count
.swap(0, Ordering::Relaxed),
i64
),
( (
"connection_add_failed_unstaked_node", "connection_add_failed_unstaked_node",
self.connection_add_failed_unstaked_node self.connection_add_failed_unstaked_node
.swap(0, Ordering::Relaxed), .swap(0, Ordering::Relaxed),
i64 i64
), ),
(
"connection_add_failed_on_pruning",
self.connection_add_failed_on_pruning
.swap(0, Ordering::Relaxed),
i64
),
( (
"connection_removed", "connection_removed",
self.connection_removed.swap(0, Ordering::Relaxed), self.connection_removed.swap(0, Ordering::Relaxed),

View File

@ -27,7 +27,7 @@ use {
// Total stake and nodes => stake map // Total stake and nodes => stake map
#[derive(Default)] #[derive(Default)]
pub struct StakedNodes { pub struct StakedNodes {
pub total_stake: f64, pub total_stake: u64,
pub stake_map: HashMap<IpAddr, u64>, pub stake_map: HashMap<IpAddr, u64>,
} }