From de027a895b4a72895f916802d2203cedd8d4d2fa Mon Sep 17 00:00:00 2001 From: Pankaj Garg Date: Mon, 2 May 2022 13:02:49 -0700 Subject: [PATCH] Use RwLock instead of Mutex in QUIC connection cache (#24857) * Use RwLock instead of Mutex in QUIC connection cache * replace LruCache with HashMap * fix tests * fix tests * refactor * add cache eviction for a random connection on reaching upperbound * cleanup --- client/src/connection_cache.rs | 179 ++++++++++++++++++++++----------- 1 file changed, 119 insertions(+), 60 deletions(-) diff --git a/client/src/connection_cache.rs b/client/src/connection_cache.rs index d77ff18eeb..246b1339d5 100644 --- a/client/src/connection_cache.rs +++ b/client/src/connection_cache.rs @@ -5,17 +5,19 @@ use { udp_client::UdpTpuConnection, }, lazy_static::lazy_static, - lru::LruCache, + quinn_proto::ConnectionStats, + rand::{thread_rng, Rng}, solana_measure::measure::Measure, solana_net_utils::VALIDATOR_PORT_RANGE, solana_sdk::{ timing::AtomicInterval, transaction::VersionedTransaction, transport::TransportError, }, std::{ + collections::BTreeMap, net::{IpAddr, Ipv4Addr, SocketAddr}, sync::{ atomic::{AtomicU64, Ordering}, - Arc, Mutex, + Arc, RwLock, }, }, }; @@ -158,7 +160,7 @@ impl ConnectionCacheStats { } struct ConnectionMap { - map: LruCache, + map: BTreeMap, stats: Arc, last_stats: AtomicInterval, use_quic: bool, @@ -167,7 +169,7 @@ struct ConnectionMap { impl ConnectionMap { pub fn new() -> Self { Self { - map: LruCache::new(MAX_CONNECTIONS), + map: BTreeMap::new(), stats: Arc::new(ConnectionCacheStats::default()), last_stats: AtomicInterval::default(), use_quic: false, @@ -180,38 +182,44 @@ impl ConnectionMap { } lazy_static! { - static ref CONNECTION_MAP: Mutex = Mutex::new(ConnectionMap::new()); + static ref CONNECTION_MAP: RwLock = RwLock::new(ConnectionMap::new()); } pub fn set_use_quic(use_quic: bool) { - let mut map = (*CONNECTION_MAP).lock().unwrap(); + let mut map = (*CONNECTION_MAP).write().unwrap(); map.set_use_quic(use_quic); } -// TODO: see https://github.com/solana-labs/solana/issues/23661 -// remove lazy_static and optimize and refactor this -fn get_connection(addr: &SocketAddr) -> (Connection, Arc) { - let mut get_connection_measure = Measure::start("get_connection_measure"); +struct GetConnectionResult { + connection: Connection, + cache_hit: bool, + report_stats: bool, + map_timing: u64, + lock_timing: u64, + connection_cache_stats: Arc, + other_stats: Option<(Arc, ConnectionStats)>, +} + +fn get_or_add_connection(addr: &SocketAddr) -> GetConnectionResult { let mut get_connection_map_lock_measure = Measure::start("get_connection_map_lock_measure"); - let mut map = (*CONNECTION_MAP).lock().unwrap(); + let map = (*CONNECTION_MAP).read().unwrap(); get_connection_map_lock_measure.stop(); - if map + let mut lock_timing = get_connection_map_lock_measure.as_ms(); + + let report_stats = map .last_stats - .should_update(CONNECTION_STAT_SUBMISSION_INTERVAL) - { - map.stats.report(); - } + .should_update(CONNECTION_STAT_SUBMISSION_INTERVAL); let mut get_connection_map_measure = Measure::start("get_connection_hit_measure"); - let (connection, hit, maybe_stats) = match map.map.get(addr) { + let (connection, cache_hit, connection_cache_stats, maybe_stats) = match map.map.get(addr) { Some(connection) => { let mut stats = None; // update connection stats if let Connection::Quic(conn) = connection { stats = conn.stats().map(|s| (conn.base_stats(), s)); } - (connection.clone(), true, stats) + (connection.clone(), true, map.stats.clone(), stats) } None => { let (_, send_socket) = solana_net_utils::bind_in_range( @@ -219,25 +227,77 @@ fn get_connection(addr: &SocketAddr) -> (Connection, Arc) VALIDATOR_PORT_RANGE, ) .unwrap(); + let connection = if map.use_quic { Connection::Quic(Arc::new(QuicTpuConnection::new(send_socket, *addr))) } else { Connection::Udp(Arc::new(UdpTpuConnection::new(send_socket, *addr))) }; - map.map.put(*addr, connection.clone()); - (connection, false, None) + // Upgrade to write access by dropping read lock and acquire write lock + drop(map); + let mut get_connection_map_lock_measure = + Measure::start("get_connection_map_lock_measure"); + let mut map = (*CONNECTION_MAP).write().unwrap(); + get_connection_map_lock_measure.stop(); + + lock_timing = lock_timing.saturating_add(get_connection_map_lock_measure.as_ms()); + + // evict a connection if the map is reaching upper bounds + while map.map.len() >= MAX_CONNECTIONS { + let mut rng = thread_rng(); + let n = rng.gen_range(0, MAX_CONNECTIONS); + if let Some((nth_addr, _)) = map.map.iter().nth(n) { + let nth_addr = *nth_addr; + map.map.remove(&nth_addr); + } + } + + map.map.insert(*addr, connection.clone()); + (connection, false, map.stats.clone(), None) } }; get_connection_map_measure.stop(); - if let Some((connection_stats, new_stats)) = maybe_stats { - map.stats.total_client_stats.congestion_events.update_stat( - &connection_stats.congestion_events, - new_stats.path.congestion_events, - ); + GetConnectionResult { + connection, + cache_hit, + report_stats, + map_timing: get_connection_map_measure.as_ms(), + lock_timing, + connection_cache_stats, + other_stats: maybe_stats, + } +} - map.stats +// TODO: see https://github.com/solana-labs/solana/issues/23661 +// remove lazy_static and optimize and refactor this +fn get_connection(addr: &SocketAddr) -> (Connection, Arc) { + let mut get_connection_measure = Measure::start("get_connection_measure"); + let GetConnectionResult { + connection, + cache_hit, + report_stats, + map_timing, + lock_timing, + connection_cache_stats, + other_stats, + } = get_or_add_connection(addr); + + if report_stats { + connection_cache_stats.report(); + } + + if let Some((connection_stats, new_stats)) = other_stats { + connection_cache_stats + .total_client_stats + .congestion_events + .update_stat( + &connection_stats.congestion_events, + new_stats.path.congestion_events, + ); + + connection_cache_stats .total_client_stats .tx_streams_blocked_uni .update_stat( @@ -245,36 +305,44 @@ fn get_connection(addr: &SocketAddr) -> (Connection, Arc) new_stats.frame_tx.streams_blocked_uni, ); - map.stats.total_client_stats.tx_data_blocked.update_stat( - &connection_stats.tx_data_blocked, - new_stats.frame_tx.data_blocked, - ); + connection_cache_stats + .total_client_stats + .tx_data_blocked + .update_stat( + &connection_stats.tx_data_blocked, + new_stats.frame_tx.data_blocked, + ); - map.stats + connection_cache_stats .total_client_stats .tx_acks .update_stat(&connection_stats.tx_acks, new_stats.frame_tx.acks); } - if hit { - map.stats.cache_hits.fetch_add(1, Ordering::Relaxed); - map.stats + if cache_hit { + connection_cache_stats + .cache_hits + .fetch_add(1, Ordering::Relaxed); + connection_cache_stats .get_connection_hit_ms - .fetch_add(get_connection_map_measure.as_us(), Ordering::Relaxed); + .fetch_add(map_timing, Ordering::Relaxed); } else { - map.stats.cache_misses.fetch_add(1, Ordering::Relaxed); - map.stats + connection_cache_stats + .cache_misses + .fetch_add(1, Ordering::Relaxed); + connection_cache_stats .get_connection_miss_ms - .fetch_add(get_connection_map_measure.as_us(), Ordering::Relaxed); + .fetch_add(map_timing, Ordering::Relaxed); } + get_connection_measure.stop(); - map.stats + connection_cache_stats .get_connection_lock_ms - .fetch_add(get_connection_map_lock_measure.as_us(), Ordering::Relaxed); - map.stats + .fetch_add(lock_timing, Ordering::Relaxed); + connection_cache_stats .get_connection_ms - .fetch_add(get_connection_measure.as_us(), Ordering::Relaxed); - (connection, map.stats.clone()) + .fetch_add(get_connection_measure.as_ms(), Ordering::Relaxed); + (connection, connection_cache_stats) } // TODO: see https://github.com/solana-labs/solana/issues/23851 @@ -417,8 +485,6 @@ mod tests { // we can actually connect to those addresses - TPUConnection implementations should either // be lazy and not connect until first use or handle connection errors somehow // (without crashing, as would be required in a real practical validator) - let first_addr = get_addr(&mut rng); - assert!(ip(get_connection(&first_addr).0) == first_addr.ip()); let addrs = (0..MAX_CONNECTIONS) .into_iter() .map(|_| { @@ -428,26 +494,19 @@ mod tests { }) .collect::>(); { - let map = (*CONNECTION_MAP).lock().unwrap(); + let map = (*CONNECTION_MAP).read().unwrap(); + assert!(map.map.len() == MAX_CONNECTIONS); addrs.iter().for_each(|a| { - let conn = map.map.peek(a).expect("Address not found"); + let conn = map.map.get(a).expect("Address not found"); assert!(a.ip() == ip(conn.clone())); }); - - assert!(map.map.peek(&first_addr).is_none()); } - // Test that get_connection updates which connection is next up for eviction - // when an existing connection is used. Initially, addrs[0] should be next up for eviction, since - // it was the earliest added. But we do get_connection(&addrs[0]), thereby using - // that connection, and bumping it back to the end of the queue. So addrs[1] should be - // the next up for eviction. So we add a new connection, and test that addrs[0] is not - // evicted but addrs[1] is. - get_connection(&addrs[0]); - get_connection(&get_addr(&mut rng)); + let addr = get_addr(&mut rng); + get_connection(&addr); - let map = (*CONNECTION_MAP).lock().unwrap(); - assert!(map.map.peek(&addrs[0]).is_some()); - assert!(map.map.peek(&addrs[1]).is_none()); + let map = (*CONNECTION_MAP).read().unwrap(); + assert!(map.map.len() == MAX_CONNECTIONS); + let _conn = map.map.get(&addr).expect("Address not found"); } }