separates out routing repair requests from establishing connections (#33742)

Currently each outgoing repair request will attempt to establish a
connection if one does not already exist. This is very wasteful and
consumes many tokio tasks if the remote node is down or unresponsive.

The commit decouples routing packets from establishing connections by
adding a buffering channel for each remote address. Outgoing packets are
always sent down this channel to be processed once the connection is
established. If connecting attempt fails, all packets already pushed to
the channel are dropped at once, reducing the number of attempts to make
a connection if the remote node is down or unresponsive.
This commit is contained in:
behzad nouri 2023-10-19 13:25:53 +00:00 committed by GitHub
parent c1353e172c
commit 7aa0faea96
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 168 additions and 112 deletions

View File

@ -21,16 +21,15 @@ use {
collections::{hash_map::Entry, HashMap},
io::{Cursor, Error as IoError},
net::{IpAddr, SocketAddr, UdpSocket},
ops::Deref,
sync::Arc,
time::Duration,
},
thiserror::Error,
tokio::{
sync::{
mpsc::{Receiver as AsyncReceiver, Sender as AsyncSender},
mpsc::{error::TrySendError, Receiver as AsyncReceiver, Sender as AsyncSender},
oneshot::Sender as OneShotSender,
RwLock,
Mutex, RwLock as AsyncRwLock,
},
task::JoinHandle,
},
@ -39,7 +38,8 @@ use {
const ALPN_REPAIR_PROTOCOL_ID: &[u8] = b"solana-repair";
const CONNECT_SERVER_NAME: &str = "solana-repair";
const CLIENT_CHANNEL_CAPACITY: usize = 1 << 14;
const CLIENT_CHANNEL_BUFFER: usize = 1 << 14;
const ROUTER_CHANNEL_BUFFER: usize = 64;
const CONNECTION_CACHE_CAPACITY: usize = 4096;
const MAX_CONCURRENT_BIDI_STREAMS: VarInt = VarInt::from_u32(512);
@ -54,7 +54,6 @@ const CONNECTION_CLOSE_REASON_INVALID_IDENTITY: &[u8] = b"INVALID_IDENTITY";
const CONNECTION_CLOSE_REASON_REPLACED: &[u8] = b"REPLACED";
pub(crate) type AsyncTryJoinHandle = TryJoin<JoinHandle<()>, JoinHandle<()>>;
type ConnectionCache = HashMap<(SocketAddr, Option<Pubkey>), Arc<RwLock<Option<Connection>>>>;
// Outgoing local requests.
pub struct LocalRequest {
@ -125,17 +124,20 @@ pub(crate) fn new_quic_endpoint(
)?
};
endpoint.set_default_client_config(client_config);
let cache = Arc::<RwLock<ConnectionCache>>::default();
let (client_sender, client_receiver) = tokio::sync::mpsc::channel(CLIENT_CHANNEL_CAPACITY);
let cache = Arc::<Mutex<HashMap<Pubkey, Connection>>>::default();
let (client_sender, client_receiver) = tokio::sync::mpsc::channel(CLIENT_CHANNEL_BUFFER);
let router = Arc::<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>::default();
let server_task = runtime.spawn(run_server(
endpoint.clone(),
remote_request_sender.clone(),
router.clone(),
cache.clone(),
));
let client_task = runtime.spawn(run_client(
endpoint.clone(),
client_receiver,
remote_request_sender,
router,
cache,
));
let task = futures::future::try_join(server_task, client_task);
@ -187,13 +189,15 @@ fn new_transport_config() -> TransportConfig {
async fn run_server(
endpoint: Endpoint,
remote_request_sender: Sender<RemoteRequest>,
cache: Arc<RwLock<ConnectionCache>>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
while let Some(connecting) = endpoint.accept().await {
tokio::task::spawn(handle_connecting_error(
endpoint.clone(),
connecting,
remote_request_sender.clone(),
router.clone(),
cache.clone(),
));
}
@ -203,26 +207,68 @@ async fn run_client(
endpoint: Endpoint,
mut receiver: AsyncReceiver<LocalRequest>,
remote_request_sender: Sender<RemoteRequest>,
cache: Arc<RwLock<ConnectionCache>>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
while let Some(request) = receiver.recv().await {
tokio::task::spawn(send_request_task(
let Some(request) = try_route_request(request, &*router.read().await) else {
continue;
};
let remote_address = request.remote_address;
let receiver = {
let mut router = router.write().await;
let Some(request) = try_route_request(request, &router) else {
continue;
};
let (sender, receiver) = tokio::sync::mpsc::channel(ROUTER_CHANNEL_BUFFER);
sender.try_send(request).unwrap();
router.insert(remote_address, sender);
receiver
};
tokio::task::spawn(make_connection_task(
endpoint.clone(),
request,
remote_address,
remote_request_sender.clone(),
receiver,
router.clone(),
cache.clone(),
));
}
close_quic_endpoint(&endpoint);
// Drop sender channels to unblock threads waiting on the receiving end.
router.write().await.clear();
}
// Routes the local request to respective channel. Drops the request if the
// channel is full. Bounces the request back if the channel is closed or does
// not exist.
fn try_route_request(
request: LocalRequest,
router: &HashMap<SocketAddr, AsyncSender<LocalRequest>>,
) -> Option<LocalRequest> {
match router.get(&request.remote_address) {
None => Some(request),
Some(sender) => match sender.try_send(request) {
Ok(()) => None,
Err(TrySendError::Full(request)) => {
error!("TrySendError::Full {}", request.remote_address);
None
}
Err(TrySendError::Closed(request)) => Some(request),
},
}
}
async fn handle_connecting_error(
endpoint: Endpoint,
connecting: Connecting,
remote_request_sender: Sender<RemoteRequest>,
cache: Arc<RwLock<ConnectionCache>>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
if let Err(err) = handle_connecting(endpoint, connecting, remote_request_sender, cache).await {
if let Err(err) =
handle_connecting(endpoint, connecting, remote_request_sender, router, cache).await
{
error!("handle_connecting: {err:?}");
}
}
@ -231,52 +277,75 @@ async fn handle_connecting(
endpoint: Endpoint,
connecting: Connecting,
remote_request_sender: Sender<RemoteRequest>,
cache: Arc<RwLock<ConnectionCache>>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) -> Result<(), Error> {
let connection = connecting.await?;
let remote_address = connection.remote_address();
let remote_pubkey = get_remote_pubkey(&connection)?;
handle_connection_error(
let receiver = {
let (sender, receiver) = tokio::sync::mpsc::channel(ROUTER_CHANNEL_BUFFER);
router.write().await.insert(remote_address, sender);
receiver
};
handle_connection(
endpoint,
remote_address,
remote_pubkey,
connection,
remote_request_sender,
receiver,
router,
cache,
)
.await;
Ok(())
}
async fn handle_connection_error(
async fn handle_connection(
endpoint: Endpoint,
remote_address: SocketAddr,
remote_pubkey: Pubkey,
connection: Connection,
remote_request_sender: Sender<RemoteRequest>,
cache: Arc<RwLock<ConnectionCache>>,
receiver: AsyncReceiver<LocalRequest>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
cache_connection(remote_address, remote_pubkey, connection.clone(), &cache).await;
if let Err(err) = handle_connection(
&endpoint,
cache_connection(remote_pubkey, connection.clone(), &cache).await;
let send_requests_task = tokio::task::spawn(send_requests_task(
endpoint.clone(),
connection.clone(),
receiver,
));
let recv_requests_task = tokio::task::spawn(recv_requests_task(
endpoint,
remote_address,
remote_pubkey,
&connection,
&remote_request_sender,
)
.await
{
drop_connection(remote_address, remote_pubkey, &connection, &cache).await;
error!("handle_connection: {remote_pubkey}, {remote_address}, {err:?}");
connection.clone(),
remote_request_sender,
));
match futures::future::try_join(send_requests_task, recv_requests_task).await {
Err(err) => error!("handle_connection: {remote_pubkey}, {remote_address}, {err:?}"),
Ok(((), Err(ref err))) => {
error!("recv_requests_task: {remote_pubkey}, {remote_address}, {err:?}");
}
Ok(((), Ok(()))) => (),
}
drop_connection(remote_pubkey, &connection, &cache).await;
if let Entry::Occupied(entry) = router.write().await.entry(remote_address) {
if entry.get().is_closed() {
entry.remove();
}
}
}
async fn handle_connection(
endpoint: &Endpoint,
async fn recv_requests_task(
endpoint: Endpoint,
remote_address: SocketAddr,
remote_pubkey: Pubkey,
connection: &Connection,
remote_request_sender: &Sender<RemoteRequest>,
connection: Connection,
remote_request_sender: Sender<RemoteRequest>,
) -> Result<(), Error> {
loop {
let (send_stream, recv_stream) = connection.accept_bi().await?;
@ -352,32 +421,39 @@ async fn handle_streams(
send_stream.finish().await.map_err(Error::from)
}
async fn send_request_task(
async fn send_requests_task(
endpoint: Endpoint,
request: LocalRequest,
remote_request_sender: Sender<RemoteRequest>,
cache: Arc<RwLock<ConnectionCache>>,
connection: Connection,
mut receiver: AsyncReceiver<LocalRequest>,
) {
if let Err(err) = send_request(&endpoint, request, remote_request_sender, cache).await {
error!("send_request_task: {err:?}");
while let Some(request) = receiver.recv().await {
tokio::task::spawn(send_request_task(
endpoint.clone(),
connection.clone(),
request,
));
}
}
async fn send_request_task(endpoint: Endpoint, connection: Connection, request: LocalRequest) {
if let Err(err) = send_request(endpoint, connection, request).await {
error!("send_request: {err:?}")
}
}
async fn send_request(
endpoint: &Endpoint,
endpoint: Endpoint,
connection: Connection,
LocalRequest {
remote_address,
remote_address: _,
bytes,
num_expected_responses,
response_sender,
}: LocalRequest,
remote_request_sender: Sender<RemoteRequest>,
cache: Arc<RwLock<ConnectionCache>>,
) -> Result<(), Error> {
// Assert that send won't block.
debug_assert_eq!(response_sender.capacity(), None);
const READ_TIMEOUT_DURATION: Duration = Duration::from_secs(10);
let connection = get_connection(endpoint, remote_address, remote_request_sender, cache).await?;
let (mut send_stream, mut recv_stream) = connection.open_bi().await?;
send_stream.write_all(&bytes).await?;
send_stream.finish().await?;
@ -405,50 +481,57 @@ async fn send_request(
response_sender
.send((remote_address, chunk))
.map_err(|err| {
close_quic_endpoint(endpoint);
close_quic_endpoint(&endpoint);
Error::from(err)
})
})
}
async fn get_connection(
endpoint: &Endpoint,
async fn make_connection_task(
endpoint: Endpoint,
remote_address: SocketAddr,
remote_request_sender: Sender<RemoteRequest>,
cache: Arc<RwLock<ConnectionCache>>,
) -> Result<Connection, Error> {
let entry = get_cache_entry(remote_address, &cache).await;
receiver: AsyncReceiver<LocalRequest>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
if let Err(err) = make_connection(
endpoint,
remote_address,
remote_request_sender,
receiver,
router,
cache,
)
.await
{
let connection: Option<Connection> = entry.read().await.clone();
if let Some(connection) = connection {
if connection.close_reason().is_none() {
return Ok(connection);
}
}
error!("make_connection: {remote_address}, {err:?}");
}
let connection = {
// Need to write lock here so that only one task initiates
// a new connection to the same remote_address.
let mut entry = entry.write().await;
if let Some(connection) = entry.deref() {
if connection.close_reason().is_none() {
return Ok(connection.clone());
}
}
let connection = endpoint
.connect(remote_address, CONNECT_SERVER_NAME)?
.await?;
entry.insert(connection).clone()
};
tokio::task::spawn(handle_connection_error(
endpoint.clone(),
}
async fn make_connection(
endpoint: Endpoint,
remote_address: SocketAddr,
remote_request_sender: Sender<RemoteRequest>,
receiver: AsyncReceiver<LocalRequest>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) -> Result<(), Error> {
let connection = endpoint
.connect(remote_address, CONNECT_SERVER_NAME)?
.await?;
handle_connection(
endpoint,
connection.remote_address(),
get_remote_pubkey(&connection)?,
connection.clone(),
connection,
remote_request_sender,
receiver,
router,
cache,
));
Ok(connection)
)
.await;
Ok(())
}
fn get_remote_pubkey(connection: &Connection) -> Result<Pubkey, Error> {
@ -464,27 +547,13 @@ fn get_remote_pubkey(connection: &Connection) -> Result<Pubkey, Error> {
}
}
async fn get_cache_entry(
remote_address: SocketAddr,
cache: &RwLock<ConnectionCache>,
) -> Arc<RwLock<Option<Connection>>> {
let key = (remote_address, /*remote_pubkey:*/ None);
if let Some(entry) = cache.read().await.get(&key) {
return entry.clone();
}
cache.write().await.entry(key).or_default().clone()
}
async fn cache_connection(
remote_address: SocketAddr,
remote_pubkey: Pubkey,
connection: Connection,
cache: &RwLock<ConnectionCache>,
cache: &Mutex<HashMap<Pubkey, Connection>>,
) {
// The 2nd cache entry with remote_pubkey == None allows to lookup an entry
// only by SocketAddr when establishing outgoing connections.
let entries: [Arc<RwLock<Option<Connection>>>; 2] = {
let mut cache = cache.write().await;
let old = {
let mut cache = cache.lock().await;
if cache.len() >= CONNECTION_CACHE_CAPACITY {
connection.close(
CONNECTION_CLOSE_ERROR_CODE_DROPPED,
@ -492,15 +561,9 @@ async fn cache_connection(
);
return;
}
[Some(remote_pubkey), None].map(|remote_pubkey| {
let key = (remote_address, remote_pubkey);
cache.entry(key).or_default().clone()
})
cache.insert(remote_pubkey, connection)
};
let mut entry = entries[0].write().await;
*entries[1].write().await = Some(connection.clone());
if let Some(old) = entry.replace(connection) {
drop(entry);
if let Some(old) = old {
old.close(
CONNECTION_CLOSE_ERROR_CODE_REPLACED,
CONNECTION_CLOSE_REASON_REPLACED,
@ -509,26 +572,19 @@ async fn cache_connection(
}
async fn drop_connection(
remote_address: SocketAddr,
remote_pubkey: Pubkey,
connection: &Connection,
cache: &RwLock<ConnectionCache>,
cache: &Mutex<HashMap<Pubkey, Connection>>,
) {
if connection.close_reason().is_none() {
connection.close(
CONNECTION_CLOSE_ERROR_CODE_DROPPED,
CONNECTION_CLOSE_REASON_DROPPED,
);
}
let key = (remote_address, Some(remote_pubkey));
if let Entry::Occupied(entry) = cache.write().await.entry(key) {
if matches!(entry.get().read().await.deref(),
Some(entry) if entry.stable_id() == connection.stable_id())
{
connection.close(
CONNECTION_CLOSE_ERROR_CODE_DROPPED,
CONNECTION_CLOSE_REASON_DROPPED,
);
if let Entry::Occupied(entry) = cache.lock().await.entry(remote_pubkey) {
if entry.get().stable_id() == connection.stable_id() {
entry.remove();
}
}
// Cache entry for (remote_address, None) will be lazily evicted.
}
impl<T> From<crossbeam_channel::SendError<T>> for Error {