diff --git a/Cargo.lock b/Cargo.lock index 2b9aa0d9e6..d0b7111f2c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7099,14 +7099,19 @@ name = "solana-turbine" version = "1.17.0" dependencies = [ "bincode", + "bytes", "crossbeam-channel", + "futures 0.3.28", "itertools", "log", "lru", "matches", + "quinn", "rand 0.7.3", "rand_chacha 0.2.2", "rayon", + "rcgen", + "rustls 0.20.8", "solana-client", "solana-entry", "solana-gossip", @@ -7124,6 +7129,7 @@ dependencies = [ "solana-sdk", "solana-streamer", "thiserror", + "tokio", ] [[package]] diff --git a/programs/sbf/Cargo.lock b/programs/sbf/Cargo.lock index c47a5e138a..861425b116 100644 --- a/programs/sbf/Cargo.lock +++ b/programs/sbf/Cargo.lock @@ -6174,13 +6174,18 @@ name = "solana-turbine" version = "1.17.0" dependencies = [ "bincode", + "bytes", "crossbeam-channel", + "futures 0.3.28", "itertools", "log", "lru", + "quinn", "rand 0.7.3", "rand_chacha 0.2.2", "rayon", + "rcgen", + "rustls 0.20.8", "solana-client", "solana-entry", "solana-gossip", @@ -6197,6 +6202,7 @@ dependencies = [ "solana-sdk", "solana-streamer", "thiserror", + "tokio", ] [[package]] diff --git a/quic-client/src/nonblocking/quic_client.rs b/quic-client/src/nonblocking/quic_client.rs index dd5b858865..3e9fd27d72 100644 --- a/quic-client/src/nonblocking/quic_client.rs +++ b/quic-client/src/nonblocking/quic_client.rs @@ -38,7 +38,7 @@ use { tokio::{sync::OnceCell, time::timeout}, }; -struct SkipServerVerification; +pub struct SkipServerVerification; impl SkipServerVerification { pub fn new() -> Arc { diff --git a/streamer/src/nonblocking/quic.rs b/streamer/src/nonblocking/quic.rs index 4299a0424e..02614bbd14 100644 --- a/streamer/src/nonblocking/quic.rs +++ b/streamer/src/nonblocking/quic.rs @@ -190,18 +190,22 @@ fn prune_unstaked_connection_table( } } -fn get_connection_stake( - connection: &Connection, - staked_nodes: &RwLock, -) -> Option<(Pubkey, u64, u64, u64, u64)> { +pub fn get_remote_pubkey(connection: &Connection) -> Option { // Use the client cert only if it is self signed and the chain length is 1. - let pubkey = connection + connection .peer_identity()? .downcast::>() .ok() .filter(|certs| certs.len() == 1)? .first() - .and_then(get_pubkey_from_tls_certificate)?; + .and_then(get_pubkey_from_tls_certificate) +} + +fn get_connection_stake( + connection: &Connection, + staked_nodes: &RwLock, +) -> Option<(Pubkey, u64, u64, u64, u64)> { + let pubkey = get_remote_pubkey(connection)?; debug!("Peer public key is {pubkey:?}"); let staked_nodes = staked_nodes.read().unwrap(); Some(( diff --git a/streamer/src/quic.rs b/streamer/src/quic.rs index 20b013f010..e64630e36c 100644 --- a/streamer/src/quic.rs +++ b/streamer/src/quic.rs @@ -28,7 +28,7 @@ use { pub const MAX_STAKED_CONNECTIONS: usize = 2000; pub const MAX_UNSTAKED_CONNECTIONS: usize = 500; -struct SkipClientVerification; +pub struct SkipClientVerification; impl SkipClientVerification { pub fn new() -> Arc { diff --git a/turbine/Cargo.toml b/turbine/Cargo.toml index 32880737ce..9807959fd0 100644 --- a/turbine/Cargo.toml +++ b/turbine/Cargo.toml @@ -11,13 +11,18 @@ edition = { workspace = true } [dependencies] bincode = { workspace = true } +bytes = { workspace = true } crossbeam-channel = { workspace = true } +futures = { workspace = true } itertools = { workspace = true } log = { workspace = true } lru = { workspace = true } +quinn = { workspace = true } rand = { workspace = true } rand_chacha = { workspace = true } rayon = { workspace = true } +rcgen = { workspace = true } +rustls = { workspace = true } solana-client = { workspace = true } solana-entry = { workspace = true } solana-gossip = { workspace = true } @@ -34,6 +39,7 @@ solana-runtime = { workspace = true } solana-sdk = { workspace = true } solana-streamer = { workspace = true } thiserror = { workspace = true } +tokio = { workspace = true } [dev-dependencies] matches = { workspace = true } diff --git a/turbine/src/lib.rs b/turbine/src/lib.rs index 5d44d9a992..d73713549d 100644 --- a/turbine/src/lib.rs +++ b/turbine/src/lib.rs @@ -2,6 +2,7 @@ pub mod broadcast_stage; pub mod cluster_nodes; +pub mod quic_endpoint; pub mod retransmit_stage; pub mod sigverify_shreds; diff --git a/turbine/src/quic_endpoint.rs b/turbine/src/quic_endpoint.rs new file mode 100644 index 0000000000..513d2c9acf --- /dev/null +++ b/turbine/src/quic_endpoint.rs @@ -0,0 +1,441 @@ +use { + bytes::Bytes, + crossbeam_channel::Sender, + futures::future::TryJoin, + log::error, + quinn::{ + ClientConfig, ConnectError, Connecting, Connection, ConnectionError, Endpoint, + EndpointConfig, SendDatagramError, ServerConfig, TokioRuntime, TransportConfig, VarInt, + }, + rcgen::RcgenError, + rustls::{Certificate, PrivateKey}, + solana_quic_client::nonblocking::quic_client::SkipServerVerification, + solana_sdk::{pubkey::Pubkey, signature::Keypair}, + solana_streamer::{ + quic::SkipClientVerification, tls_certificates::new_self_signed_tls_certificate, + }, + std::{ + collections::{hash_map::Entry, HashMap}, + io::Error as IoError, + net::{IpAddr, SocketAddr, UdpSocket}, + ops::Deref, + sync::Arc, + }, + thiserror::Error, + tokio::{ + runtime::Runtime, + sync::{ + mpsc::{Receiver as AsyncReceiver, Sender as AsyncSender}, + RwLock, + }, + task::JoinHandle, + }, +}; + +const CLIENT_CHANNEL_CAPACITY: usize = 1 << 20; +const INITIAL_MAX_UDP_PAYLOAD_SIZE: u16 = 1280; +const ALPN_TURBINE_PROTOCOL_ID: &[u8] = b"solana-turbine"; +const CONNECT_SERVER_NAME: &str = "solana-turbine"; + +const CONNECTION_CLOSE_ERROR_CODE_SHUTDOWN: VarInt = VarInt::from_u32(1); +const CONNECTION_CLOSE_ERROR_CODE_DROPPED: VarInt = VarInt::from_u32(2); +const CONNECTION_CLOSE_ERROR_CODE_INVALID_IDENTITY: VarInt = VarInt::from_u32(3); +const CONNECTION_CLOSE_ERROR_CODE_REPLACED: VarInt = VarInt::from_u32(4); + +const CONNECTION_CLOSE_REASON_SHUTDOWN: &[u8] = b"SHUTDOWN"; +const CONNECTION_CLOSE_REASON_DROPPED: &[u8] = b"DROPPED"; +const CONNECTION_CLOSE_REASON_INVALID_IDENTITY: &[u8] = b"INVALID_IDENTITY"; +const CONNECTION_CLOSE_REASON_REPLACED: &[u8] = b"REPLACED"; + +type ConnectionCache = HashMap<(SocketAddr, Option), Arc>>>; + +#[derive(Error, Debug)] +pub enum Error { + #[error(transparent)] + CertificateError(#[from] RcgenError), + #[error(transparent)] + ConnectError(#[from] ConnectError), + #[error(transparent)] + ConnectionError(#[from] ConnectionError), + #[error("Channel Send Error")] + ChannelSendError, + #[error("Invalid Identity: {0:?}")] + InvalidIdentity(SocketAddr), + #[error(transparent)] + IoError(#[from] IoError), + #[error(transparent)] + SendDatagramError(#[from] SendDatagramError), + #[error(transparent)] + TlsError(#[from] rustls::Error), +} + +#[allow(clippy::type_complexity)] +pub fn new_quic_endpoint( + runtime: &Runtime, + keypair: &Keypair, + socket: UdpSocket, + address: IpAddr, + sender: Sender<(Pubkey, SocketAddr, Bytes)>, +) -> Result< + ( + Endpoint, + AsyncSender<(SocketAddr, Bytes)>, + TryJoin, JoinHandle<()>>, + ), + Error, +> { + let (cert, key) = new_self_signed_tls_certificate(keypair, address)?; + let server_config = new_server_config(cert.clone(), key.clone())?; + let client_config = new_client_config(cert, key)?; + let mut endpoint = { + // Endpoint::new requires entering the runtime context, + // otherwise the code below will panic. + let _guard = runtime.enter(); + Endpoint::new( + EndpointConfig::default(), + Some(server_config), + socket, + TokioRuntime, + )? + }; + endpoint.set_default_client_config(client_config); + let cache = Arc::>::default(); + let (client_sender, client_receiver) = tokio::sync::mpsc::channel(CLIENT_CHANNEL_CAPACITY); + let server_task = runtime.spawn(run_server(endpoint.clone(), sender.clone(), cache.clone())); + let client_task = runtime.spawn(run_client(endpoint.clone(), client_receiver, sender, cache)); + let task = futures::future::try_join(server_task, client_task); + Ok((endpoint, client_sender, task)) +} + +pub fn close_quic_endpoint(endpoint: &Endpoint) { + endpoint.close( + CONNECTION_CLOSE_ERROR_CODE_SHUTDOWN, + CONNECTION_CLOSE_REASON_SHUTDOWN, + ); +} + +fn new_server_config(cert: Certificate, key: PrivateKey) -> Result { + let mut config = rustls::ServerConfig::builder() + .with_safe_defaults() + .with_client_cert_verifier(Arc::new(SkipClientVerification {})) + .with_single_cert(vec![cert], key)?; + config.alpn_protocols = vec![ALPN_TURBINE_PROTOCOL_ID.to_vec()]; + let mut config = ServerConfig::with_crypto(Arc::new(config)); + config + .transport_config(Arc::new(new_transport_config())) + .use_retry(true) + .migration(false); + Ok(config) +} + +fn new_client_config(cert: Certificate, key: PrivateKey) -> Result { + let mut config = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_custom_certificate_verifier(Arc::new(SkipServerVerification {})) + .with_single_cert(vec![cert], key)?; + config.enable_early_data = true; + config.alpn_protocols = vec![ALPN_TURBINE_PROTOCOL_ID.to_vec()]; + let mut config = ClientConfig::new(Arc::new(config)); + config.transport_config(Arc::new(new_transport_config())); + Ok(config) +} + +fn new_transport_config() -> TransportConfig { + let mut config = TransportConfig::default(); + config + .max_concurrent_bidi_streams(VarInt::from(0u8)) + .max_concurrent_uni_streams(VarInt::from(0u8)) + .initial_max_udp_payload_size(INITIAL_MAX_UDP_PAYLOAD_SIZE); + config +} + +async fn run_server( + endpoint: Endpoint, + sender: Sender<(Pubkey, SocketAddr, Bytes)>, + cache: Arc>, +) { + while let Some(connecting) = endpoint.accept().await { + tokio::task::spawn(handle_connecting_error( + connecting, + sender.clone(), + cache.clone(), + )); + } +} + +async fn run_client( + endpoint: Endpoint, + mut receiver: AsyncReceiver<(SocketAddr, Bytes)>, + sender: Sender<(Pubkey, SocketAddr, Bytes)>, + cache: Arc>, +) { + while let Some((remote_address, bytes)) = receiver.recv().await { + tokio::task::spawn(send_datagram_task( + endpoint.clone(), + remote_address, + bytes, + sender.clone(), + cache.clone(), + )); + } + close_quic_endpoint(&endpoint); +} + +async fn handle_connecting_error( + connecting: Connecting, + sender: Sender<(Pubkey, SocketAddr, Bytes)>, + cache: Arc>, +) { + if let Err(err) = handle_connecting(connecting, sender, cache).await { + error!("handle_connecting: {err:?}"); + } +} + +async fn handle_connecting( + connecting: Connecting, + sender: Sender<(Pubkey, SocketAddr, Bytes)>, + cache: Arc>, +) -> Result<(), Error> { + let connection = connecting.await?; + let remote_address = connection.remote_address(); + let remote_pubkey = get_remote_pubkey(&connection)?; + handle_connection_error(remote_address, remote_pubkey, connection, sender, cache).await; + Ok(()) +} + +async fn handle_connection_error( + remote_address: SocketAddr, + remote_pubkey: Pubkey, + connection: Connection, + sender: Sender<(Pubkey, SocketAddr, Bytes)>, + cache: Arc>, +) { + cache_connection(remote_address, remote_pubkey, connection.clone(), &cache).await; + if let Err(err) = handle_connection(remote_address, remote_pubkey, &connection, sender).await { + drop_connection(remote_address, remote_pubkey, &connection, &cache).await; + error!("handle_connection: {remote_pubkey}, {remote_address}, {err:?}"); + } +} + +async fn handle_connection( + remote_address: SocketAddr, + remote_pubkey: Pubkey, + connection: &Connection, + sender: Sender<(Pubkey, SocketAddr, Bytes)>, +) -> Result<(), Error> { + // Assert that send won't block. + debug_assert_eq!(sender.capacity(), None); + loop { + match connection.read_datagram().await { + Ok(bytes) => sender.send((remote_pubkey, remote_address, bytes))?, + Err(err) => { + if let Some(err) = connection.close_reason() { + return Err(Error::from(err)); + } + error!("connection.read_datagram: {remote_pubkey}, {remote_address}, {err:?}"); + } + }; + } +} + +async fn send_datagram_task( + endpoint: Endpoint, + remote_address: SocketAddr, + bytes: Bytes, + sender: Sender<(Pubkey, SocketAddr, Bytes)>, + cache: Arc>, +) { + if let Err(err) = send_datagram(&endpoint, remote_address, bytes, sender, cache).await { + error!("send_datagram: {remote_address}, {err:?}"); + } +} + +async fn send_datagram( + endpoint: &Endpoint, + remote_address: SocketAddr, + bytes: Bytes, + sender: Sender<(Pubkey, SocketAddr, Bytes)>, + cache: Arc>, +) -> Result<(), Error> { + let connection = get_connection(endpoint, remote_address, sender, cache).await?; + connection.send_datagram(bytes)?; + Ok(()) +} + +async fn get_connection( + endpoint: &Endpoint, + remote_address: SocketAddr, + sender: Sender<(Pubkey, SocketAddr, Bytes)>, + cache: Arc>, +) -> Result { + let key = (remote_address, /*remote_pubkey:*/ None); + let entry = cache.write().await.entry(key).or_default().clone(); + { + let connection: Option = entry.read().await.clone(); + if let Some(connection) = connection { + if connection.close_reason().is_none() { + return Ok(connection); + } + } + } + let connection = { + // Need to write lock here so that only one task initiates + // a new connection to the same remote_address. + let mut entry = entry.write().await; + if let Some(connection) = entry.deref() { + if connection.close_reason().is_none() { + return Ok(connection.clone()); + } + } + let connection = endpoint + .connect(remote_address, CONNECT_SERVER_NAME)? + .await?; + entry.insert(connection).clone() + }; + tokio::task::spawn(handle_connection_error( + connection.remote_address(), + get_remote_pubkey(&connection)?, + connection.clone(), + sender, + cache, + )); + Ok(connection) +} + +fn get_remote_pubkey(connection: &Connection) -> Result { + match solana_streamer::nonblocking::quic::get_remote_pubkey(connection) { + Some(remote_pubkey) => Ok(remote_pubkey), + None => { + connection.close( + CONNECTION_CLOSE_ERROR_CODE_INVALID_IDENTITY, + CONNECTION_CLOSE_REASON_INVALID_IDENTITY, + ); + Err(Error::InvalidIdentity(connection.remote_address())) + } + } +} + +async fn cache_connection( + remote_address: SocketAddr, + remote_pubkey: Pubkey, + connection: Connection, + cache: &RwLock, +) { + let entries: [Arc>>; 2] = { + let mut cache = cache.write().await; + [Some(remote_pubkey), None].map(|remote_pubkey| { + let key = (remote_address, remote_pubkey); + cache.entry(key).or_default().clone() + }) + }; + let mut entry = entries[0].write().await; + *entries[1].write().await = Some(connection.clone()); + if let Some(old) = entry.replace(connection) { + drop(entry); + old.close( + CONNECTION_CLOSE_ERROR_CODE_REPLACED, + CONNECTION_CLOSE_REASON_REPLACED, + ); + } +} + +async fn drop_connection( + remote_address: SocketAddr, + remote_pubkey: Pubkey, + connection: &Connection, + cache: &RwLock, +) { + if connection.close_reason().is_none() { + connection.close( + CONNECTION_CLOSE_ERROR_CODE_DROPPED, + CONNECTION_CLOSE_REASON_DROPPED, + ); + } + let key = (remote_address, Some(remote_pubkey)); + if let Entry::Occupied(entry) = cache.write().await.entry(key) { + if matches!(entry.get().read().await.deref(), + Some(entry) if entry.stable_id() == connection.stable_id()) + { + entry.remove(); + } + } + // Cache entry for (remote_address, None) will be lazily evicted. +} + +impl From> for Error { + fn from(_: crossbeam_channel::SendError) -> Self { + Error::ChannelSendError + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + itertools::{izip, multiunzip}, + solana_sdk::signature::Signer, + std::{iter::repeat_with, net::Ipv4Addr, time::Duration}, + }; + + #[test] + fn test_quic_endpoint() { + const NUM_ENDPOINTS: usize = 3; + const RECV_TIMEOUT: Duration = Duration::from_secs(60); + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(8) + .enable_all() + .build() + .unwrap(); + let keypairs: Vec = repeat_with(Keypair::new).take(NUM_ENDPOINTS).collect(); + let sockets: Vec = repeat_with(|| UdpSocket::bind((Ipv4Addr::LOCALHOST, 0))) + .take(NUM_ENDPOINTS) + .collect::>() + .unwrap(); + let addresses: Vec = sockets + .iter() + .map(UdpSocket::local_addr) + .collect::>() + .unwrap(); + let (senders, receivers): (Vec<_>, Vec<_>) = + repeat_with(crossbeam_channel::unbounded::<(Pubkey, SocketAddr, Bytes)>) + .take(NUM_ENDPOINTS) + .unzip(); + let (endpoints, senders, tasks): (Vec<_>, Vec<_>, Vec<_>) = + multiunzip(keypairs.iter().zip(sockets).zip(senders).map( + |((keypair, socket), sender)| { + new_quic_endpoint( + &runtime, + keypair, + socket, + IpAddr::V4(Ipv4Addr::LOCALHOST), + sender, + ) + .unwrap() + }, + )); + // Send a unique message from each endpoint to every other endpoint. + for (i, (keypair, &address, sender)) in izip!(&keypairs, &addresses, &senders).enumerate() { + for (j, &address) in addresses.iter().enumerate() { + if i != j { + let bytes = Bytes::from(format!("{i}=>{j}")); + sender.blocking_send((address, bytes)).unwrap(); + } + } + // Verify all messages are received. + for (j, receiver) in receivers.iter().enumerate() { + if i != j { + let bytes = Bytes::from(format!("{i}=>{j}")); + let entry = (keypair.pubkey(), address, bytes); + assert_eq!(receiver.recv_timeout(RECV_TIMEOUT).unwrap(), entry); + } + } + } + drop(senders); + for endpoint in endpoints { + close_quic_endpoint(&endpoint); + } + for task in tasks { + runtime.block_on(task).unwrap(); + } + } +}