Use RwLock instead of Mutex in QUIC connection cache (#24857)

* Use RwLock instead of Mutex in QUIC connection cache

* replace LruCache with HashMap

* fix tests

* fix tests

* refactor

* add cache eviction for a random connection on reaching upperbound

* cleanup
This commit is contained in:
Pankaj Garg 2022-05-02 13:02:49 -07:00 committed by GitHub
parent fd46c69a17
commit de027a895b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 119 additions and 60 deletions

View File

@ -5,17 +5,19 @@ use {
udp_client::UdpTpuConnection,
},
lazy_static::lazy_static,
lru::LruCache,
quinn_proto::ConnectionStats,
rand::{thread_rng, Rng},
solana_measure::measure::Measure,
solana_net_utils::VALIDATOR_PORT_RANGE,
solana_sdk::{
timing::AtomicInterval, transaction::VersionedTransaction, transport::TransportError,
},
std::{
collections::BTreeMap,
net::{IpAddr, Ipv4Addr, SocketAddr},
sync::{
atomic::{AtomicU64, Ordering},
Arc, Mutex,
Arc, RwLock,
},
},
};
@ -158,7 +160,7 @@ impl ConnectionCacheStats {
}
struct ConnectionMap {
map: LruCache<SocketAddr, Connection>,
map: BTreeMap<SocketAddr, Connection>,
stats: Arc<ConnectionCacheStats>,
last_stats: AtomicInterval,
use_quic: bool,
@ -167,7 +169,7 @@ struct ConnectionMap {
impl ConnectionMap {
pub fn new() -> Self {
Self {
map: LruCache::new(MAX_CONNECTIONS),
map: BTreeMap::new(),
stats: Arc::new(ConnectionCacheStats::default()),
last_stats: AtomicInterval::default(),
use_quic: false,
@ -180,38 +182,44 @@ impl ConnectionMap {
}
lazy_static! {
static ref CONNECTION_MAP: Mutex<ConnectionMap> = Mutex::new(ConnectionMap::new());
static ref CONNECTION_MAP: RwLock<ConnectionMap> = RwLock::new(ConnectionMap::new());
}
pub fn set_use_quic(use_quic: bool) {
let mut map = (*CONNECTION_MAP).lock().unwrap();
let mut map = (*CONNECTION_MAP).write().unwrap();
map.set_use_quic(use_quic);
}
// TODO: see https://github.com/solana-labs/solana/issues/23661
// remove lazy_static and optimize and refactor this
fn get_connection(addr: &SocketAddr) -> (Connection, Arc<ConnectionCacheStats>) {
let mut get_connection_measure = Measure::start("get_connection_measure");
struct GetConnectionResult {
connection: Connection,
cache_hit: bool,
report_stats: bool,
map_timing: u64,
lock_timing: u64,
connection_cache_stats: Arc<ConnectionCacheStats>,
other_stats: Option<(Arc<ClientStats>, ConnectionStats)>,
}
fn get_or_add_connection(addr: &SocketAddr) -> GetConnectionResult {
let mut get_connection_map_lock_measure = Measure::start("get_connection_map_lock_measure");
let mut map = (*CONNECTION_MAP).lock().unwrap();
let map = (*CONNECTION_MAP).read().unwrap();
get_connection_map_lock_measure.stop();
if map
let mut lock_timing = get_connection_map_lock_measure.as_ms();
let report_stats = map
.last_stats
.should_update(CONNECTION_STAT_SUBMISSION_INTERVAL)
{
map.stats.report();
}
.should_update(CONNECTION_STAT_SUBMISSION_INTERVAL);
let mut get_connection_map_measure = Measure::start("get_connection_hit_measure");
let (connection, hit, maybe_stats) = match map.map.get(addr) {
let (connection, cache_hit, connection_cache_stats, maybe_stats) = match map.map.get(addr) {
Some(connection) => {
let mut stats = None;
// update connection stats
if let Connection::Quic(conn) = connection {
stats = conn.stats().map(|s| (conn.base_stats(), s));
}
(connection.clone(), true, stats)
(connection.clone(), true, map.stats.clone(), stats)
}
None => {
let (_, send_socket) = solana_net_utils::bind_in_range(
@ -219,25 +227,77 @@ fn get_connection(addr: &SocketAddr) -> (Connection, Arc<ConnectionCacheStats>)
VALIDATOR_PORT_RANGE,
)
.unwrap();
let connection = if map.use_quic {
Connection::Quic(Arc::new(QuicTpuConnection::new(send_socket, *addr)))
} else {
Connection::Udp(Arc::new(UdpTpuConnection::new(send_socket, *addr)))
};
map.map.put(*addr, connection.clone());
(connection, false, None)
// Upgrade to write access by dropping read lock and acquire write lock
drop(map);
let mut get_connection_map_lock_measure =
Measure::start("get_connection_map_lock_measure");
let mut map = (*CONNECTION_MAP).write().unwrap();
get_connection_map_lock_measure.stop();
lock_timing = lock_timing.saturating_add(get_connection_map_lock_measure.as_ms());
// evict a connection if the map is reaching upper bounds
while map.map.len() >= MAX_CONNECTIONS {
let mut rng = thread_rng();
let n = rng.gen_range(0, MAX_CONNECTIONS);
if let Some((nth_addr, _)) = map.map.iter().nth(n) {
let nth_addr = *nth_addr;
map.map.remove(&nth_addr);
}
}
map.map.insert(*addr, connection.clone());
(connection, false, map.stats.clone(), None)
}
};
get_connection_map_measure.stop();
if let Some((connection_stats, new_stats)) = maybe_stats {
map.stats.total_client_stats.congestion_events.update_stat(
&connection_stats.congestion_events,
new_stats.path.congestion_events,
);
GetConnectionResult {
connection,
cache_hit,
report_stats,
map_timing: get_connection_map_measure.as_ms(),
lock_timing,
connection_cache_stats,
other_stats: maybe_stats,
}
}
map.stats
// TODO: see https://github.com/solana-labs/solana/issues/23661
// remove lazy_static and optimize and refactor this
fn get_connection(addr: &SocketAddr) -> (Connection, Arc<ConnectionCacheStats>) {
let mut get_connection_measure = Measure::start("get_connection_measure");
let GetConnectionResult {
connection,
cache_hit,
report_stats,
map_timing,
lock_timing,
connection_cache_stats,
other_stats,
} = get_or_add_connection(addr);
if report_stats {
connection_cache_stats.report();
}
if let Some((connection_stats, new_stats)) = other_stats {
connection_cache_stats
.total_client_stats
.congestion_events
.update_stat(
&connection_stats.congestion_events,
new_stats.path.congestion_events,
);
connection_cache_stats
.total_client_stats
.tx_streams_blocked_uni
.update_stat(
@ -245,36 +305,44 @@ fn get_connection(addr: &SocketAddr) -> (Connection, Arc<ConnectionCacheStats>)
new_stats.frame_tx.streams_blocked_uni,
);
map.stats.total_client_stats.tx_data_blocked.update_stat(
&connection_stats.tx_data_blocked,
new_stats.frame_tx.data_blocked,
);
connection_cache_stats
.total_client_stats
.tx_data_blocked
.update_stat(
&connection_stats.tx_data_blocked,
new_stats.frame_tx.data_blocked,
);
map.stats
connection_cache_stats
.total_client_stats
.tx_acks
.update_stat(&connection_stats.tx_acks, new_stats.frame_tx.acks);
}
if hit {
map.stats.cache_hits.fetch_add(1, Ordering::Relaxed);
map.stats
if cache_hit {
connection_cache_stats
.cache_hits
.fetch_add(1, Ordering::Relaxed);
connection_cache_stats
.get_connection_hit_ms
.fetch_add(get_connection_map_measure.as_us(), Ordering::Relaxed);
.fetch_add(map_timing, Ordering::Relaxed);
} else {
map.stats.cache_misses.fetch_add(1, Ordering::Relaxed);
map.stats
connection_cache_stats
.cache_misses
.fetch_add(1, Ordering::Relaxed);
connection_cache_stats
.get_connection_miss_ms
.fetch_add(get_connection_map_measure.as_us(), Ordering::Relaxed);
.fetch_add(map_timing, Ordering::Relaxed);
}
get_connection_measure.stop();
map.stats
connection_cache_stats
.get_connection_lock_ms
.fetch_add(get_connection_map_lock_measure.as_us(), Ordering::Relaxed);
map.stats
.fetch_add(lock_timing, Ordering::Relaxed);
connection_cache_stats
.get_connection_ms
.fetch_add(get_connection_measure.as_us(), Ordering::Relaxed);
(connection, map.stats.clone())
.fetch_add(get_connection_measure.as_ms(), Ordering::Relaxed);
(connection, connection_cache_stats)
}
// TODO: see https://github.com/solana-labs/solana/issues/23851
@ -417,8 +485,6 @@ mod tests {
// we can actually connect to those addresses - TPUConnection implementations should either
// be lazy and not connect until first use or handle connection errors somehow
// (without crashing, as would be required in a real practical validator)
let first_addr = get_addr(&mut rng);
assert!(ip(get_connection(&first_addr).0) == first_addr.ip());
let addrs = (0..MAX_CONNECTIONS)
.into_iter()
.map(|_| {
@ -428,26 +494,19 @@ mod tests {
})
.collect::<Vec<_>>();
{
let map = (*CONNECTION_MAP).lock().unwrap();
let map = (*CONNECTION_MAP).read().unwrap();
assert!(map.map.len() == MAX_CONNECTIONS);
addrs.iter().for_each(|a| {
let conn = map.map.peek(a).expect("Address not found");
let conn = map.map.get(a).expect("Address not found");
assert!(a.ip() == ip(conn.clone()));
});
assert!(map.map.peek(&first_addr).is_none());
}
// Test that get_connection updates which connection is next up for eviction
// when an existing connection is used. Initially, addrs[0] should be next up for eviction, since
// it was the earliest added. But we do get_connection(&addrs[0]), thereby using
// that connection, and bumping it back to the end of the queue. So addrs[1] should be
// the next up for eviction. So we add a new connection, and test that addrs[0] is not
// evicted but addrs[1] is.
get_connection(&addrs[0]);
get_connection(&get_addr(&mut rng));
let addr = get_addr(&mut rng);
get_connection(&addr);
let map = (*CONNECTION_MAP).lock().unwrap();
assert!(map.map.peek(&addrs[0]).is_some());
assert!(map.map.peek(&addrs[1]).is_none());
let map = (*CONNECTION_MAP).read().unwrap();
assert!(map.map.len() == MAX_CONNECTIONS);
let _conn = map.map.get(&addr).expect("Address not found");
}
}