Quic limit connections (#23283)

* quic server limit connections

* bump per_ip

* Review comments

* Make the connections per port
This commit is contained in:
sakridge 2022-03-09 10:52:31 +01:00 committed by GitHub
parent 8a4b019ded
commit 7a9884c831
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 316 additions and 28 deletions

View File

@ -36,6 +36,9 @@ use {
pub const DEFAULT_TPU_COALESCE_MS: u64 = 5;
// allow multiple connections for NAT and any open/close overlap
pub const MAX_QUIC_CONNECTIONS_PER_IP: usize = 8;
pub struct TpuSockets {
pub transactions: Vec<UdpSocket>,
pub transaction_forwards: Vec<UdpSocket>,
@ -108,6 +111,7 @@ impl Tpu {
cluster_info.my_contact_info().tpu.ip(),
packet_sender,
exit.clone(),
MAX_QUIC_CONNECTIONS_PER_IP,
)
.unwrap();

View File

@ -3,22 +3,24 @@ use {
futures_util::stream::StreamExt,
pem::Pem,
pkcs8::{der::Document, AlgorithmIdentifier, ObjectIdentifier},
quinn::{Endpoint, EndpointConfig, ServerConfig},
quinn::{Endpoint, EndpointConfig, IncomingUniStreams, ServerConfig},
rcgen::{CertificateParams, DistinguishedName, DnType, SanType},
solana_perf::packet::PacketBatch,
solana_sdk::{
packet::{Packet, PACKET_DATA_SIZE},
signature::Keypair,
timing,
},
std::{
collections::{hash_map::Entry, HashMap},
error::Error,
net::{IpAddr, SocketAddr, UdpSocket},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
Arc, Mutex,
},
thread,
time::Duration,
time::{Duration, Instant},
},
tokio::{
runtime::{Builder, Runtime},
@ -120,8 +122,12 @@ fn new_cert_params(identity_keypair: &Keypair, san: IpAddr) -> CertificateParams
cert_params
}
pub fn rt() -> Runtime {
Builder::new_current_thread().enable_all().build().unwrap()
fn rt() -> Runtime {
Builder::new_multi_thread()
.worker_threads(1)
.enable_all()
.build()
.unwrap()
}
#[derive(thiserror::Error, Debug)]
@ -190,12 +196,207 @@ fn handle_chunk(
false
}
#[derive(Debug)]
struct ConnectionEntry {
exit: Arc<AtomicBool>,
last_update: Arc<AtomicU64>,
port: u16,
}
impl ConnectionEntry {
fn new(exit: Arc<AtomicBool>, last_update: Arc<AtomicU64>, port: u16) -> Self {
Self {
exit,
last_update,
port,
}
}
fn last_update(&self) -> u64 {
self.last_update.load(Ordering::Relaxed)
}
}
impl Drop for ConnectionEntry {
fn drop(&mut self) {
self.exit.store(true, Ordering::Relaxed);
}
}
// Map of IP to list of connection entries
#[derive(Default, Debug)]
struct ConnectionTable {
table: HashMap<IpAddr, Vec<ConnectionEntry>>,
total_size: usize,
}
// Prune the connection which has the oldest update
// Return number pruned
impl ConnectionTable {
fn prune_oldest(&mut self, max_size: usize) -> usize {
let mut num_pruned = 0;
while self.total_size > max_size {
let mut oldest = std::u64::MAX;
let mut oldest_ip = None;
for (ip, connections) in self.table.iter() {
for entry in connections {
let last_update = entry.last_update();
if last_update < oldest {
oldest = last_update;
oldest_ip = Some(*ip);
}
}
}
self.table.remove(&oldest_ip.unwrap());
self.total_size -= 1;
num_pruned += 1;
}
num_pruned
}
fn try_add_connection(
&mut self,
addr: &SocketAddr,
last_update: u64,
max_connections_per_ip: usize,
) -> Option<(Arc<AtomicU64>, Arc<AtomicBool>)> {
let connection_entry = self.table.entry(addr.ip()).or_insert_with(Vec::new);
let has_connection_capacity = connection_entry
.len()
.checked_add(1)
.map(|c| c <= max_connections_per_ip)
.unwrap_or(false);
if has_connection_capacity {
let exit = Arc::new(AtomicBool::new(false));
let last_update = Arc::new(AtomicU64::new(last_update));
connection_entry.push(ConnectionEntry::new(
exit.clone(),
last_update.clone(),
addr.port(),
));
self.total_size += 1;
Some((last_update, exit))
} else {
None
}
}
fn remove_connection(&mut self, addr: &SocketAddr) {
if let Entry::Occupied(mut e) = self.table.entry(addr.ip()) {
let e_ref = e.get_mut();
e_ref.retain(|connection| connection.port != addr.port());
if e_ref.is_empty() {
e.remove_entry();
}
self.total_size -= 1;
}
}
}
#[derive(Default)]
struct StreamStats {
total_connections: AtomicUsize,
total_new_connections: AtomicUsize,
total_streams: AtomicUsize,
total_new_streams: AtomicUsize,
num_evictions: AtomicUsize,
}
impl StreamStats {
fn report(&self) {
datapoint_info!(
"quic-connections",
(
"active_connections",
self.total_connections.load(Ordering::Relaxed),
i64
),
(
"active_streams",
self.total_streams.load(Ordering::Relaxed),
i64
),
(
"new_connections",
self.total_new_connections.swap(0, Ordering::Relaxed),
i64
),
(
"new_streams",
self.total_new_streams.swap(0, Ordering::Relaxed),
i64
),
(
"evictions",
self.num_evictions.swap(0, Ordering::Relaxed),
i64
),
);
}
}
fn handle_connection(
mut uni_streams: IncomingUniStreams,
packet_sender: Sender<PacketBatch>,
remote_addr: SocketAddr,
last_update: Arc<AtomicU64>,
connection_table: Arc<Mutex<ConnectionTable>>,
stream_exit: Arc<AtomicBool>,
stats: Arc<StreamStats>,
) {
tokio::spawn(async move {
debug!(
"quic new connection {} streams: {} connections: {}",
remote_addr,
stats.total_streams.load(Ordering::Relaxed),
stats.total_connections.load(Ordering::Relaxed),
);
while !stream_exit.load(Ordering::Relaxed) {
match uni_streams.next().await {
Some(stream_result) => match stream_result {
Ok(mut stream) => {
stats.total_streams.fetch_add(1, Ordering::Relaxed);
stats.total_new_streams.fetch_add(1, Ordering::Relaxed);
let mut maybe_batch = None;
while !stream_exit.load(Ordering::Relaxed) {
if handle_chunk(
&stream.read_chunk(PACKET_DATA_SIZE, false).await,
&mut maybe_batch,
&remote_addr,
&packet_sender,
) {
last_update.store(timing::timestamp(), Ordering::Relaxed);
break;
}
}
}
Err(e) => {
debug!("stream error: {:?}", e);
stats.total_streams.fetch_sub(1, Ordering::Relaxed);
break;
}
},
None => {
stats.total_streams.fetch_sub(1, Ordering::Relaxed);
break;
}
}
}
connection_table
.lock()
.unwrap()
.remove_connection(&remote_addr);
stats.total_connections.fetch_sub(1, Ordering::Relaxed);
});
}
pub fn spawn_server(
sock: UdpSocket,
keypair: &Keypair,
gossip_host: IpAddr,
packet_sender: Sender<PacketBatch>,
exit: Arc<AtomicBool>,
max_connections_per_ip: usize,
) -> Result<thread::JoinHandle<()>, QuicServerError> {
let (config, _cert) = configure_server(keypair, gossip_host)?;
@ -206,8 +407,13 @@ pub fn spawn_server(
.map_err(|_e| QuicServerError::EndpointFailed)?
};
let stats = Arc::new(StreamStats::default());
let handle = thread::spawn(move || {
let handle = runtime.spawn(async move {
debug!("spawn quic server");
let mut last_datapoint = Instant::now();
let connection_table: Arc<Mutex<ConnectionTable>> =
Arc::new(Mutex::new(ConnectionTable::default()));
while !exit.load(Ordering::Relaxed) {
const WAIT_FOR_CONNECTION_TIMEOUT_MS: u64 = 1000;
let timeout_connection = timeout(
@ -216,33 +422,49 @@ pub fn spawn_server(
)
.await;
if last_datapoint.elapsed().as_secs() >= 5 {
stats.report();
last_datapoint = Instant::now();
}
if let Ok(Some(connection)) = timeout_connection {
if let Ok(new_connection) = connection.await {
let exit = exit.clone();
stats.total_connections.fetch_add(1, Ordering::Relaxed);
stats.total_new_connections.fetch_add(1, Ordering::Relaxed);
let quinn::NewConnection {
connection,
mut uni_streams,
uni_streams,
..
} = new_connection;
let remote_addr = connection.remote_address();
let packet_sender = packet_sender.clone();
tokio::spawn(async move {
debug!("new connection {}", remote_addr);
while let Some(Ok(mut stream)) = uni_streams.next().await {
let mut maybe_batch = None;
while !exit.load(Ordering::Relaxed) {
if handle_chunk(
&stream.read_chunk(PACKET_DATA_SIZE, false).await,
&mut maybe_batch,
&remote_addr,
&packet_sender,
) {
break;
}
}
}
});
let mut connection_table_l = connection_table.lock().unwrap();
const MAX_CONNECTION_TABLE_SIZE: usize = 5000;
let num_pruned = connection_table_l.prune_oldest(MAX_CONNECTION_TABLE_SIZE);
stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed);
if let Some((last_update, stream_exit)) = connection_table_l
.try_add_connection(
&remote_addr,
timing::timestamp(),
max_connections_per_ip,
)
{
drop(connection_table_l);
let packet_sender = packet_sender.clone();
let stats = stats.clone();
let connection_table1 = connection_table.clone();
handle_connection(
uni_streams,
packet_sender,
remote_addr,
last_update,
connection_table1,
stream_exit,
stats,
);
}
}
}
}
@ -300,7 +522,7 @@ mod test {
let (sender, _receiver) = unbounded();
let keypair = Keypair::new();
let ip = "127.0.0.1".parse().unwrap();
let t = spawn_server(s, &keypair, ip, sender, exit.clone()).unwrap();
let t = spawn_server(s, &keypair, ip, sender, exit.clone(), 1).unwrap();
exit.store(true, Ordering::Relaxed);
t.join().unwrap();
}
@ -316,6 +538,35 @@ mod test {
.unwrap()
}
#[test]
fn test_quic_server_block_multiple_connections() {
solana_logger::setup();
let s = UdpSocket::bind("127.0.0.1:0").unwrap();
let exit = Arc::new(AtomicBool::new(false));
let (sender, _receiver) = unbounded();
let keypair = Keypair::new();
let ip = "127.0.0.1".parse().unwrap();
let server_address = s.local_addr().unwrap();
let t = spawn_server(s, &keypair, ip, sender, exit.clone(), 1).unwrap();
let runtime = rt();
let _rt_guard = runtime.enter();
let conn1 = make_client_endpoint(&runtime, &server_address);
let conn2 = make_client_endpoint(&runtime, &server_address);
let handle = runtime.spawn(async move {
let mut s1 = conn1.connection.open_uni().await.unwrap();
let mut s2 = conn2.connection.open_uni().await.unwrap();
s1.write_all(&[0u8]).await.unwrap();
s1.finish().await.unwrap();
s2.write_all(&[0u8])
.await
.expect_err("shouldn't be able to open 2 connections");
});
runtime.block_on(handle).unwrap();
exit.store(true, Ordering::Relaxed);
t.join().unwrap();
}
#[test]
fn test_quic_server_multiple_streams() {
solana_logger::setup();
@ -325,7 +576,7 @@ mod test {
let keypair = Keypair::new();
let ip = "127.0.0.1".parse().unwrap();
let server_address = s.local_addr().unwrap();
let t = spawn_server(s, &keypair, ip, sender, exit.clone()).unwrap();
let t = spawn_server(s, &keypair, ip, sender, exit.clone(), 2).unwrap();
let runtime = rt();
let _rt_guard = runtime.enter();
@ -380,7 +631,7 @@ mod test {
let keypair = Keypair::new();
let ip = "127.0.0.1".parse().unwrap();
let server_address = s.local_addr().unwrap();
let t = spawn_server(s, &keypair, ip, sender, exit.clone()).unwrap();
let t = spawn_server(s, &keypair, ip, sender, exit.clone(), 1).unwrap();
let runtime = rt();
let _rt_guard = runtime.enter();
@ -420,4 +671,37 @@ mod test {
exit.store(true, Ordering::Relaxed);
t.join().unwrap();
}
#[test]
fn test_prune_table() {
use std::net::Ipv4Addr;
solana_logger::setup();
let mut table = ConnectionTable::default();
let num_entries = 5;
let max_connections_per_ip = 10;
let sockets: Vec<_> = (0..num_entries)
.into_iter()
.map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0))
.collect();
for (i, socket) in sockets.iter().enumerate() {
table
.try_add_connection(socket, i as u64, max_connections_per_ip)
.unwrap();
}
let new_size = 3;
let pruned = table.prune_oldest(new_size);
assert_eq!(pruned, num_entries as usize - new_size);
for v in table.table.values() {
for x in v {
assert!(x.last_update() >= (num_entries as u64 - new_size as u64));
}
}
assert_eq!(table.table.len(), new_size);
assert_eq!(table.total_size, new_size);
for socket in sockets.iter().take(num_entries as usize).skip(new_size - 1) {
table.remove_connection(socket);
}
info!("{:?}", table);
assert_eq!(table.total_size, 0);
}
}