diff --git a/common/src/quic/quiche_client_loop.rs b/common/src/quic/quiche_client_loop.rs index 5ca2b65..b019b79 100644 --- a/common/src/quic/quiche_client_loop.rs +++ b/common/src/quic/quiche_client_loop.rs @@ -3,8 +3,10 @@ use std::net::SocketAddr; use crate::{ message::Message, quic::{ - configure_server::MAX_DATAGRAM_SIZE, quiche_reciever::recv_message, - quiche_sender::send_message, quiche_utils::get_next_unidi, + configure_server::MAX_DATAGRAM_SIZE, + quiche_reciever::{recv_message, ReadStreams}, + quiche_sender::{handle_writable, send_message}, + quiche_utils::{get_next_unidi, PartialResponses}, }, }; use anyhow::bail; @@ -52,6 +54,8 @@ pub fn client_loop( let mut current_stream_id = 3; let mut buf = [0; 65535]; let mut out = [0; MAX_DATAGRAM_SIZE]; + let mut partial_responses = PartialResponses::new(); + let mut read_streams = ReadStreams::new(); 'client: loop { poll.poll(&mut events, conn.timeout()).unwrap(); @@ -110,17 +114,24 @@ pub fn client_loop( } // io events - for stream in conn.readable() { - let message = recv_message(&mut conn, stream); + for stream_id in conn.readable() { + let message = recv_message(&mut conn, &mut read_streams, stream_id); match message { - Ok(message) => { + Ok(Some(message)) => { message_recv_queue.send(message).unwrap(); } + Ok(None) => { + // do nothing + } Err(e) => { log::error!("Error recieving message : {e}") } } } + + for stream_id in conn.writable() { + handle_writable(&mut conn, &mut partial_responses, stream_id); + } } // chanel updates @@ -128,8 +139,15 @@ pub fn client_loop( // channel events if let Ok(message_to_send) = message_send_queue.try_recv() { current_stream_id = get_next_unidi(current_stream_id, false); - if let Err(e) = send_message(&mut conn, current_stream_id, &message_to_send) { - log::error!("Error sending message on stream : {}", e); + let binary = + bincode::serialize(&message_to_send).expect("Message should be serializable"); + if let Err(e) = send_message( + &mut conn, + &mut partial_responses, + current_stream_id, + &binary, + ) { + log::error!("Sending failed with error {e:?}"); } } } diff --git a/common/src/quic/quiche_reciever.rs b/common/src/quic/quiche_reciever.rs index 22e29a9..5fee9fb 100644 --- a/common/src/quic/quiche_reciever.rs +++ b/common/src/quic/quiche_reciever.rs @@ -1,3 +1,5 @@ +use std::collections::BTreeMap; + use anyhow::bail; use crate::message::Message; @@ -8,11 +10,17 @@ pub fn convert_binary_to_message(bytes: Vec) -> anyhow::Result { Ok(bincode::deserialize::(&bytes)?) } +pub type ReadStreams = BTreeMap>; + pub fn recv_message( connection: &mut quiche::Connection, + read_streams: &mut ReadStreams, stream_id: u64, -) -> anyhow::Result { - let mut total_buf = vec![]; +) -> anyhow::Result> { + let mut total_buf = match read_streams.remove(&stream_id) { + Some(buf) => buf, + None => vec![], + }; loop { let mut buf = [0; MAX_DATAGRAM_SIZE]; // 10kk buffer size match connection.stream_recv(stream_id, &mut buf) { @@ -21,401 +29,22 @@ pub fn recv_message( total_buf.extend_from_slice(&buf[..read]); if fin { log::debug!("fin stream : {}", stream_id); - return Ok(bincode::deserialize::(&total_buf)?); + return Ok(Some(bincode::deserialize::(&total_buf)?)); } } Err(e) => { - bail!("Fail to read from stream {stream_id} : error : {e}"); + match &e { + quiche::Error::Done => { + // will be tried again later + log::debug!("stream saved : {}. len: {}", stream_id, total_buf.len()); + read_streams.insert(stream_id, total_buf); + return Ok(None); + } + _ => { + bail!("read error on stream : {}, error: {}", stream_id, e); + } + } } } } } - -#[cfg(test)] -mod tests { - use std::{ - net::{IpAddr, Ipv4Addr, SocketAddr}, - str::FromStr, - thread::sleep, - time::Duration, - }; - - use itertools::Itertools; - use quiche::ConnectionId; - use ring::rand::{SecureRandom, SystemRandom}; - use std::net::UdpSocket; - - use crate::{ - message::Message, - quic::{ - configure_client::configure_client, - configure_server::{configure_server, MAX_DATAGRAM_SIZE}, - quiche_reciever::recv_message, - quiche_sender::send_message, - quiche_utils::{mint_token, validate_token}, - }, - types::account::Account, - }; - - #[test] - fn test_send_and_recieve_of_small_account() { - let mut config = configure_server(1, 100000, 1).unwrap(); - - // Setup the event loop. - let socket_addr = SocketAddr::from_str("0.0.0.0:0").unwrap(); - let socket = UdpSocket::bind(socket_addr).unwrap(); - - let port = socket.local_addr().unwrap().port(); - let local_addr = socket.local_addr().unwrap(); - - let account = Account::get_account_for_test(123456, 2); - let message = Message::AccountMsg(account); - - let jh = { - let message_to_send = message.clone(); - let server_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port); - std::thread::spawn(move || { - let mut client_config = configure_client(1, 1000, 10).unwrap(); - - // Setup the event loop. - let socket_addr: SocketAddr = "0.0.0.0:0".parse().unwrap(); - let socket = std::net::UdpSocket::bind(socket_addr).unwrap(); - - let mut scid = [0; quiche::MAX_CONN_ID_LEN]; - SystemRandom::new().fill(&mut scid[..]).unwrap(); - - let scid = quiche::ConnectionId::from_ref(&scid); - - // Get local address. - let local_addr = socket.local_addr().unwrap(); - println!("connecting"); - let mut conn = - quiche::connect(None, &scid, local_addr, server_addr, &mut client_config) - .unwrap(); - let mut out = [0; MAX_DATAGRAM_SIZE]; - println!("sending message"); - let (write, send_info) = conn.send(&mut out).expect("initial send failed"); - - while let Err(e) = socket.send_to(&out[..write], send_info.to) { - panic!("send() failed: {:?}", e); - } - - while !conn.is_established() { - sleep(Duration::from_millis(100)); - } - - send_message(&mut conn, 4, &message_to_send).unwrap(); - - // Generate outgoing QUIC packets and send them on the UDP socket, until - // quiche reports that there are no more packets to be sent. - loop { - let (write, send_info) = match conn.send(&mut out) { - Ok(v) => v, - - Err(quiche::Error::Done) => { - break; - } - - Err(e) => { - conn.close(false, 0x1, b"fail").ok(); - break; - } - }; - - if let Err(e) = socket.send_to(&out[..write], send_info.to) { - if e.kind() == std::io::ErrorKind::WouldBlock { - break; - } - - panic!("send() failed: {:?}", e); - } - } - conn.close(true, 0, b"not required").unwrap(); - }) - }; - - loop { - let mut buf = [0; 65535]; - let mut out = [0; MAX_DATAGRAM_SIZE]; - - let (len, from) = match socket.recv_from(&mut buf) { - Ok(v) => v, - Err(e) => { - panic!("recv() failed: {:?}", e); - } - }; - println!("recieved first packet"); - - log::debug!("got {} bytes", len); - - let pkt_buf = &mut buf[..len]; - - // Parse the QUIC packet's header. - let hdr = match quiche::Header::from_slice(pkt_buf, quiche::MAX_CONN_ID_LEN) { - Ok(header) => header, - - Err(e) => { - panic!("Parsing packet header failed: {:?}", e); - } - }; - let rng = SystemRandom::new(); - let conn_id_seed = ring::hmac::Key::generate(ring::hmac::HMAC_SHA256, &rng).unwrap(); - let conn_id = ring::hmac::sign(&conn_id_seed, &hdr.dcid); - let conn_id = &conn_id.as_ref()[..quiche::MAX_CONN_ID_LEN]; - let conn_id: ConnectionId<'static> = conn_id.to_vec().into(); - - if hdr.ty != quiche::Type::Initial { - panic!("Packet is not Initial"); - } - - if !quiche::version_is_supported(hdr.version) { - log::warn!("Doing version negotiation"); - - let len = quiche::negotiate_version(&hdr.scid, &hdr.dcid, &mut out).unwrap(); - - let out = &out[..len]; - - if let Err(e) = socket.send_to(out, from) { - panic!("send() failed: {:?}", e); - } - } - - let mut scid = [0; quiche::MAX_CONN_ID_LEN]; - scid.copy_from_slice(&conn_id); - - let scid = quiche::ConnectionId::from_ref(&scid); - - // Token is always present in Initial packets. - let token = hdr.token.as_ref().unwrap(); - - println!("token: {}", token.iter().map(|x| x.to_string()).join(", ")); - - // Do stateless retry if the client didn't send a token. - if token.is_empty() { - log::warn!("Doing stateless retry"); - - let new_token = mint_token(&hdr, &from); - - let len = quiche::retry( - &hdr.scid, - &hdr.dcid, - &scid, - &new_token, - hdr.version, - &mut out, - ) - .unwrap(); - - let out = &out[..len]; - - if let Err(e) = socket.send_to(out, from) { - panic!("send() failed: {:?}", e); - } else { - continue; - } - } - let odcid = validate_token(&from, token); - // The token was not valid, meaning the retry failed, so - // drop the packet. - if odcid.is_none() { - panic!("Invalid address validation token"); - } - - if scid.len() != hdr.dcid.len() { - panic!("Invalid destination connection ID"); - } - - // Reuse the source connection ID we sent in the Retry packet, - // instead of changing it again. - let scid = hdr.dcid.clone(); - - log::debug!("New connection: dcid={:?} scid={:?}", hdr.dcid, scid); - - let mut conn = - quiche::accept(&scid, odcid.as_ref(), local_addr, from, &mut config).unwrap(); - - let stream_id = { - let mut stream_id = 0; - loop { - let readable = conn.stream_readable_next(); - match readable { - Some(id) => { - stream_id = id; - break; - } - None => { - sleep(Duration::from_millis(100)); - } - } - } - stream_id - }; - let recvd_message = recv_message(&mut conn, stream_id).unwrap(); - assert_eq!(recvd_message, message); - std::thread::sleep(Duration::from_secs(1)); - assert_eq!(conn.is_closed(), true); - jh.join().unwrap(); - break; - } - } - - #[test] - fn test_send_and_recieve_of_large_account() { - let mut config = configure_server(1, 100000, 1).unwrap(); - - // Setup the event loop. - let socket_addr = SocketAddr::from_str("0.0.0.0:0").unwrap(); - let socket = UdpSocket::bind(socket_addr).unwrap(); - - let port = socket.local_addr().unwrap().port(); - let local_addr = socket.local_addr().unwrap(); - - let account = Account::get_account_for_test(123456, 10_000_000); - let message = Message::AccountMsg(account); - - let jh = { - let message = message.clone(); - let server_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port); - std::thread::spawn(move || { - let mut client_config = configure_client(1, 12_000_000, 10).unwrap(); - - // Setup the event loop. - let socket_addr: SocketAddr = "0.0.0.0:0".parse().unwrap(); - let socket = std::net::UdpSocket::bind(socket_addr).unwrap(); - - let mut scid = [0; quiche::MAX_CONN_ID_LEN]; - SystemRandom::new().fill(&mut scid[..]).unwrap(); - - let scid = quiche::ConnectionId::from_ref(&scid); - - // Get local address. - let local_addr = socket.local_addr().unwrap(); - println!("connecting"); - let mut conn = - quiche::connect(None, &scid, local_addr, server_addr, &mut client_config) - .unwrap(); - let mut out = [0; MAX_DATAGRAM_SIZE]; - println!("sending message"); - let (write, send_info) = conn.send(&mut out).expect("initial send failed"); - - while let Err(e) = socket.send_to(&out[..write], send_info.to) { - panic!("send() failed: {:?}", e); - } - - let stream_id = conn.stream_readable_next().unwrap(); - let recvd_message = recv_message(&mut conn, stream_id).unwrap(); - assert_eq!(recvd_message, message); - std::thread::sleep(Duration::from_secs(1)); - assert_eq!(conn.is_closed(), true); - }) - }; - - loop { - let mut buf = [0; 65535]; - let mut out = [0; MAX_DATAGRAM_SIZE]; - - let (len, from) = match socket.recv_from(&mut buf) { - Ok(v) => v, - Err(e) => { - panic!("recv() failed: {:?}", e); - } - }; - println!("recieved first packet"); - - log::debug!("got {} bytes", len); - - let pkt_buf = &mut buf[..len]; - - // Parse the QUIC packet's header. - let hdr = match quiche::Header::from_slice(pkt_buf, quiche::MAX_CONN_ID_LEN) { - Ok(header) => header, - - Err(e) => { - panic!("Parsing packet header failed: {:?}", e); - } - }; - let rng = SystemRandom::new(); - let conn_id_seed = ring::hmac::Key::generate(ring::hmac::HMAC_SHA256, &rng).unwrap(); - let conn_id = ring::hmac::sign(&conn_id_seed, &hdr.dcid); - let conn_id = &conn_id.as_ref()[..quiche::MAX_CONN_ID_LEN]; - let conn_id: ConnectionId<'static> = conn_id.to_vec().into(); - - if hdr.ty != quiche::Type::Initial { - panic!("Packet is not Initial"); - } - - if !quiche::version_is_supported(hdr.version) { - log::warn!("Doing version negotiation"); - - let len = quiche::negotiate_version(&hdr.scid, &hdr.dcid, &mut out).unwrap(); - - let out = &out[..len]; - - if let Err(e) = socket.send_to(out, from) { - panic!("send() failed: {:?}", e); - } - } - - let mut scid = [0; quiche::MAX_CONN_ID_LEN]; - scid.copy_from_slice(&conn_id); - - let scid = quiche::ConnectionId::from_ref(&scid); - - // Token is always present in Initial packets. - let token = hdr.token.as_ref().unwrap(); - - println!("token: {}", token.iter().map(|x| x.to_string()).join(", ")); - - // Do stateless retry if the client didn't send a token. - if token.is_empty() { - log::warn!("Doing stateless retry"); - - let new_token = mint_token(&hdr, &from); - - let len = quiche::retry( - &hdr.scid, - &hdr.dcid, - &scid, - &new_token, - hdr.version, - &mut out, - ) - .unwrap(); - - let out = &out[..len]; - - if let Err(e) = socket.send_to(out, from) { - panic!("send() failed: {:?}", e); - } else { - continue; - } - } - let odcid = validate_token(&from, token); - // The token was not valid, meaning the retry failed, so - // drop the packet. - if odcid.is_none() { - panic!("Invalid address validation token"); - } - - if scid.len() != hdr.dcid.len() { - panic!("Invalid destination connection ID"); - } - - // Reuse the source connection ID we sent in the Retry packet, - // instead of changing it again. - let scid = hdr.dcid.clone(); - - log::debug!("New connection: dcid={:?} scid={:?}", hdr.dcid, scid); - - let mut conn = - quiche::accept(&scid, odcid.as_ref(), local_addr, from, &mut config).unwrap(); - - let stream_id = conn.stream_writable_next().unwrap(); - send_message(&mut conn, stream_id, &message).unwrap(); - conn.close(true, 0, b"not required").unwrap(); - - jh.join().unwrap(); - break; - } - } -} diff --git a/common/src/quic/quiche_sender.rs b/common/src/quic/quiche_sender.rs index ad89104..fda7d7f 100644 --- a/common/src/quic/quiche_sender.rs +++ b/common/src/quic/quiche_sender.rs @@ -1,9 +1,9 @@ -use itertools::Itertools; +use anyhow::bail; use quiche::Connection; -use crate::message::Message; +use crate::{message::Message, quic::quiche_utils::PartialResponse}; -use super::configure_server::MAX_DATAGRAM_SIZE; +use super::quiche_utils::PartialResponses; pub fn convert_to_binary(message: &Message) -> anyhow::Result> { Ok(bincode::serialize(&message)?) @@ -11,14 +11,64 @@ pub fn convert_to_binary(message: &Message) -> anyhow::Result> { pub fn send_message( connection: &mut Connection, + partial_responses: &mut PartialResponses, stream_id: u64, - message: &Message, + message: &Vec, ) -> anyhow::Result<()> { - let binary = convert_to_binary(message)?; - let chunks = binary.chunks(MAX_DATAGRAM_SIZE).collect_vec(); - let nb_chunks = chunks.len(); - for (index, buf) in chunks.iter().enumerate() { - connection.stream_send(stream_id, buf, index + 1 == nb_chunks)?; + let written = match connection.stream_send(stream_id, &message, true) { + Ok(v) => v, + + Err(quiche::Error::Done) => 0, + + Err(e) => { + bail!("{} stream send failed {:?}", stream_id, e); + } + }; + log::debug!("dispatched {} on stream id : {}", written, stream_id); + + if written < message.len() { + let response = PartialResponse { + binary: message[written..].to_vec(), + written, + }; + partial_responses.insert(stream_id, response); } Ok(()) } + +/// Handles newly writable streams. +pub fn handle_writable( + conn: &mut quiche::Connection, + partial_responses: &mut PartialResponses, + stream_id: u64, +) { + log::debug!("{} stream {} is writable", conn.trace_id(), stream_id); + + if !partial_responses.contains_key(&stream_id) { + return; + } + + let resp = partial_responses + .get_mut(&stream_id) + .expect("should have a stream id"); + let body = &resp.binary; + + let written = match conn.stream_send(stream_id, body, true) { + Ok(v) => v, + + Err(quiche::Error::Done) => 0, + + Err(e) => { + partial_responses.remove(&stream_id); + + log::error!("{} stream send failed {:?}", conn.trace_id(), e); + return; + } + }; + if resp.written == resp.binary.len() { + partial_responses.remove(&stream_id); + } else { + resp.binary = resp.binary[written..].to_vec(); + resp.written += written; + } +} diff --git a/common/src/quic/quiche_server_loop.rs b/common/src/quic/quiche_server_loop.rs index 0de2a97..a553fe4 100644 --- a/common/src/quic/quiche_server_loop.rs +++ b/common/src/quic/quiche_server_loop.rs @@ -2,7 +2,7 @@ use std::{collections::HashMap, net::SocketAddr}; use itertools::Itertools; use mio::net::UdpSocket; -use quiche::ConnectionId; +use quiche::{Connection, ConnectionId}; use ring::rand::SystemRandom; use crate::{ @@ -12,21 +12,21 @@ use crate::{ message::Message, quic::{ quiche_reciever::recv_message, + quiche_sender::{handle_writable, send_message}, quiche_utils::{get_next_unidi, mint_token, validate_token}, }, types::{account::Account, block_meta::SlotMeta, slot_identifier::SlotIdentifier}, }; -use super::{configure_server::MAX_DATAGRAM_SIZE, quiche_sender::convert_to_binary}; - -struct PartialResponse { - pub binary: Vec, - pub written: usize, -} +use super::{ + configure_server::MAX_DATAGRAM_SIZE, quiche_reciever::ReadStreams, + quiche_utils::PartialResponses, +}; struct Client { - pub conn: quiche::Connection, - pub partial_responses: HashMap, + pub conn: Connection, + pub partial_responses: PartialResponses, + pub read_streams: ReadStreams, pub filters: Vec, pub next_stream: u64, } @@ -191,7 +191,8 @@ pub fn server_loop( let client = Client { conn, - partial_responses: HashMap::new(), + partial_responses: PartialResponses::new(), + read_streams: ReadStreams::new(), filters: Vec::new(), next_stream: get_next_unidi(0, true), }; @@ -226,14 +227,15 @@ pub fn server_loop( if client.conn.is_in_early_data() || client.conn.is_established() { for stream_id in client.conn.writable() { - handle_writable(client, stream_id); + handle_writable(&mut client.conn, &mut client.partial_responses, stream_id); } // Process all readable streams. for stream in client.conn.readable() { - let message = recv_message(&mut client.conn, stream); + let message = + recv_message(&mut client.conn, &mut client.read_streams, stream); match message { - Ok(message) => { + Ok(Some(message)) => { match message { Message::Filters(mut filters) => { client.filters.append(&mut filters); @@ -246,6 +248,7 @@ pub fn server_loop( } } } + Ok(None) => {} Err(e) => { log::error!("Error recieving message : {e}") } @@ -292,7 +295,7 @@ pub fn server_loop( Message::TransactionMsg(transaction) } }; - let binary = convert_to_binary(&message) + let binary = bincode::serialize(&message) .expect("Message should be serializable in binary"); for client in dispatch_to { let stream_id = client.next_stream; @@ -302,28 +305,13 @@ pub fn server_loop( binary.len(), stream_id ); - let written = match client.conn.stream_send(stream_id, &binary, true) { - Ok(v) => v, - - Err(quiche::Error::Done) => 0, - - Err(e) => { - log::error!( - "{} stream send failed {:?}", - client.conn.trace_id(), - e - ); - continue; - } - }; - log::debug!("dispatched {} on stream id : {}", written, stream_id); - - if written < binary.len() { - let response = PartialResponse { - binary: binary[written..].to_vec(), - written, - }; - client.partial_responses.insert(stream_id, response); + if let Err(e) = send_message( + &mut client.conn, + &mut client.partial_responses, + stream_id, + &binary, + ) { + log::error!("Error sending message : {e}"); } } } @@ -381,39 +369,3 @@ pub fn server_loop( } } } - -/// Handles newly writable streams. -fn handle_writable(client: &mut Client, stream_id: u64) { - let conn = &mut client.conn; - - log::debug!("{} stream {} is writable", conn.trace_id(), stream_id); - - if !client.partial_responses.contains_key(&stream_id) { - return; - } - - let resp = client - .partial_responses - .get_mut(&stream_id) - .expect("should have a stream id"); - let body = &resp.binary; - - let written = match conn.stream_send(stream_id, body, true) { - Ok(v) => v, - - Err(quiche::Error::Done) => 0, - - Err(e) => { - client.partial_responses.remove(&stream_id); - - log::error!("{} stream send failed {:?}", conn.trace_id(), e); - return; - } - }; - if resp.written == resp.binary.len() { - client.partial_responses.remove(&stream_id); - } else { - resp.binary = resp.binary[written..].to_vec(); - resp.written += written; - } -} diff --git a/common/src/quic/quiche_utils.rs b/common/src/quic/quiche_utils.rs index af414a3..0e66ad1 100644 --- a/common/src/quic/quiche_utils.rs +++ b/common/src/quic/quiche_utils.rs @@ -1,3 +1,5 @@ +use std::collections::BTreeMap; + pub fn validate_token<'a>( src: &std::net::SocketAddr, token: &'a [u8], @@ -65,3 +67,10 @@ pub fn get_next_unidi(current_stream_id: u64, is_server: bool) -> u64 { } panic!("stream not found"); } + +pub struct PartialResponse { + pub binary: Vec, + pub written: usize, +} + +pub type PartialResponses = BTreeMap;