use { bincode::Options, crossbeam_channel::Sender, futures::future::TryJoin, itertools::Itertools, log::error, quinn::{ ClientConfig, ConnectError, Connecting, Connection, ConnectionError, Endpoint, EndpointConfig, ReadToEndError, RecvStream, SendStream, ServerConfig, TokioRuntime, TransportConfig, VarInt, WriteError, }, rcgen::RcgenError, rustls::{Certificate, PrivateKey}, serde_bytes::ByteBuf, solana_quic_client::nonblocking::quic_client::SkipServerVerification, solana_runtime::bank_forks::BankForks, solana_sdk::{packet::PACKET_DATA_SIZE, pubkey::Pubkey, signature::Keypair}, solana_streamer::{ quic::SkipClientVerification, tls_certificates::new_self_signed_tls_certificate, }, std::{ cmp::Reverse, collections::{hash_map::Entry, HashMap}, io::{Cursor, Error as IoError}, net::{IpAddr, SocketAddr, UdpSocket}, sync::{ atomic::{AtomicBool, Ordering}, Arc, RwLock, }, time::Duration, }, thiserror::Error, tokio::{ sync::{ mpsc::{error::TrySendError, Receiver as AsyncReceiver, Sender as AsyncSender}, oneshot::Sender as OneShotSender, Mutex, RwLock as AsyncRwLock, }, task::JoinHandle, }, }; const ALPN_REPAIR_PROTOCOL_ID: &[u8] = b"solana-repair"; const CONNECT_SERVER_NAME: &str = "solana-repair"; const CLIENT_CHANNEL_BUFFER: usize = 1 << 14; const ROUTER_CHANNEL_BUFFER: usize = 64; const CONNECTION_CACHE_CAPACITY: usize = 3072; const MAX_CONCURRENT_BIDI_STREAMS: VarInt = VarInt::from_u32(512); const CONNECTION_CLOSE_ERROR_CODE_SHUTDOWN: VarInt = VarInt::from_u32(1); const CONNECTION_CLOSE_ERROR_CODE_DROPPED: VarInt = VarInt::from_u32(2); const CONNECTION_CLOSE_ERROR_CODE_INVALID_IDENTITY: VarInt = VarInt::from_u32(3); const CONNECTION_CLOSE_ERROR_CODE_REPLACED: VarInt = VarInt::from_u32(4); const CONNECTION_CLOSE_ERROR_CODE_PRUNED: VarInt = VarInt::from_u32(5); const CONNECTION_CLOSE_REASON_SHUTDOWN: &[u8] = b"SHUTDOWN"; const CONNECTION_CLOSE_REASON_DROPPED: &[u8] = b"DROPPED"; const CONNECTION_CLOSE_REASON_INVALID_IDENTITY: &[u8] = b"INVALID_IDENTITY"; const CONNECTION_CLOSE_REASON_REPLACED: &[u8] = b"REPLACED"; const CONNECTION_CLOSE_REASON_PRUNED: &[u8] = b"PRUNED"; pub(crate) type AsyncTryJoinHandle = TryJoin, JoinHandle<()>>; // Outgoing local requests. pub struct LocalRequest { pub(crate) remote_address: SocketAddr, pub(crate) bytes: Vec, pub(crate) num_expected_responses: usize, pub(crate) response_sender: Sender<(SocketAddr, Vec)>, } // Incomming requests from remote nodes. // remote_pubkey and response_sender are None only when adapting UDP packets. pub struct RemoteRequest { pub(crate) remote_pubkey: Option, pub(crate) remote_address: SocketAddr, pub(crate) bytes: Vec, pub(crate) response_sender: Option>>>, } #[derive(Error, Debug)] #[allow(clippy::enum_variant_names)] pub(crate) enum Error { #[error(transparent)] BincodeError(#[from] bincode::Error), #[error(transparent)] CertificateError(#[from] RcgenError), #[error(transparent)] ConnectError(#[from] ConnectError), #[error(transparent)] ConnectionError(#[from] ConnectionError), #[error("Channel Send Error")] ChannelSendError, #[error("Invalid Identity: {0:?}")] InvalidIdentity(SocketAddr), #[error(transparent)] IoError(#[from] IoError), #[error("No Response Received")] NoResponseReceived, #[error(transparent)] ReadToEndError(#[from] ReadToEndError), #[error("read_to_end Timeout")] ReadToEndTimeout, #[error(transparent)] WriteError(#[from] WriteError), #[error(transparent)] TlsError(#[from] rustls::Error), } #[allow(clippy::type_complexity)] pub(crate) fn new_quic_endpoint( runtime: &tokio::runtime::Handle, keypair: &Keypair, socket: UdpSocket, address: IpAddr, remote_request_sender: Sender, bank_forks: Arc>, ) -> Result<(Endpoint, AsyncSender, AsyncTryJoinHandle), Error> { let (cert, key) = new_self_signed_tls_certificate(keypair, address)?; let server_config = new_server_config(cert.clone(), key.clone())?; let client_config = new_client_config(cert, key)?; let mut endpoint = { // Endpoint::new requires entering the runtime context, // otherwise the code below will panic. let _guard = runtime.enter(); Endpoint::new( EndpointConfig::default(), Some(server_config), socket, Arc::new(TokioRuntime), )? }; endpoint.set_default_client_config(client_config); let prune_cache_pending = Arc::::default(); let cache = Arc::>>::default(); let (client_sender, client_receiver) = tokio::sync::mpsc::channel(CLIENT_CHANNEL_BUFFER); let router = Arc::>>>::default(); let server_task = runtime.spawn(run_server( endpoint.clone(), remote_request_sender.clone(), bank_forks.clone(), prune_cache_pending.clone(), router.clone(), cache.clone(), )); let client_task = runtime.spawn(run_client( endpoint.clone(), client_receiver, remote_request_sender, bank_forks, prune_cache_pending, router, cache, )); let task = futures::future::try_join(server_task, client_task); Ok((endpoint, client_sender, task)) } pub(crate) fn close_quic_endpoint(endpoint: &Endpoint) { endpoint.close( CONNECTION_CLOSE_ERROR_CODE_SHUTDOWN, CONNECTION_CLOSE_REASON_SHUTDOWN, ); } fn new_server_config(cert: Certificate, key: PrivateKey) -> Result { let mut config = rustls::ServerConfig::builder() .with_safe_defaults() .with_client_cert_verifier(Arc::new(SkipClientVerification {})) .with_single_cert(vec![cert], key)?; config.alpn_protocols = vec![ALPN_REPAIR_PROTOCOL_ID.to_vec()]; let mut config = ServerConfig::with_crypto(Arc::new(config)); config .transport_config(Arc::new(new_transport_config())) .use_retry(true) .migration(false); Ok(config) } fn new_client_config(cert: Certificate, key: PrivateKey) -> Result { let mut config = rustls::ClientConfig::builder() .with_safe_defaults() .with_custom_certificate_verifier(Arc::new(SkipServerVerification {})) .with_client_auth_cert(vec![cert], key)?; config.enable_early_data = true; config.alpn_protocols = vec![ALPN_REPAIR_PROTOCOL_ID.to_vec()]; let mut config = ClientConfig::new(Arc::new(config)); config.transport_config(Arc::new(new_transport_config())); Ok(config) } fn new_transport_config() -> TransportConfig { let mut config = TransportConfig::default(); config .max_concurrent_bidi_streams(MAX_CONCURRENT_BIDI_STREAMS) .max_concurrent_uni_streams(VarInt::from(0u8)) .datagram_receive_buffer_size(None); config } async fn run_server( endpoint: Endpoint, remote_request_sender: Sender, bank_forks: Arc>, prune_cache_pending: Arc, router: Arc>>>, cache: Arc>>, ) { while let Some(connecting) = endpoint.accept().await { tokio::task::spawn(handle_connecting_error( endpoint.clone(), connecting, remote_request_sender.clone(), bank_forks.clone(), prune_cache_pending.clone(), router.clone(), cache.clone(), )); } } async fn run_client( endpoint: Endpoint, mut receiver: AsyncReceiver, remote_request_sender: Sender, bank_forks: Arc>, prune_cache_pending: Arc, router: Arc>>>, cache: Arc>>, ) { while let Some(request) = receiver.recv().await { let Some(request) = try_route_request(request, &*router.read().await) else { continue; }; let remote_address = request.remote_address; let receiver = { let mut router = router.write().await; let Some(request) = try_route_request(request, &router) else { continue; }; let (sender, receiver) = tokio::sync::mpsc::channel(ROUTER_CHANNEL_BUFFER); sender.try_send(request).unwrap(); router.insert(remote_address, sender); receiver }; tokio::task::spawn(make_connection_task( endpoint.clone(), remote_address, remote_request_sender.clone(), receiver, bank_forks.clone(), prune_cache_pending.clone(), router.clone(), cache.clone(), )); } close_quic_endpoint(&endpoint); // Drop sender channels to unblock threads waiting on the receiving end. router.write().await.clear(); } // Routes the local request to respective channel. Drops the request if the // channel is full. Bounces the request back if the channel is closed or does // not exist. fn try_route_request( request: LocalRequest, router: &HashMap>, ) -> Option { match router.get(&request.remote_address) { None => Some(request), Some(sender) => match sender.try_send(request) { Ok(()) => None, Err(TrySendError::Full(request)) => { error!("TrySendError::Full {}", request.remote_address); None } Err(TrySendError::Closed(request)) => Some(request), }, } } async fn handle_connecting_error( endpoint: Endpoint, connecting: Connecting, remote_request_sender: Sender, bank_forks: Arc>, prune_cache_pending: Arc, router: Arc>>>, cache: Arc>>, ) { if let Err(err) = handle_connecting( endpoint, connecting, remote_request_sender, bank_forks, prune_cache_pending, router, cache, ) .await { error!("handle_connecting: {err:?}"); } } async fn handle_connecting( endpoint: Endpoint, connecting: Connecting, remote_request_sender: Sender, bank_forks: Arc>, prune_cache_pending: Arc, router: Arc>>>, cache: Arc>>, ) -> Result<(), Error> { let connection = connecting.await?; let remote_address = connection.remote_address(); let remote_pubkey = get_remote_pubkey(&connection)?; let receiver = { let (sender, receiver) = tokio::sync::mpsc::channel(ROUTER_CHANNEL_BUFFER); router.write().await.insert(remote_address, sender); receiver }; handle_connection( endpoint, remote_address, remote_pubkey, connection, remote_request_sender, receiver, bank_forks, prune_cache_pending, router, cache, ) .await; Ok(()) } #[allow(clippy::too_many_arguments)] async fn handle_connection( endpoint: Endpoint, remote_address: SocketAddr, remote_pubkey: Pubkey, connection: Connection, remote_request_sender: Sender, receiver: AsyncReceiver, bank_forks: Arc>, prune_cache_pending: Arc, router: Arc>>>, cache: Arc>>, ) { cache_connection( remote_pubkey, connection.clone(), bank_forks, prune_cache_pending, router.clone(), cache.clone(), ) .await; let send_requests_task = tokio::task::spawn(send_requests_task( endpoint.clone(), connection.clone(), receiver, )); let recv_requests_task = tokio::task::spawn(recv_requests_task( endpoint, remote_address, remote_pubkey, connection.clone(), remote_request_sender, )); match futures::future::try_join(send_requests_task, recv_requests_task).await { Err(err) => error!("handle_connection: {remote_pubkey}, {remote_address}, {err:?}"), Ok(((), Err(ref err))) => { error!("recv_requests_task: {remote_pubkey}, {remote_address}, {err:?}"); } Ok(((), Ok(()))) => (), } drop_connection(remote_pubkey, &connection, &cache).await; if let Entry::Occupied(entry) = router.write().await.entry(remote_address) { if entry.get().is_closed() { entry.remove(); } } } async fn recv_requests_task( endpoint: Endpoint, remote_address: SocketAddr, remote_pubkey: Pubkey, connection: Connection, remote_request_sender: Sender, ) -> Result<(), Error> { loop { let (send_stream, recv_stream) = connection.accept_bi().await?; tokio::task::spawn(handle_streams_task( endpoint.clone(), remote_address, remote_pubkey, send_stream, recv_stream, remote_request_sender.clone(), )); } } async fn handle_streams_task( endpoint: Endpoint, remote_address: SocketAddr, remote_pubkey: Pubkey, send_stream: SendStream, recv_stream: RecvStream, remote_request_sender: Sender, ) { if let Err(err) = handle_streams( &endpoint, remote_address, remote_pubkey, send_stream, recv_stream, &remote_request_sender, ) .await { error!("handle_stream: {remote_address}, {remote_pubkey}, {err:?}"); } } async fn handle_streams( endpoint: &Endpoint, remote_address: SocketAddr, remote_pubkey: Pubkey, mut send_stream: SendStream, mut recv_stream: RecvStream, remote_request_sender: &Sender, ) -> Result<(), Error> { // Assert that send won't block. debug_assert_eq!(remote_request_sender.capacity(), None); const READ_TIMEOUT_DURATION: Duration = Duration::from_secs(2); let bytes = tokio::time::timeout( READ_TIMEOUT_DURATION, recv_stream.read_to_end(PACKET_DATA_SIZE), ) .await .map_err(|_| Error::ReadToEndTimeout)??; let (response_sender, response_receiver) = tokio::sync::oneshot::channel(); let remote_request = RemoteRequest { remote_pubkey: Some(remote_pubkey), remote_address, bytes, response_sender: Some(response_sender), }; if let Err(err) = remote_request_sender.send(remote_request) { close_quic_endpoint(endpoint); return Err(Error::from(err)); } let Ok(response) = response_receiver.await else { return Err(Error::NoResponseReceived); }; for chunk in response { let size = chunk.len() as u64; send_stream.write_all(&size.to_le_bytes()).await?; send_stream.write_all(&chunk).await?; } send_stream.finish().await.map_err(Error::from) } async fn send_requests_task( endpoint: Endpoint, connection: Connection, mut receiver: AsyncReceiver, ) { while let Some(request) = receiver.recv().await { tokio::task::spawn(send_request_task( endpoint.clone(), connection.clone(), request, )); } } async fn send_request_task(endpoint: Endpoint, connection: Connection, request: LocalRequest) { if let Err(err) = send_request(endpoint, connection, request).await { error!("send_request: {err:?}") } } async fn send_request( endpoint: Endpoint, connection: Connection, LocalRequest { remote_address: _, bytes, num_expected_responses, response_sender, }: LocalRequest, ) -> Result<(), Error> { // Assert that send won't block. debug_assert_eq!(response_sender.capacity(), None); const READ_TIMEOUT_DURATION: Duration = Duration::from_secs(10); let (mut send_stream, mut recv_stream) = connection.open_bi().await?; send_stream.write_all(&bytes).await?; send_stream.finish().await?; // Each response is at most PACKET_DATA_SIZE bytes and requires // an additional 8 bytes to encode its length. let size = PACKET_DATA_SIZE .saturating_add(8) .saturating_mul(num_expected_responses); let response = tokio::time::timeout(READ_TIMEOUT_DURATION, recv_stream.read_to_end(size)) .await .map_err(|_| Error::ReadToEndTimeout)??; let remote_address = connection.remote_address(); let mut cursor = Cursor::new(&response[..]); std::iter::repeat_with(|| { bincode::options() .with_limit(response.len() as u64) .with_fixint_encoding() .allow_trailing_bytes() .deserialize_from::<_, ByteBuf>(&mut cursor) .map(ByteBuf::into_vec) .ok() }) .while_some() .try_for_each(|chunk| { response_sender .send((remote_address, chunk)) .map_err(|err| { close_quic_endpoint(&endpoint); Error::from(err) }) }) } async fn make_connection_task( endpoint: Endpoint, remote_address: SocketAddr, remote_request_sender: Sender, receiver: AsyncReceiver, bank_forks: Arc>, prune_cache_pending: Arc, router: Arc>>>, cache: Arc>>, ) { if let Err(err) = make_connection( endpoint, remote_address, remote_request_sender, receiver, bank_forks, prune_cache_pending, router, cache, ) .await { error!("make_connection: {remote_address}, {err:?}"); } } async fn make_connection( endpoint: Endpoint, remote_address: SocketAddr, remote_request_sender: Sender, receiver: AsyncReceiver, bank_forks: Arc>, prune_cache_pending: Arc, router: Arc>>>, cache: Arc>>, ) -> Result<(), Error> { let connection = endpoint .connect(remote_address, CONNECT_SERVER_NAME)? .await?; handle_connection( endpoint, connection.remote_address(), get_remote_pubkey(&connection)?, connection, remote_request_sender, receiver, bank_forks, prune_cache_pending, router, cache, ) .await; Ok(()) } fn get_remote_pubkey(connection: &Connection) -> Result { match solana_streamer::nonblocking::quic::get_remote_pubkey(connection) { Some(remote_pubkey) => Ok(remote_pubkey), None => { connection.close( CONNECTION_CLOSE_ERROR_CODE_INVALID_IDENTITY, CONNECTION_CLOSE_REASON_INVALID_IDENTITY, ); Err(Error::InvalidIdentity(connection.remote_address())) } } } async fn cache_connection( remote_pubkey: Pubkey, connection: Connection, bank_forks: Arc>, prune_cache_pending: Arc, router: Arc>>>, cache: Arc>>, ) { let (old, should_prune_cache) = { let mut cache = cache.lock().await; ( cache.insert(remote_pubkey, connection), cache.len() >= CONNECTION_CACHE_CAPACITY.saturating_mul(2), ) }; if let Some(old) = old { old.close( CONNECTION_CLOSE_ERROR_CODE_REPLACED, CONNECTION_CLOSE_REASON_REPLACED, ); } if should_prune_cache && !prune_cache_pending.swap(true, Ordering::Relaxed) { tokio::task::spawn(prune_connection_cache( bank_forks, prune_cache_pending, router, cache, )); } } async fn drop_connection( remote_pubkey: Pubkey, connection: &Connection, cache: &Mutex>, ) { connection.close( CONNECTION_CLOSE_ERROR_CODE_DROPPED, CONNECTION_CLOSE_REASON_DROPPED, ); if let Entry::Occupied(entry) = cache.lock().await.entry(remote_pubkey) { if entry.get().stable_id() == connection.stable_id() { entry.remove(); } } } async fn prune_connection_cache( bank_forks: Arc>, prune_cache_pending: Arc, router: Arc>>>, cache: Arc>>, ) { debug_assert!(prune_cache_pending.load(Ordering::Relaxed)); let staked_nodes = { let root_bank = bank_forks.read().unwrap().root_bank(); root_bank.staked_nodes() }; { let mut cache = cache.lock().await; if cache.len() < CONNECTION_CACHE_CAPACITY.saturating_mul(2) { prune_cache_pending.store(false, Ordering::Relaxed); return; } let mut connections: Vec<_> = cache .drain() .filter(|(_, connection)| connection.close_reason().is_none()) .map(|entry @ (pubkey, _)| { let stake = staked_nodes.get(&pubkey).copied().unwrap_or_default(); (stake, entry) }) .collect(); connections .select_nth_unstable_by_key(CONNECTION_CACHE_CAPACITY, |&(stake, _)| Reverse(stake)); for (_, (_, connection)) in &connections[CONNECTION_CACHE_CAPACITY..] { connection.close( CONNECTION_CLOSE_ERROR_CODE_PRUNED, CONNECTION_CLOSE_REASON_PRUNED, ); } cache.extend( connections .into_iter() .take(CONNECTION_CACHE_CAPACITY) .map(|(_, entry)| entry), ); prune_cache_pending.store(false, Ordering::Relaxed); } router.write().await.retain(|_, sender| !sender.is_closed()); } impl From> for Error { fn from(_: crossbeam_channel::SendError) -> Self { Error::ChannelSendError } } #[cfg(test)] mod tests { use { super::*, itertools::{izip, multiunzip}, solana_ledger::genesis_utils::{create_genesis_config, GenesisConfigInfo}, solana_runtime::bank::Bank, solana_sdk::signature::Signer, std::{iter::repeat_with, net::Ipv4Addr, time::Duration}, }; #[test] fn test_quic_endpoint() { const NUM_ENDPOINTS: usize = 3; const RECV_TIMEOUT: Duration = Duration::from_secs(30); let runtime = tokio::runtime::Builder::new_multi_thread() .worker_threads(8) .enable_all() .build() .unwrap(); let keypairs: Vec = repeat_with(Keypair::new).take(NUM_ENDPOINTS).collect(); let sockets: Vec = repeat_with(|| UdpSocket::bind((Ipv4Addr::LOCALHOST, 0))) .take(NUM_ENDPOINTS) .collect::>() .unwrap(); let addresses: Vec = sockets .iter() .map(UdpSocket::local_addr) .collect::>() .unwrap(); let (remote_request_senders, remote_request_receivers): (Vec<_>, Vec<_>) = repeat_with(crossbeam_channel::unbounded::) .take(NUM_ENDPOINTS) .unzip(); let bank_forks = { let GenesisConfigInfo { genesis_config, .. } = create_genesis_config(/*mint_lamports:*/ 100_000); let bank = Bank::new_for_tests(&genesis_config); BankForks::new_rw_arc(bank) }; let (endpoints, senders, tasks): (Vec<_>, Vec<_>, Vec<_>) = multiunzip( keypairs .iter() .zip(sockets) .zip(remote_request_senders) .map(|((keypair, socket), remote_request_sender)| { new_quic_endpoint( runtime.handle(), keypair, socket, IpAddr::V4(Ipv4Addr::LOCALHOST), remote_request_sender, bank_forks.clone(), ) .unwrap() }), ); let (response_senders, response_receivers): (Vec<_>, Vec<_>) = repeat_with(crossbeam_channel::unbounded::<(SocketAddr, Vec)>) .take(NUM_ENDPOINTS) .unzip(); // Send a unique request from each endpoint to every other endpoint. for (i, (keypair, &address, sender)) in izip!(&keypairs, &addresses, &senders).enumerate() { for (j, (&remote_address, response_sender)) in addresses.iter().zip(&response_senders).enumerate() { if i != j { let mut bytes: Vec = format!("{i}=>{j}").into_bytes(); bytes.resize(PACKET_DATA_SIZE, 0xa5); let request = LocalRequest { remote_address, bytes, num_expected_responses: j + 1, response_sender: response_sender.clone(), }; sender.blocking_send(request).unwrap(); } } // Verify all requests are received and respond to each. for (j, remote_request_receiver) in remote_request_receivers.iter().enumerate() { if i != j { let RemoteRequest { remote_pubkey, remote_address, bytes, response_sender, } = remote_request_receiver.recv_timeout(RECV_TIMEOUT).unwrap(); assert_eq!(remote_pubkey, Some(keypair.pubkey())); assert_eq!(remote_address, address); assert_eq!(bytes, { let mut bytes = format!("{i}=>{j}").into_bytes(); bytes.resize(PACKET_DATA_SIZE, 0xa5); bytes }); let response: Vec> = (0..=j) .map(|k| { let mut bytes = format!("{j}=>{i}({k})").into_bytes(); bytes.resize(PACKET_DATA_SIZE, 0xd5); bytes }) .collect(); response_sender.unwrap().send(response).unwrap(); } } // Verify responses. for (j, (&remote_address, response_receiver)) in addresses.iter().zip(&response_receivers).enumerate() { if i != j { for k in 0..=j { let (address, response) = response_receiver.recv_timeout(RECV_TIMEOUT).unwrap(); assert_eq!(address, remote_address); assert_eq!(response, { let mut bytes = format!("{j}=>{i}({k})").into_bytes(); bytes.resize(PACKET_DATA_SIZE, 0xd5); bytes }); } } } } drop(senders); for endpoint in endpoints { close_quic_endpoint(&endpoint); } for task in tasks { runtime.block_on(task).unwrap(); } } }