diff --git a/Cargo.lock b/Cargo.lock index 053c1c7..b97b8e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2457,6 +2457,7 @@ dependencies = [ "bincode", "itertools", "log", + "mio_channel", "quic-geyser-common", "quic-geyser-server", "rand 0.8.5", diff --git a/block-builder/Cargo.toml b/block-builder/Cargo.toml index 421c359..0dd62e3 100644 --- a/block-builder/Cargo.toml +++ b/block-builder/Cargo.toml @@ -14,6 +14,7 @@ log = { workspace = true } quic-geyser-common = { workspace = true } bincode = { workspace = true } itertools = { workspace = true } +mio_channel = { workspace = true } [dev-dependencies] rand = { workspace = true } diff --git a/block-builder/src/block_builder.rs b/block-builder/src/block_builder.rs index 41e44f4..f1683a6 100644 --- a/block-builder/src/block_builder.rs +++ b/block-builder/src/block_builder.rs @@ -1,6 +1,6 @@ use std::{ collections::{BTreeMap, HashMap}, - sync::mpsc::{Receiver, Sender}, + sync::mpsc::Receiver, }; use itertools::Itertools; @@ -16,7 +16,7 @@ use solana_sdk::pubkey::Pubkey; pub fn start_block_building_thread( channel_messages: Receiver, - output: Sender, + output: mio_channel::Sender, compression_type: CompressionType, ) { std::thread::spawn(move || { @@ -33,7 +33,7 @@ struct PartialBlock { pub fn build_blocks( channel_messages: Receiver, - output: Sender, + output: mio_channel::Sender, compression_type: CompressionType, ) { let mut partially_build_blocks = BTreeMap::::new(); @@ -156,7 +156,7 @@ pub fn build_blocks( fn dispatch_partial_block( partial_blocks: &mut BTreeMap, slot: u64, - output: &Sender, + output: &mio_channel::Sender, compression_type: CompressionType, ) { if let Some(dispatched_partial_block) = partial_blocks.remove(&slot) { diff --git a/block-builder/src/tests.rs b/block-builder/src/tests.rs index d92efd2..6ba7816 100644 --- a/block-builder/src/tests.rs +++ b/block-builder/src/tests.rs @@ -34,7 +34,7 @@ mod tests { #[test] fn test_block_creation_transactions_after_blockmeta() { let (channelmsg_sx, cm_rx) = channel(); - let (ms_sx, msg_rx) = channel(); + let (ms_sx, msg_rx) = mio_channel::channel(); start_block_building_thread( cm_rx, ms_sx, @@ -227,7 +227,8 @@ mod tests { .send(ChannelMessage::Transaction(Box::new(tx3.clone()))) .unwrap(); - let block_message = msg_rx.recv().unwrap(); + sleep(Duration::from_millis(1)); + let block_message = msg_rx.try_recv().unwrap(); let ChannelMessage::Block(block) = block_message else { unreachable!(); }; @@ -256,7 +257,7 @@ mod tests { #[test] fn test_block_creation_blockmeta_after_transactions() { let (channelmsg_sx, cm_rx) = channel(); - let (ms_sx, msg_rx) = channel(); + let (ms_sx, msg_rx) = mio_channel::channel(); start_block_building_thread( cm_rx, ms_sx, @@ -450,7 +451,8 @@ mod tests { .send(ChannelMessage::BlockMeta(block_meta.clone())) .unwrap(); - let block_message = msg_rx.recv().unwrap(); + sleep(Duration::from_millis(1)); + let block_message = msg_rx.try_recv().unwrap(); let ChannelMessage::Block(block) = block_message else { unreachable!(); }; @@ -479,7 +481,7 @@ mod tests { #[test] fn test_block_creation_incomplete_block_after_slot_notification() { let (channelmsg_sx, cm_rx) = channel(); - let (ms_sx, msg_rx) = channel(); + let (ms_sx, msg_rx) = mio_channel::channel(); start_block_building_thread( cm_rx, ms_sx, @@ -672,7 +674,8 @@ mod tests { .send(ChannelMessage::Transaction(Box::new(tx3.clone()))) .unwrap(); - let block_message = msg_rx.recv().unwrap(); + sleep(Duration::from_millis(1)); + let block_message = msg_rx.try_recv().unwrap(); let ChannelMessage::Block(block) = block_message else { unreachable!(); }; @@ -701,7 +704,7 @@ mod tests { #[test] fn test_block_creation_incomplete_slot() { let (channelmsg_sx, cm_rx) = channel(); - let (ms_sx, msg_rx) = channel(); + let (ms_sx, msg_rx) = mio_channel::channel(); start_block_building_thread( cm_rx, ms_sx, diff --git a/blocking_client/src/quiche_client_loop.rs b/blocking_client/src/quiche_client_loop.rs index bf73ae8..9ff9258 100644 --- a/blocking_client/src/quiche_client_loop.rs +++ b/blocking_client/src/quiche_client_loop.rs @@ -373,7 +373,7 @@ mod tests { ); // server loop - let (server_send_queue, rx_sent_queue) = mpsc::channel::(); + let (server_send_queue, rx_sent_queue) = mio_channel::channel::(); let _server_loop_jh = std::thread::spawn(move || { if let Err(e) = server_loop( QuicParameters::default(), diff --git a/server/src/quic_server.rs b/server/src/quic_server.rs index 708389b..d250b2e 100644 --- a/server/src/quic_server.rs +++ b/server/src/quic_server.rs @@ -1,11 +1,11 @@ use quic_geyser_common::{ channel_message::ChannelMessage, config::ConfigQuicPlugin, plugin_error::QuicGeyserError, }; -use std::{fmt::Debug, sync::mpsc}; +use std::fmt::Debug; use super::quiche_server_loop::server_loop; pub struct QuicServer { - pub data_channel_sender: mpsc::Sender, + pub data_channel_sender: mio_channel::Sender, pub quic_plugin_config: ConfigQuicPlugin, } @@ -20,7 +20,7 @@ impl QuicServer { let socket = config.address; let compression_type = config.compression_parameters.compression_type; - let (data_channel_sender, data_channel_tx) = mpsc::channel(); + let (data_channel_sender, data_channel_tx) = mio_channel::channel(); let _server_loop_jh = std::thread::spawn(move || { if let Err(e) = server_loop( diff --git a/server/src/quiche_server_loop.rs b/server/src/quiche_server_loop.rs index b5495b7..69f23e7 100644 --- a/server/src/quiche_server_loop.rs +++ b/server/src/quiche_server_loop.rs @@ -1,24 +1,14 @@ -use std::{ - collections::HashMap, - net::SocketAddr, - sync::{ - atomic::{AtomicBool, AtomicU64, AtomicUsize}, - mpsc::{self, Sender}, - Arc, Mutex, RwLock, - }, - time::{Duration, Instant}, -}; +use std::{collections::HashMap, net::SocketAddr}; -use anyhow::bail; use itertools::Itertools; use quiche::ConnectionId; -use ring::rand::SystemRandom; +use ring::rand::{SecureRandom, SystemRandom}; use quic_geyser_common::{ channel_message::ChannelMessage, compression::CompressionType, config::QuicParameters, - defaults::{MAX_ALLOWED_PARTIAL_RESPONSES, MAX_DATAGRAM_SIZE}, + defaults::MAX_DATAGRAM_SIZE, filters::Filter, message::Message, types::{account::Account, block_meta::SlotMeta, slot_identifier::SlotIdentifier}, @@ -27,106 +17,135 @@ use quic_geyser_common::{ use quic_geyser_quiche_utils::{ quiche_reciever::{recv_message, ReadStreams}, quiche_sender::{handle_writable, send_message}, - quiche_utils::{get_next_unidi, mint_token, validate_token, write_to_socket, PartialResponses}, + quiche_utils::{get_next_unidi, mint_token, validate_token, PartialResponses}, }; use crate::configure_server::configure_server; -struct DispatchingData { - pub sender: Sender<(Vec, u8)>, - pub filters: Arc>>, - pub messages_in_queue: Arc, +const MAX_BUFFER_SIZE: usize = 65507; +const MAX_MESSAGES_PER_LOOP: usize = 32; + +pub struct Client { + pub conn: quiche::Connection, + pub client_id: ClientId, + pub partial_requests: ReadStreams, + pub partial_responses: PartialResponses, + pub max_datagram_size: usize, + pub loss_rate: f64, + pub max_send_burst: usize, + pub max_burst_was_set: bool, + pub filters: Vec, + pub next_stream: u64, + pub closed: bool, } - -type DispachingConnections = Arc, DispatchingData>>>; - -const MAX_BUFFER_SIZE: usize = 65535; +pub type ClientId = u64; +pub type ClientIdMap = HashMap, ClientId>; +pub type ClientMap = HashMap; pub fn server_loop( - quic_parameters: QuicParameters, + mut quic_parameters: QuicParameters, socket_addr: SocketAddr, - message_send_queue: mpsc::Receiver, + mut message_send_queue: mio_channel::Receiver, compression_type: CompressionType, stop_laggy_client: bool, ) -> anyhow::Result<()> { let maximum_concurrent_streams_id = u64::MAX; - let mut config = configure_server(quic_parameters)?; - - let mut socket = mio::net::UdpSocket::bind(socket_addr)?; - let mut poll = mio::Poll::new()?; - let mut events = mio::Events::with_capacity(1024); - - let pacing = if quic_parameters.enable_pacing { - match set_txtime_sockopt(&socket) { - Ok(_) => { - log::debug!("successfully set SO_TXTIME socket option"); - true - } - Err(e) => { - log::debug!("setsockopt failed {:?}", e); - false - } - } - } else { - false - }; - - poll.registry().register( - &mut socket, - mio::Token(0), - mio::Interest::READABLE | mio::Interest::WRITABLE, - )?; let mut buf = [0; MAX_BUFFER_SIZE]; - let mut out = [0; MAX_DATAGRAM_SIZE]; + let mut out = [0; MAX_BUFFER_SIZE]; + let mut pacing = false; + // Setup the event loop. + let mut poll = mio::Poll::new().unwrap(); + let mut events = mio::Events::with_capacity(1024); - let local_addr = socket.local_addr()?; + // Create the UDP listening socket, and register it with the event loop. + let mut socket = mio::net::UdpSocket::bind(socket_addr)?; + poll.registry() + .register(&mut socket, mio::Token(0), mio::Interest::READABLE) + .unwrap(); + + poll.registry() + .register( + &mut message_send_queue, + mio::Token(1), + mio::Interest::READABLE, + ) + .unwrap(); + + let max_datagram_size = MAX_DATAGRAM_SIZE; let enable_gso = if quic_parameters.enable_gso { - detect_gso(&socket, MAX_DATAGRAM_SIZE) + detect_gso(&socket, max_datagram_size) } else { false }; + if !quic_parameters.enable_pacing { + match set_txtime_sockopt(&socket) { + Ok(_) => { + pacing = true; + log::debug!("successfully set SO_TXTIME socket option"); + } + Err(e) => log::debug!("setsockopt failed {:?}", e), + }; + } + quic_parameters.enable_pacing = pacing; + + let mut config = configure_server(quic_parameters)?; + let rng = SystemRandom::new(); let conn_id_seed = ring::hmac::Key::generate(ring::hmac::HMAC_SHA256, &rng).unwrap(); - let mut clients: HashMap< - quiche::ConnectionId<'static>, - mio_channel::Sender<(quiche::RecvInfo, Vec)>, - > = HashMap::new(); + let mut next_client_id = 0; + let mut clients_ids = ClientIdMap::new(); + let mut clients = ClientMap::new(); - let (write_sender, mut write_reciver) = mio_channel::channel::<(quiche::SendInfo, Vec)>(); - - poll.registry() - .register(&mut write_reciver, mio::Token(1), mio::Interest::READABLE)?; - - let dispatching_connections: DispachingConnections = Arc::new(Mutex::new(HashMap::< - ConnectionId<'static>, - DispatchingData, - >::new())); - - create_dispatching_thread( - message_send_queue, - dispatching_connections.clone(), - compression_type, - ); + let mut continue_write = false; + let local_addr = socket.local_addr().unwrap(); loop { - poll.poll(&mut events, Some(Duration::from_millis(10)))?; + // Find the shorter timeout from all the active connections. + // + // TODO: use event loop that properly supports timers + let timeout = match continue_write { + true => Some(std::time::Duration::from_secs(0)), + + false => clients.values().filter_map(|c| c.conn.timeout()).min(), + }; + + poll.poll(&mut events, timeout).unwrap(); + + // Read incoming UDP packets from the socket and feed them to quiche, + // until there are no more packets to read. 'read: loop { + // If the event loop reported no events, it means that the timeout + // has expired, so handle it without attempting to read packets. We + // will then proceed with the send loop. + if events.is_empty() && !continue_write { + log::trace!("timed out"); + + clients.values_mut().for_each(|c| c.conn.on_timeout()); + + break 'read; + } + let (len, from) = match socket.recv_from(&mut buf) { Ok(v) => v, + Err(e) => { + // There are no more UDP packets to read, so end the read + // loop. if e.kind() == std::io::ErrorKind::WouldBlock { log::trace!("recv() would block"); break 'read; } - bail!("recv() failed: {:?}", e); + + panic!("recv() failed: {:?}", e); } }; - let pkt_buf = &mut buf[..len]; + log::trace!("got {} bytes", len); + let pkt_buf = &mut buf[..len]; // Parse the QUIC packet's header. let hdr = match quiche::Header::from_slice(pkt_buf, quiche::MAX_CONN_ID_LEN) { Ok(v) => v, @@ -137,10 +156,22 @@ pub fn server_loop( } }; - let conn_id = ring::hmac::sign(&conn_id_seed, &hdr.dcid); - let conn_id = &conn_id.as_ref()[..quiche::MAX_CONN_ID_LEN]; - let conn_id: ConnectionId<'static> = conn_id.to_vec().into(); - if !clients.contains_key(&hdr.dcid) && !clients.contains_key(&conn_id) { + log::trace!("got packet {:?}", hdr); + + let conn_id = if !cfg!(feature = "fuzzing") { + let conn_id = ring::hmac::sign(&conn_id_seed, &hdr.dcid); + let conn_id = &conn_id.as_ref()[..quiche::MAX_CONN_ID_LEN]; + conn_id.to_vec().into() + } else { + // When fuzzing use an all zero connection ID. + [0; quiche::MAX_CONN_ID_LEN].to_vec().into() + }; + + // Lookup a connection based on the packet's connection ID. If there + // is no connection matching, create a new one. + let client = if !clients_ids.contains_key(&hdr.dcid) + && !clients_ids.contains_key(&conn_id) + { if hdr.ty != quiche::Type::Initial { log::error!("Packet is not Initial"); continue 'read; @@ -148,14 +179,17 @@ pub fn server_loop( if !quiche::version_is_supported(hdr.version) { log::warn!("Doing version negotiation"); + let len = quiche::negotiate_version(&hdr.scid, &hdr.dcid, &mut out).unwrap(); let out = &out[..len]; if let Err(e) = socket.send_to(out, from) { if e.kind() == std::io::ErrorKind::WouldBlock { + log::trace!("send() would block"); break; } + panic!("send() failed: {:?}", e); } continue 'read; @@ -164,436 +198,364 @@ pub fn server_loop( let mut scid = [0; quiche::MAX_CONN_ID_LEN]; scid.copy_from_slice(&conn_id); - let scid = quiche::ConnectionId::from_ref(&scid); + #[allow(unused_assignments)] + let mut odcid = None; - // Token is always present in Initial packets. - let token = hdr.token.as_ref().unwrap(); + { + // Token is always present in Initial packets. + let token = hdr.token.as_ref().unwrap(); - // Do stateless retry if the client didn't send a token. - if token.is_empty() { - log::debug!("Doing stateless retry"); + // Do stateless retry if the client didn't send a token. + if token.is_empty() { + log::warn!("Doing stateless retry"); - let new_token = mint_token(&hdr, &from); + let scid = quiche::ConnectionId::from_ref(&scid); + let new_token = mint_token(&hdr, &from); - let len = quiche::retry( - &hdr.scid, - &hdr.dcid, - &scid, - &new_token, - hdr.version, - &mut out, - ) - .unwrap(); + let len = quiche::retry( + &hdr.scid, + &hdr.dcid, + &scid, + &new_token, + hdr.version, + &mut out, + ) + .unwrap(); - if write_to_socket(&socket, &out[..len], from) { - break; - } - continue 'read; - } + let out = &out[..len]; - let odcid = validate_token(&from, token); + if let Err(e) = socket.send_to(out, from) { + if e.kind() == std::io::ErrorKind::WouldBlock { + log::trace!("send() would block"); + break; + } - if odcid.is_none() { - log::error!("Invalid address validation token"); - continue 'read; - } - - if scid.len() != hdr.dcid.len() { - log::error!("Invalid destination connection ID"); - continue 'read; - } - - let scid = hdr.dcid.clone(); - - log::info!("New connection: dcid={:?} scid={:?}", hdr.dcid, scid); - - let mut conn = - quiche::accept(&scid, odcid.as_ref(), local_addr, from, &mut config)?; - - let recv_info = quiche::RecvInfo { - to: socket.local_addr().unwrap(), - from, - }; - // Process potentially coalesced packets. - match conn.recv(pkt_buf, recv_info) { - Ok(v) => v, - Err(e) => { - log::error!("{} recv failed: {:?}", conn.trace_id(), e); + panic!("send() failed: {:?}", e); + } continue 'read; } - }; - let (client_sender, client_reciver) = mio_channel::channel(); - let (client_message_sx, client_message_rx) = mpsc::channel(); - let messages_in_queue = Arc::new(AtomicUsize::new(0)); + odcid = validate_token(&from, token); - let filters = Arc::new(RwLock::new(Vec::new())); - create_client_task( + // The token was not valid, meaning the retry failed, so + // drop the packet. + if odcid.is_none() { + log::error!("Invalid address validation token"); + continue; + } + + if scid.len() != hdr.dcid.len() { + log::error!("Invalid destination connection ID"); + continue 'read; + } + + // Reuse the source connection ID we sent in the Retry + // packet, instead of changing it again. + scid.copy_from_slice(&hdr.dcid); + } + + let scid = quiche::ConnectionId::from_vec(scid.to_vec()); + + log::debug!("New connection: dcid={:?} scid={:?}", hdr.dcid, scid); + + #[allow(unused_mut)] + let mut conn = + quiche::accept(&scid, odcid.as_ref(), local_addr, from, &mut config).unwrap(); + + let client_id = next_client_id; + + let client = Client { conn, - client_reciver, - write_sender.clone(), - client_message_rx, - filters.clone(), - maximum_concurrent_streams_id, - stop_laggy_client, - messages_in_queue.clone(), - ); - let mut lk = dispatching_connections.lock().unwrap(); - lk.insert( - scid.clone(), - DispatchingData { - sender: client_message_sx, - filters, - messages_in_queue, - }, - ); - clients.insert(scid, client_sender); - } else { - // get the existing client - let client = match clients.get(&hdr.dcid) { - Some(v) => v, - None => clients - .get(&conn_id) - .expect("The client should exist in the map"), + client_id, + partial_requests: HashMap::new(), + partial_responses: HashMap::new(), + max_burst_was_set: false, + max_datagram_size, + loss_rate: 0.0, + max_send_burst: MAX_BUFFER_SIZE, + filters: vec![], + next_stream: 3, + closed: false, }; - let recv_info = quiche::RecvInfo { - to: socket.local_addr().unwrap(), - from, + clients.insert(client_id, client); + clients_ids.insert(scid.clone(), client_id); + + next_client_id += 1; + + clients.get_mut(&client_id).unwrap() + } else { + let cid = match clients_ids.get(&hdr.dcid) { + Some(v) => v, + + None => clients_ids.get(&conn_id).unwrap(), }; - if client.send((recv_info, pkt_buf.to_vec())).is_err() { - // client is closed - clients.remove(&hdr.dcid); - clients.remove(&conn_id); + + clients.get_mut(cid).unwrap() + }; + + let recv_info = quiche::RecvInfo { + to: local_addr, + from, + }; + + // Process potentially coalesced packets. + let read = match client.conn.recv(pkt_buf, recv_info) { + Ok(v) => v, + + Err(e) => { + log::error!("{} recv failed: {:?}", client.conn.trace_id(), e); + continue 'read; } }; - } - while let Ok((send_info, buffer)) = write_reciver.try_recv() { - if let Ok(written) = send_to( - &socket, - &buffer, - &send_info, - MAX_DATAGRAM_SIZE, - pacing, - enable_gso, - ) { - log::debug!("socket wrote to : {written:?} to {:?}", send_info.to); - } - } - } -} + log::debug!("{} processed {} bytes", client.conn.trace_id(), read); -#[allow(clippy::too_many_arguments)] -fn create_client_task( - connection: quiche::Connection, - mut receiver: mio_channel::Receiver<(quiche::RecvInfo, Vec)>, - sender: mio_channel::Sender<(quiche::SendInfo, Vec)>, - message_channel: mpsc::Receiver<(Vec, u8)>, - filters: Arc>>, - maximum_concurrent_streams_id: u64, - stop_laggy_client: bool, - messages_in_queue: Arc, -) { - std::thread::spawn(move || { - let mut partial_responses = PartialResponses::new(); - let mut read_streams = ReadStreams::new(); - let mut next_stream: u64 = 3; - let mut connection = connection; - let mut instance = Instant::now(); - let mut closed = false; - let mut buf = [0; MAX_DATAGRAM_SIZE]; - - let mut poll = mio::Poll::new().unwrap(); - let mut events = mio::Events::with_capacity(1024); - let max_allowed_partial_responses = MAX_ALLOWED_PARTIAL_RESPONSES as usize; - - poll.registry() - .register(&mut receiver, mio::Token(0), mio::Interest::READABLE) - .unwrap(); - - let number_of_loops = Arc::new(AtomicU64::new(0)); - let number_of_meesages_from_network = Arc::new(AtomicU64::new(0)); - let number_of_meesages_to_network = Arc::new(AtomicU64::new(0)); - let number_of_readable_streams = Arc::new(AtomicU64::new(0)); - let number_of_writable_streams = Arc::new(AtomicU64::new(0)); - let messages_added = Arc::new(AtomicU64::new(0)); - let quit = Arc::new(AtomicBool::new(false)); - - { - let number_of_loops = number_of_loops.clone(); - let number_of_meesages_from_network = number_of_meesages_from_network.clone(); - let number_of_meesages_to_network = number_of_meesages_to_network.clone(); - let number_of_readable_streams = number_of_readable_streams.clone(); - let number_of_writable_streams = number_of_writable_streams.clone(); - let messages_added = messages_added.clone(); - let messages_in_queue = messages_in_queue.clone(); - let quit = quit.clone(); - std::thread::spawn(move || { - while !quit.load(std::sync::atomic::Ordering::Relaxed) { - std::thread::sleep(Duration::from_secs(1)); - log::info!("---------------------------------"); - log::info!( - "number of loop : {}", - number_of_loops.swap(0, std::sync::atomic::Ordering::Relaxed) - ); - log::info!( - "number of packets read : {}", - number_of_meesages_from_network - .swap(0, std::sync::atomic::Ordering::Relaxed) - ); - log::info!( - "number of packets write : {}", - number_of_meesages_to_network.swap(0, std::sync::atomic::Ordering::Relaxed) - ); - log::info!( - "number_of_readable_streams : {}", - number_of_readable_streams.swap(0, std::sync::atomic::Ordering::Relaxed) - ); - log::info!( - "number_of_writable_streams : {}", - number_of_writable_streams.swap(0, std::sync::atomic::Ordering::Relaxed) - ); - log::info!( - "messages_added : {}", - messages_added.swap(0, std::sync::atomic::Ordering::Relaxed) - ); - log::info!( - "messages in queue to be sent : {}", - messages_in_queue.load(std::sync::atomic::Ordering::Relaxed) - ); - } - }); - } - - loop { - number_of_loops.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - poll.poll(&mut events, Some(Duration::from_millis(1))) - .unwrap(); - - if !events.is_empty() { - while let Ok((info, mut buf)) = receiver.try_recv() { - let buf = buf.as_mut_slice(); - match connection.recv(buf, info) { - Ok(_) => { - number_of_meesages_from_network - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); - } - Err(e) => { - log::error!("{} recv failed: {:?}", connection.trace_id(), e); - break; - } - }; - } - continue; - } - - if connection.is_in_early_data() || connection.is_established() { - // Process all readable streams. - for stream in connection.readable() { - number_of_readable_streams.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - let message = recv_message(&mut connection, &mut read_streams, stream); - match message { - Ok(Some(message)) => match message { - Message::Filters(mut f) => { - let mut filter_lk = filters.write().unwrap(); - filter_lk.append(&mut f); - } - _ => { - log::error!("unknown message from the client"); - } - }, - Ok(None) => {} - Err(e) => { - log::error!("Error recieving message : {e}") - } - } - } - } - - if !connection.is_closed() - && (connection.is_established() || connection.is_in_early_data()) + // Create a new application protocol session as soon as the QUIC + // connection is established. + if !client.max_burst_was_set + && (client.conn.is_in_early_data() || client.conn.is_established()) { - for stream_id in connection.writable() { - number_of_writable_streams.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - if let Err(e) = - handle_writable(&mut connection, &mut partial_responses, stream_id) - { - log::error!("Error writing {e:?}"); + client.max_burst_was_set = true; + // Update max_datagram_size after connection established. + client.max_datagram_size = client.conn.max_send_udp_payload_size(); + } + + let conn = &mut client.conn; + let partial_responses = &mut client.partial_responses; + let read_streams = &mut client.partial_requests; + let filters = &mut client.filters; + // Handle writable streams. + for stream_id in conn.writable() { + if let Err(e) = handle_writable(conn, partial_responses, stream_id) { + log::error!("Error writing {e:?}"); + } + } + + for stream in conn.readable() { + let message = recv_message(conn, read_streams, stream); + match message { + Ok(Some(message)) => match message { + Message::Filters(mut f) => { + log::debug!("filters added .."); + filters.append(&mut f); + } + _ => { + log::error!("unknown message from the client"); + } + }, + Ok(None) => {} + Err(e) => { + log::error!("Error recieving message : {e}") } } + } - while partial_responses.len() < max_allowed_partial_responses { - let close = match message_channel.try_recv() { - Ok((message, priority)) => { - messages_in_queue.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); - let stream_id = next_stream; - next_stream = - get_next_unidi(stream_id, true, maximum_concurrent_streams_id); + handle_path_events(client); - if let Err(e) = connection.stream_priority(stream_id, priority, true) { - if !closed { - log::error!( - "Unable to set priority for the stream {}, error {}", - stream_id, - e - ); - } - true - } else { - messages_added.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - match send_message( - &mut connection, - &mut partial_responses, - stream_id, - &message, - ) { - Ok(_) => false, - Err(quiche::Error::Done) => { - // done writing / queue is full - break; - } - Err(e) => { - log::error!("error sending message : {e:?}"); - true - } - } - } + // See whether source Connection IDs have been retired. + while let Some(retired_scid) = client.conn.retired_scid_next() { + log::info!("Retiring source CID {:?}", retired_scid); + clients_ids.remove(&retired_scid); + } + + // Provides as many CIDs as possible. + while client.conn.scids_left() > 0 { + let (scid, reset_token) = generate_cid_and_reset_token(&rng); + if client.conn.new_scid(&scid, reset_token, false).is_err() { + break; + } + + clients_ids.insert(scid, client.client_id); + } + } + + for _ in 0..MAX_MESSAGES_PER_LOOP { + if let Ok(message) = message_send_queue.try_recv() { + let dispatching_connections = clients + .iter() + .filter_map(|(id, client)| { + if client.filters.iter().any(|x| x.allows(&message)) { + Some(*id) + } else { + None } - Err(e) => { - match e { - mpsc::TryRecvError::Empty => { - break; - } - mpsc::TryRecvError::Disconnected => { - // too many message the connection is lagging - log::error!("channel disconnected by dispatcher"); - true - } - } + }) + .collect_vec(); + + if !dispatching_connections.is_empty() { + let (message, priority) = match message { + ChannelMessage::Account(account, slot) => { + let slot_identifier = SlotIdentifier { slot }; + let geyser_account = Account::new( + account.pubkey, + account.account, + compression_type, + slot_identifier, + account.write_version, + ); + + (Message::AccountMsg(geyser_account), 3) } + ChannelMessage::Slot(slot, parent, commitment_config) => ( + Message::SlotMsg(SlotMeta { + slot, + parent, + commitment_config, + }), + 1, + ), + ChannelMessage::BlockMeta(block_meta) => { + (Message::BlockMetaMsg(block_meta), 2) + } + ChannelMessage::Transaction(transaction) => { + (Message::TransactionMsg(transaction), 3) + } + ChannelMessage::Block(block) => (Message::BlockMsg(block), 2), }; + let binary = bincode::serialize(&message) + .expect("Message should be serializable in binary"); + for id in dispatching_connections.iter() { + let client = clients.get_mut(id).unwrap(); - if close && !closed && stop_laggy_client { - if let Err(e) = connection.close(true, 1, b"laggy client") { - if e != quiche::Error::Done { - log::error!("error closing client : {}", e); + let stream_id = client.next_stream; + client.next_stream = + get_next_unidi(stream_id, true, maximum_concurrent_streams_id); + + if let Err(e) = client.conn.stream_priority(stream_id, priority, true) { + if !client.closed && stop_laggy_client { + client.closed = true; + log::error!("Error setting stream : {e:?}"); + let _ = client.conn.close(true, 0, b"laggy"); } } else { - log::info!("Stopping laggy client : {}", connection.trace_id(),); + log::debug!("sending message to {id}"); + match send_message( + &mut client.conn, + &mut client.partial_responses, + stream_id, + &binary, + ) { + Ok(_) => {} + Err(quiche::Error::Done) => { + // done writing / queue is full + break; + } + Err(e) => { + if !client.closed && stop_laggy_client { + client.closed = true; + log::error!("Error sending message : {e:?}"); + let _ = client.conn.close(true, 0, b"laggy"); + } + } + } } - closed = true; - break; } } - } - - if instance.elapsed() > Duration::from_secs(2) { - instance = Instant::now(); - connection.on_timeout(); - } - - let max_burst = connection.send_quantum(); - let mut total_send = 0; - while total_send < max_burst { - match connection.send(&mut buf[0..MAX_DATAGRAM_SIZE]) { - Ok((len, send_info)) => { - sender.send((send_info, buf[..len].to_vec())).unwrap(); - number_of_meesages_to_network - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); - total_send += len; - } - Err(quiche::Error::Done) => { - break; - } - Err(e) => { - log::error!( - "{} send failed: {:?}, closing connection", - connection.trace_id(), - e - ); - connection.close(false, 0x1, b"fail").ok(); - break; - } - }; - } - - if connection.is_closed() { - log::info!( - "{} connection closed {:?}", - connection.trace_id(), - connection.stats() - ); + } else { break; } } - quit.store(true, std::sync::atomic::Ordering::Relaxed); - }); -} -fn create_dispatching_thread( - message_send_queue: mpsc::Receiver, - dispatching_connections: DispachingConnections, - compression_type: CompressionType, -) { - std::thread::spawn(move || { - while let Ok(message) = message_send_queue.recv() { - let mut dispatching_connections_lk = dispatching_connections.lock().unwrap(); + // Generate outgoing QUIC packets for all active connections and send + // them on the UDP socket, until quiche reports that there are no more + // packets to be sent. + continue_write = false; + for client in clients.values_mut() { + // Reduce max_send_burst by 25% if loss is increasing more than 0.1%. + let loss_rate = client.conn.stats().lost as f64 / client.conn.stats().sent as f64; + if loss_rate > client.loss_rate + 0.001 { + client.max_send_burst = client.max_send_burst / 4 * 3; + // Minimun bound of 10xMSS. + client.max_send_burst = client.max_send_burst.max(client.max_datagram_size * 10); + client.loss_rate = loss_rate; + } - let dispatching_connections = dispatching_connections_lk - .iter() - .filter_map(|(id, x)| { - let filters = x.filters.read().unwrap(); - if filters.iter().any(|x| x.allows(&message)) { - Some(id.clone()) - } else { - None - } - }) - .collect_vec(); + let max_send_burst = client.conn.send_quantum().min(client.max_send_burst) + / client.max_datagram_size + * client.max_datagram_size; + let mut total_write = 0; + let mut dst_info = None; - if !dispatching_connections.is_empty() { - let (message, priority) = match message { - ChannelMessage::Account(account, slot) => { - let slot_identifier = SlotIdentifier { slot }; - let geyser_account = Account::new( - account.pubkey, - account.account, - compression_type, - slot_identifier, - account.write_version, - ); + while total_write < max_send_burst { + let (write, send_info) = + match client.conn.send(&mut out[total_write..max_send_burst]) { + Ok(v) => v, - (Message::AccountMsg(geyser_account), 3) - } - ChannelMessage::Slot(slot, parent, commitment_config) => ( - Message::SlotMsg(SlotMeta { - slot, - parent, - commitment_config, - }), - 1, - ), - ChannelMessage::BlockMeta(block_meta) => (Message::BlockMetaMsg(block_meta), 2), - ChannelMessage::Transaction(transaction) => { - (Message::TransactionMsg(transaction), 3) - } - ChannelMessage::Block(block) => (Message::BlockMsg(block), 2), - }; - let binary = - bincode::serialize(&message).expect("Message should be serializable in binary"); - for id in dispatching_connections.iter() { - let data = dispatching_connections_lk.get(id).unwrap(); - data.messages_in_queue - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); - if data.sender.send((binary.clone(), priority)).is_err() { - // client is closed - dispatching_connections_lk.remove(id); - } + Err(quiche::Error::Done) => { + log::trace!("{} done writing", client.conn.trace_id()); + break; + } + + Err(e) => { + log::error!("{} send failed: {:?}", client.conn.trace_id(), e); + + client.conn.close(false, 0x1, b"fail").ok(); + break; + } + }; + + total_write += write; + + // Use the first packet time to send, not the last. + let _ = dst_info.get_or_insert(send_info); + + if write < client.max_datagram_size { + continue_write = true; + break; } } + + if total_write == 0 || dst_info.is_none() { + break; + } + + if let Err(e) = send_to( + &socket, + &out[..total_write], + &dst_info.unwrap(), + client.max_datagram_size, + pacing, + enable_gso, + ) { + if e.kind() == std::io::ErrorKind::WouldBlock { + log::trace!("send() would block"); + break; + } + + panic!("send_to() failed: {:?}", e); + } + + log::trace!("{} written {} bytes", client.conn.trace_id(), total_write); + + if total_write >= max_send_burst { + log::trace!("{} pause writing", client.conn.trace_id(),); + continue_write = true; + break; + } } - }); + + // Garbage collect closed connections. + clients.retain(|_, ref mut c| { + log::trace!("Collecting garbage"); + + if c.conn.is_closed() { + log::info!( + "{} connection collected {:?} {:?}", + c.conn.trace_id(), + c.conn.stats(), + c.conn.path_stats().collect::>() + ); + + for id in c.conn.source_ids() { + let id_owned = id.clone().into_owned(); + clients_ids.remove(&id_owned); + } + } + + !c.conn.is_closed() + }); + } } fn set_txtime_sockopt(sock: &mio::net::UdpSocket) -> std::io::Result<()> { @@ -710,3 +672,83 @@ pub fn send_to( Ok(written) } + +/// Generate a new pair of Source Connection ID and reset token. +pub fn generate_cid_and_reset_token( + rng: &T, +) -> (quiche::ConnectionId<'static>, u128) { + let mut scid = [0; quiche::MAX_CONN_ID_LEN]; + rng.fill(&mut scid).unwrap(); + let scid = scid.to_vec().into(); + let mut reset_token = [0; 16]; + rng.fill(&mut reset_token).unwrap(); + let reset_token = u128::from_be_bytes(reset_token); + (scid, reset_token) +} + +fn handle_path_events(client: &mut Client) { + while let Some(qe) = client.conn.path_event_next() { + match qe { + quiche::PathEvent::New(local_addr, peer_addr) => { + log::info!( + "{} Seen new path ({}, {})", + client.conn.trace_id(), + local_addr, + peer_addr + ); + + // Directly probe the new path. + client + .conn + .probe_path(local_addr, peer_addr) + .expect("cannot probe"); + } + + quiche::PathEvent::Validated(local_addr, peer_addr) => { + log::info!( + "{} Path ({}, {}) is now validated", + client.conn.trace_id(), + local_addr, + peer_addr + ); + } + + quiche::PathEvent::FailedValidation(local_addr, peer_addr) => { + log::info!( + "{} Path ({}, {}) failed validation", + client.conn.trace_id(), + local_addr, + peer_addr + ); + } + + quiche::PathEvent::Closed(local_addr, peer_addr) => { + log::info!( + "{} Path ({}, {}) is now closed and unusable", + client.conn.trace_id(), + local_addr, + peer_addr + ); + } + + quiche::PathEvent::ReusedSourceConnectionId(cid_seq, old, new) => { + log::info!( + "{} Peer reused cid seq {} (initially {:?}) on {:?}", + client.conn.trace_id(), + cid_seq, + old, + new + ); + } + + quiche::PathEvent::PeerMigrated(local_addr, peer_addr) => { + log::info!( + "{} Connection migrated to ({}, {})", + client.conn.trace_id(), + local_addr, + peer_addr + ); + } + } + } +}