Fixing the quiche for large accounts

This commit is contained in:
godmodegalactus 2024-05-22 12:34:07 +02:00
parent b54e614de8
commit 55245dfe29
No known key found for this signature in database
GPG Key ID: 22DA4A30887FDA3C
5 changed files with 139 additions and 481 deletions

View File

@ -3,8 +3,10 @@ use std::net::SocketAddr;
use crate::{
message::Message,
quic::{
configure_server::MAX_DATAGRAM_SIZE, quiche_reciever::recv_message,
quiche_sender::send_message, quiche_utils::get_next_unidi,
configure_server::MAX_DATAGRAM_SIZE,
quiche_reciever::{recv_message, ReadStreams},
quiche_sender::{handle_writable, send_message},
quiche_utils::{get_next_unidi, PartialResponses},
},
};
use anyhow::bail;
@ -52,6 +54,8 @@ pub fn client_loop(
let mut current_stream_id = 3;
let mut buf = [0; 65535];
let mut out = [0; MAX_DATAGRAM_SIZE];
let mut partial_responses = PartialResponses::new();
let mut read_streams = ReadStreams::new();
'client: loop {
poll.poll(&mut events, conn.timeout()).unwrap();
@ -110,17 +114,24 @@ pub fn client_loop(
}
// io events
for stream in conn.readable() {
let message = recv_message(&mut conn, stream);
for stream_id in conn.readable() {
let message = recv_message(&mut conn, &mut read_streams, stream_id);
match message {
Ok(message) => {
Ok(Some(message)) => {
message_recv_queue.send(message).unwrap();
}
Ok(None) => {
// do nothing
}
Err(e) => {
log::error!("Error recieving message : {e}")
}
}
}
for stream_id in conn.writable() {
handle_writable(&mut conn, &mut partial_responses, stream_id);
}
}
// chanel updates
@ -128,8 +139,15 @@ pub fn client_loop(
// channel events
if let Ok(message_to_send) = message_send_queue.try_recv() {
current_stream_id = get_next_unidi(current_stream_id, false);
if let Err(e) = send_message(&mut conn, current_stream_id, &message_to_send) {
log::error!("Error sending message on stream : {}", e);
let binary =
bincode::serialize(&message_to_send).expect("Message should be serializable");
if let Err(e) = send_message(
&mut conn,
&mut partial_responses,
current_stream_id,
&binary,
) {
log::error!("Sending failed with error {e:?}");
}
}
}

View File

@ -1,3 +1,5 @@
use std::collections::BTreeMap;
use anyhow::bail;
use crate::message::Message;
@ -8,11 +10,17 @@ pub fn convert_binary_to_message(bytes: Vec<u8>) -> anyhow::Result<Message> {
Ok(bincode::deserialize::<Message>(&bytes)?)
}
pub type ReadStreams = BTreeMap<u64, Vec<u8>>;
pub fn recv_message(
connection: &mut quiche::Connection,
read_streams: &mut ReadStreams,
stream_id: u64,
) -> anyhow::Result<Message> {
let mut total_buf = vec![];
) -> anyhow::Result<Option<Message>> {
let mut total_buf = match read_streams.remove(&stream_id) {
Some(buf) => buf,
None => vec![],
};
loop {
let mut buf = [0; MAX_DATAGRAM_SIZE]; // 10kk buffer size
match connection.stream_recv(stream_id, &mut buf) {
@ -21,401 +29,22 @@ pub fn recv_message(
total_buf.extend_from_slice(&buf[..read]);
if fin {
log::debug!("fin stream : {}", stream_id);
return Ok(bincode::deserialize::<Message>(&total_buf)?);
return Ok(Some(bincode::deserialize::<Message>(&total_buf)?));
}
}
Err(e) => {
bail!("Fail to read from stream {stream_id} : error : {e}");
match &e {
quiche::Error::Done => {
// will be tried again later
log::debug!("stream saved : {}. len: {}", stream_id, total_buf.len());
read_streams.insert(stream_id, total_buf);
return Ok(None);
}
_ => {
bail!("read error on stream : {}, error: {}", stream_id, e);
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::{
net::{IpAddr, Ipv4Addr, SocketAddr},
str::FromStr,
thread::sleep,
time::Duration,
};
use itertools::Itertools;
use quiche::ConnectionId;
use ring::rand::{SecureRandom, SystemRandom};
use std::net::UdpSocket;
use crate::{
message::Message,
quic::{
configure_client::configure_client,
configure_server::{configure_server, MAX_DATAGRAM_SIZE},
quiche_reciever::recv_message,
quiche_sender::send_message,
quiche_utils::{mint_token, validate_token},
},
types::account::Account,
};
#[test]
fn test_send_and_recieve_of_small_account() {
let mut config = configure_server(1, 100000, 1).unwrap();
// Setup the event loop.
let socket_addr = SocketAddr::from_str("0.0.0.0:0").unwrap();
let socket = UdpSocket::bind(socket_addr).unwrap();
let port = socket.local_addr().unwrap().port();
let local_addr = socket.local_addr().unwrap();
let account = Account::get_account_for_test(123456, 2);
let message = Message::AccountMsg(account);
let jh = {
let message_to_send = message.clone();
let server_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port);
std::thread::spawn(move || {
let mut client_config = configure_client(1, 1000, 10).unwrap();
// Setup the event loop.
let socket_addr: SocketAddr = "0.0.0.0:0".parse().unwrap();
let socket = std::net::UdpSocket::bind(socket_addr).unwrap();
let mut scid = [0; quiche::MAX_CONN_ID_LEN];
SystemRandom::new().fill(&mut scid[..]).unwrap();
let scid = quiche::ConnectionId::from_ref(&scid);
// Get local address.
let local_addr = socket.local_addr().unwrap();
println!("connecting");
let mut conn =
quiche::connect(None, &scid, local_addr, server_addr, &mut client_config)
.unwrap();
let mut out = [0; MAX_DATAGRAM_SIZE];
println!("sending message");
let (write, send_info) = conn.send(&mut out).expect("initial send failed");
while let Err(e) = socket.send_to(&out[..write], send_info.to) {
panic!("send() failed: {:?}", e);
}
while !conn.is_established() {
sleep(Duration::from_millis(100));
}
send_message(&mut conn, 4, &message_to_send).unwrap();
// Generate outgoing QUIC packets and send them on the UDP socket, until
// quiche reports that there are no more packets to be sent.
loop {
let (write, send_info) = match conn.send(&mut out) {
Ok(v) => v,
Err(quiche::Error::Done) => {
break;
}
Err(e) => {
conn.close(false, 0x1, b"fail").ok();
break;
}
};
if let Err(e) = socket.send_to(&out[..write], send_info.to) {
if e.kind() == std::io::ErrorKind::WouldBlock {
break;
}
panic!("send() failed: {:?}", e);
}
}
conn.close(true, 0, b"not required").unwrap();
})
};
loop {
let mut buf = [0; 65535];
let mut out = [0; MAX_DATAGRAM_SIZE];
let (len, from) = match socket.recv_from(&mut buf) {
Ok(v) => v,
Err(e) => {
panic!("recv() failed: {:?}", e);
}
};
println!("recieved first packet");
log::debug!("got {} bytes", len);
let pkt_buf = &mut buf[..len];
// Parse the QUIC packet's header.
let hdr = match quiche::Header::from_slice(pkt_buf, quiche::MAX_CONN_ID_LEN) {
Ok(header) => header,
Err(e) => {
panic!("Parsing packet header failed: {:?}", e);
}
};
let rng = SystemRandom::new();
let conn_id_seed = ring::hmac::Key::generate(ring::hmac::HMAC_SHA256, &rng).unwrap();
let conn_id = ring::hmac::sign(&conn_id_seed, &hdr.dcid);
let conn_id = &conn_id.as_ref()[..quiche::MAX_CONN_ID_LEN];
let conn_id: ConnectionId<'static> = conn_id.to_vec().into();
if hdr.ty != quiche::Type::Initial {
panic!("Packet is not Initial");
}
if !quiche::version_is_supported(hdr.version) {
log::warn!("Doing version negotiation");
let len = quiche::negotiate_version(&hdr.scid, &hdr.dcid, &mut out).unwrap();
let out = &out[..len];
if let Err(e) = socket.send_to(out, from) {
panic!("send() failed: {:?}", e);
}
}
let mut scid = [0; quiche::MAX_CONN_ID_LEN];
scid.copy_from_slice(&conn_id);
let scid = quiche::ConnectionId::from_ref(&scid);
// Token is always present in Initial packets.
let token = hdr.token.as_ref().unwrap();
println!("token: {}", token.iter().map(|x| x.to_string()).join(", "));
// Do stateless retry if the client didn't send a token.
if token.is_empty() {
log::warn!("Doing stateless retry");
let new_token = mint_token(&hdr, &from);
let len = quiche::retry(
&hdr.scid,
&hdr.dcid,
&scid,
&new_token,
hdr.version,
&mut out,
)
.unwrap();
let out = &out[..len];
if let Err(e) = socket.send_to(out, from) {
panic!("send() failed: {:?}", e);
} else {
continue;
}
}
let odcid = validate_token(&from, token);
// The token was not valid, meaning the retry failed, so
// drop the packet.
if odcid.is_none() {
panic!("Invalid address validation token");
}
if scid.len() != hdr.dcid.len() {
panic!("Invalid destination connection ID");
}
// Reuse the source connection ID we sent in the Retry packet,
// instead of changing it again.
let scid = hdr.dcid.clone();
log::debug!("New connection: dcid={:?} scid={:?}", hdr.dcid, scid);
let mut conn =
quiche::accept(&scid, odcid.as_ref(), local_addr, from, &mut config).unwrap();
let stream_id = {
let mut stream_id = 0;
loop {
let readable = conn.stream_readable_next();
match readable {
Some(id) => {
stream_id = id;
break;
}
None => {
sleep(Duration::from_millis(100));
}
}
}
stream_id
};
let recvd_message = recv_message(&mut conn, stream_id).unwrap();
assert_eq!(recvd_message, message);
std::thread::sleep(Duration::from_secs(1));
assert_eq!(conn.is_closed(), true);
jh.join().unwrap();
break;
}
}
#[test]
fn test_send_and_recieve_of_large_account() {
let mut config = configure_server(1, 100000, 1).unwrap();
// Setup the event loop.
let socket_addr = SocketAddr::from_str("0.0.0.0:0").unwrap();
let socket = UdpSocket::bind(socket_addr).unwrap();
let port = socket.local_addr().unwrap().port();
let local_addr = socket.local_addr().unwrap();
let account = Account::get_account_for_test(123456, 10_000_000);
let message = Message::AccountMsg(account);
let jh = {
let message = message.clone();
let server_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port);
std::thread::spawn(move || {
let mut client_config = configure_client(1, 12_000_000, 10).unwrap();
// Setup the event loop.
let socket_addr: SocketAddr = "0.0.0.0:0".parse().unwrap();
let socket = std::net::UdpSocket::bind(socket_addr).unwrap();
let mut scid = [0; quiche::MAX_CONN_ID_LEN];
SystemRandom::new().fill(&mut scid[..]).unwrap();
let scid = quiche::ConnectionId::from_ref(&scid);
// Get local address.
let local_addr = socket.local_addr().unwrap();
println!("connecting");
let mut conn =
quiche::connect(None, &scid, local_addr, server_addr, &mut client_config)
.unwrap();
let mut out = [0; MAX_DATAGRAM_SIZE];
println!("sending message");
let (write, send_info) = conn.send(&mut out).expect("initial send failed");
while let Err(e) = socket.send_to(&out[..write], send_info.to) {
panic!("send() failed: {:?}", e);
}
let stream_id = conn.stream_readable_next().unwrap();
let recvd_message = recv_message(&mut conn, stream_id).unwrap();
assert_eq!(recvd_message, message);
std::thread::sleep(Duration::from_secs(1));
assert_eq!(conn.is_closed(), true);
})
};
loop {
let mut buf = [0; 65535];
let mut out = [0; MAX_DATAGRAM_SIZE];
let (len, from) = match socket.recv_from(&mut buf) {
Ok(v) => v,
Err(e) => {
panic!("recv() failed: {:?}", e);
}
};
println!("recieved first packet");
log::debug!("got {} bytes", len);
let pkt_buf = &mut buf[..len];
// Parse the QUIC packet's header.
let hdr = match quiche::Header::from_slice(pkt_buf, quiche::MAX_CONN_ID_LEN) {
Ok(header) => header,
Err(e) => {
panic!("Parsing packet header failed: {:?}", e);
}
};
let rng = SystemRandom::new();
let conn_id_seed = ring::hmac::Key::generate(ring::hmac::HMAC_SHA256, &rng).unwrap();
let conn_id = ring::hmac::sign(&conn_id_seed, &hdr.dcid);
let conn_id = &conn_id.as_ref()[..quiche::MAX_CONN_ID_LEN];
let conn_id: ConnectionId<'static> = conn_id.to_vec().into();
if hdr.ty != quiche::Type::Initial {
panic!("Packet is not Initial");
}
if !quiche::version_is_supported(hdr.version) {
log::warn!("Doing version negotiation");
let len = quiche::negotiate_version(&hdr.scid, &hdr.dcid, &mut out).unwrap();
let out = &out[..len];
if let Err(e) = socket.send_to(out, from) {
panic!("send() failed: {:?}", e);
}
}
let mut scid = [0; quiche::MAX_CONN_ID_LEN];
scid.copy_from_slice(&conn_id);
let scid = quiche::ConnectionId::from_ref(&scid);
// Token is always present in Initial packets.
let token = hdr.token.as_ref().unwrap();
println!("token: {}", token.iter().map(|x| x.to_string()).join(", "));
// Do stateless retry if the client didn't send a token.
if token.is_empty() {
log::warn!("Doing stateless retry");
let new_token = mint_token(&hdr, &from);
let len = quiche::retry(
&hdr.scid,
&hdr.dcid,
&scid,
&new_token,
hdr.version,
&mut out,
)
.unwrap();
let out = &out[..len];
if let Err(e) = socket.send_to(out, from) {
panic!("send() failed: {:?}", e);
} else {
continue;
}
}
let odcid = validate_token(&from, token);
// The token was not valid, meaning the retry failed, so
// drop the packet.
if odcid.is_none() {
panic!("Invalid address validation token");
}
if scid.len() != hdr.dcid.len() {
panic!("Invalid destination connection ID");
}
// Reuse the source connection ID we sent in the Retry packet,
// instead of changing it again.
let scid = hdr.dcid.clone();
log::debug!("New connection: dcid={:?} scid={:?}", hdr.dcid, scid);
let mut conn =
quiche::accept(&scid, odcid.as_ref(), local_addr, from, &mut config).unwrap();
let stream_id = conn.stream_writable_next().unwrap();
send_message(&mut conn, stream_id, &message).unwrap();
conn.close(true, 0, b"not required").unwrap();
jh.join().unwrap();
break;
}
}
}

View File

@ -1,9 +1,9 @@
use itertools::Itertools;
use anyhow::bail;
use quiche::Connection;
use crate::message::Message;
use crate::{message::Message, quic::quiche_utils::PartialResponse};
use super::configure_server::MAX_DATAGRAM_SIZE;
use super::quiche_utils::PartialResponses;
pub fn convert_to_binary(message: &Message) -> anyhow::Result<Vec<u8>> {
Ok(bincode::serialize(&message)?)
@ -11,14 +11,64 @@ pub fn convert_to_binary(message: &Message) -> anyhow::Result<Vec<u8>> {
pub fn send_message(
connection: &mut Connection,
partial_responses: &mut PartialResponses,
stream_id: u64,
message: &Message,
message: &Vec<u8>,
) -> anyhow::Result<()> {
let binary = convert_to_binary(message)?;
let chunks = binary.chunks(MAX_DATAGRAM_SIZE).collect_vec();
let nb_chunks = chunks.len();
for (index, buf) in chunks.iter().enumerate() {
connection.stream_send(stream_id, buf, index + 1 == nb_chunks)?;
let written = match connection.stream_send(stream_id, &message, true) {
Ok(v) => v,
Err(quiche::Error::Done) => 0,
Err(e) => {
bail!("{} stream send failed {:?}", stream_id, e);
}
};
log::debug!("dispatched {} on stream id : {}", written, stream_id);
if written < message.len() {
let response = PartialResponse {
binary: message[written..].to_vec(),
written,
};
partial_responses.insert(stream_id, response);
}
Ok(())
}
/// Handles newly writable streams.
pub fn handle_writable(
conn: &mut quiche::Connection,
partial_responses: &mut PartialResponses,
stream_id: u64,
) {
log::debug!("{} stream {} is writable", conn.trace_id(), stream_id);
if !partial_responses.contains_key(&stream_id) {
return;
}
let resp = partial_responses
.get_mut(&stream_id)
.expect("should have a stream id");
let body = &resp.binary;
let written = match conn.stream_send(stream_id, body, true) {
Ok(v) => v,
Err(quiche::Error::Done) => 0,
Err(e) => {
partial_responses.remove(&stream_id);
log::error!("{} stream send failed {:?}", conn.trace_id(), e);
return;
}
};
if resp.written == resp.binary.len() {
partial_responses.remove(&stream_id);
} else {
resp.binary = resp.binary[written..].to_vec();
resp.written += written;
}
}

View File

@ -2,7 +2,7 @@ use std::{collections::HashMap, net::SocketAddr};
use itertools::Itertools;
use mio::net::UdpSocket;
use quiche::ConnectionId;
use quiche::{Connection, ConnectionId};
use ring::rand::SystemRandom;
use crate::{
@ -12,21 +12,21 @@ use crate::{
message::Message,
quic::{
quiche_reciever::recv_message,
quiche_sender::{handle_writable, send_message},
quiche_utils::{get_next_unidi, mint_token, validate_token},
},
types::{account::Account, block_meta::SlotMeta, slot_identifier::SlotIdentifier},
};
use super::{configure_server::MAX_DATAGRAM_SIZE, quiche_sender::convert_to_binary};
struct PartialResponse {
pub binary: Vec<u8>,
pub written: usize,
}
use super::{
configure_server::MAX_DATAGRAM_SIZE, quiche_reciever::ReadStreams,
quiche_utils::PartialResponses,
};
struct Client {
pub conn: quiche::Connection,
pub partial_responses: HashMap<u64, PartialResponse>,
pub conn: Connection,
pub partial_responses: PartialResponses,
pub read_streams: ReadStreams,
pub filters: Vec<Filter>,
pub next_stream: u64,
}
@ -191,7 +191,8 @@ pub fn server_loop(
let client = Client {
conn,
partial_responses: HashMap::new(),
partial_responses: PartialResponses::new(),
read_streams: ReadStreams::new(),
filters: Vec::new(),
next_stream: get_next_unidi(0, true),
};
@ -226,14 +227,15 @@ pub fn server_loop(
if client.conn.is_in_early_data() || client.conn.is_established() {
for stream_id in client.conn.writable() {
handle_writable(client, stream_id);
handle_writable(&mut client.conn, &mut client.partial_responses, stream_id);
}
// Process all readable streams.
for stream in client.conn.readable() {
let message = recv_message(&mut client.conn, stream);
let message =
recv_message(&mut client.conn, &mut client.read_streams, stream);
match message {
Ok(message) => {
Ok(Some(message)) => {
match message {
Message::Filters(mut filters) => {
client.filters.append(&mut filters);
@ -246,6 +248,7 @@ pub fn server_loop(
}
}
}
Ok(None) => {}
Err(e) => {
log::error!("Error recieving message : {e}")
}
@ -292,7 +295,7 @@ pub fn server_loop(
Message::TransactionMsg(transaction)
}
};
let binary = convert_to_binary(&message)
let binary = bincode::serialize(&message)
.expect("Message should be serializable in binary");
for client in dispatch_to {
let stream_id = client.next_stream;
@ -302,28 +305,13 @@ pub fn server_loop(
binary.len(),
stream_id
);
let written = match client.conn.stream_send(stream_id, &binary, true) {
Ok(v) => v,
Err(quiche::Error::Done) => 0,
Err(e) => {
log::error!(
"{} stream send failed {:?}",
client.conn.trace_id(),
e
);
continue;
}
};
log::debug!("dispatched {} on stream id : {}", written, stream_id);
if written < binary.len() {
let response = PartialResponse {
binary: binary[written..].to_vec(),
written,
};
client.partial_responses.insert(stream_id, response);
if let Err(e) = send_message(
&mut client.conn,
&mut client.partial_responses,
stream_id,
&binary,
) {
log::error!("Error sending message : {e}");
}
}
}
@ -381,39 +369,3 @@ pub fn server_loop(
}
}
}
/// Handles newly writable streams.
fn handle_writable(client: &mut Client, stream_id: u64) {
let conn = &mut client.conn;
log::debug!("{} stream {} is writable", conn.trace_id(), stream_id);
if !client.partial_responses.contains_key(&stream_id) {
return;
}
let resp = client
.partial_responses
.get_mut(&stream_id)
.expect("should have a stream id");
let body = &resp.binary;
let written = match conn.stream_send(stream_id, body, true) {
Ok(v) => v,
Err(quiche::Error::Done) => 0,
Err(e) => {
client.partial_responses.remove(&stream_id);
log::error!("{} stream send failed {:?}", conn.trace_id(), e);
return;
}
};
if resp.written == resp.binary.len() {
client.partial_responses.remove(&stream_id);
} else {
resp.binary = resp.binary[written..].to_vec();
resp.written += written;
}
}

View File

@ -1,3 +1,5 @@
use std::collections::BTreeMap;
pub fn validate_token<'a>(
src: &std::net::SocketAddr,
token: &'a [u8],
@ -65,3 +67,10 @@ pub fn get_next_unidi(current_stream_id: u64, is_server: bool) -> u64 {
}
panic!("stream not found");
}
pub struct PartialResponse {
pub binary: Vec<u8>,
pub written: usize,
}
pub type PartialResponses = BTreeMap<u64, PartialResponse>;