moving from mio::udp socket to std::net::udp socket

This commit is contained in:
godmodegalactus 2024-08-02 16:09:35 +02:00
parent 9b5a5fae7e
commit c927b713c6
No known key found for this signature in database
GPG Key ID: 22DA4A30887FDA3C
2 changed files with 193 additions and 204 deletions

View File

@ -76,7 +76,7 @@ pub fn client_loop(
let mut buf = [0; 65535];
'client: loop {
poll.poll(&mut events, Some(Duration::from_micros(100)))?;
poll.poll(&mut events, Some(Duration::from_millis(100)))?;
'read: loop {
match socket.recv_from(&mut buf) {

View File

@ -1,6 +1,7 @@
use std::{
collections::HashMap,
net::SocketAddr,
net::UdpSocket,
sync::{
atomic::{AtomicBool, AtomicU64},
mpsc::{self, Sender},
@ -11,7 +12,6 @@ use std::{
use anyhow::bail;
use itertools::Itertools;
use mio::Token;
use quiche::ConnectionId;
use ring::rand::SystemRandom;
@ -53,13 +53,7 @@ pub fn server_loop(
let maximum_concurrent_streams_id = u64::MAX;
let mut config = configure_server(quic_params)?;
let mut socket = mio::net::UdpSocket::bind(socket_addr)?;
let mut poll = mio::Poll::new()?;
let mut events = mio::Events::with_capacity(1024);
poll.registry()
.register(&mut socket, mio::Token(0), mio::Interest::READABLE)?;
let socket = Arc::new(UdpSocket::bind(socket_addr)?);
let mut buf = [0; 65535];
let mut out = [0; MAX_DATAGRAM_SIZE];
@ -73,11 +67,6 @@ pub fn server_loop(
let clients_by_id: Arc<Mutex<HashMap<ConnectionId<'static>, u64>>> =
Arc::new(Mutex::new(HashMap::new()));
let (write_sender, mut write_reciver) = mio_channel::channel::<(quiche::SendInfo, Vec<u8>)>();
poll.registry()
.register(&mut write_reciver, mio::Token(1), mio::Interest::READABLE)?;
let enable_pacing = if quic_params.enable_pacing {
set_txtime_sockopt(&socket).is_ok()
} else {
@ -105,16 +94,11 @@ pub fn server_loop(
let mut client_id_counter = 0;
loop {
poll.poll(&mut events, Some(Duration::from_millis(10)))?;
let do_read = events.is_empty() || events.iter().any(|x| x.token() == Token(0));
if do_read {
'read: loop {
let (len, from) = match socket.recv_from(&mut buf) {
Ok(v) => v,
Err(e) => {
if e.kind() == std::io::ErrorKind::WouldBlock {
log::trace!("recv() would block");
break 'read;
break;
}
bail!("recv() failed: {:?}", e);
}
@ -128,7 +112,7 @@ pub fn server_loop(
Err(e) => {
log::error!("Parsing packet header failed: {:?}", e);
continue 'read;
continue;
}
};
@ -137,15 +121,15 @@ pub fn server_loop(
let conn_id: ConnectionId<'static> = conn_id.to_vec().into();
let mut clients_lk = clients_by_id.lock().unwrap();
if !clients_lk.contains_key(&hdr.dcid) && !clients_lk.contains_key(&conn_id) {
drop(clients_lk);
if hdr.ty != quiche::Type::Initial {
log::error!("Packet is not Initial");
continue 'read;
continue;
}
if !quiche::version_is_supported(hdr.version) {
log::warn!("Doing version negotiation");
let len =
quiche::negotiate_version(&hdr.scid, &hdr.dcid, &mut out).unwrap();
let len = quiche::negotiate_version(&hdr.scid, &hdr.dcid, &mut out).unwrap();
let out = &out[..len];
@ -155,7 +139,7 @@ pub fn server_loop(
}
panic!("send() failed: {:?}", e);
}
continue 'read;
continue;
}
let mut scid = [0; quiche::MAX_CONN_ID_LEN];
@ -185,27 +169,26 @@ pub fn server_loop(
if let Err(e) = socket.send_to(&out[..len], from) {
log::error!("Error sending retry messages : {e:?}");
}
continue 'read;
continue;
}
let odcid = validate_token(&from, token);
if odcid.is_none() {
log::error!("Invalid address validation token");
continue 'read;
continue;
}
if scid.len() != hdr.dcid.len() {
log::error!("Invalid destination connection ID");
continue 'read;
continue;
}
let scid = hdr.dcid.clone();
log::info!("New connection: dcid={:?} scid={:?}", hdr.dcid, scid);
let mut conn =
quiche::accept(&scid, odcid.as_ref(), local_addr, from, &mut config)?;
let mut conn = quiche::accept(&scid, odcid.as_ref(), local_addr, from, &mut config)?;
let recv_info = quiche::RecvInfo {
to: socket.local_addr().unwrap(),
@ -216,7 +199,7 @@ pub fn server_loop(
Ok(v) => v,
Err(e) => {
log::error!("{} recv failed: {:?}", conn.trace_id(), e);
continue 'read;
continue;
}
};
@ -227,17 +210,19 @@ pub fn server_loop(
let filters = Arc::new(RwLock::new(Vec::new()));
create_client_task(
socket.clone(),
conn,
current_client_id,
clients_by_id.clone(),
client_reciver,
write_sender.clone(),
client_message_rx,
filters.clone(),
maximum_concurrent_streams_id,
stop_laggy_client,
quic_params.incremental_priority,
rng.clone(),
enable_pacing,
enable_gso,
);
let mut lk = dispatching_connections.lock().unwrap();
lk.insert(
@ -247,6 +232,7 @@ pub fn server_loop(
filters,
},
);
let mut clients_lk = clients_by_id.lock().unwrap();
clients_lk.insert(scid, current_client_id);
client_messsage_channel_by_id.insert(current_client_id, client_sender);
} else {
@ -277,37 +263,24 @@ pub fn server_loop(
}
};
}
}
while let Ok((send_info, buffer)) = write_reciver.try_recv() {
let send_result = if enable_pacing {
send_with_pacing(&socket, &buffer, &send_info, enable_gso)
} else {
socket.send_to(&buffer, send_info.to)
};
match send_result {
Ok(_written) => {}
Err(e) => {
log::error!("sending failed with error : {e:?}");
}
}
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn create_client_task(
socket: Arc<UdpSocket>,
connection: quiche::Connection,
client_id: u64,
client_id_by_scid: Arc<Mutex<HashMap<ConnectionId<'static>, u64>>>,
mut receiver: mio_channel::Receiver<(quiche::RecvInfo, Vec<u8>)>,
sender: mio_channel::Sender<(quiche::SendInfo, Vec<u8>)>,
message_channel: mpsc::Receiver<(Vec<u8>, u8)>,
filters: Arc<RwLock<Vec<Filter>>>,
maximum_concurrent_streams_id: u64,
stop_laggy_client: bool,
incremental_priority: bool,
rng: SystemRandom,
enable_pacing: bool,
enable_gso: bool,
) {
std::thread::spawn(move || {
let mut partial_responses = PartialResponses::new();
@ -332,7 +305,11 @@ fn create_client_task(
let number_of_writable_streams = Arc::new(AtomicU64::new(0));
let messages_added = Arc::new(AtomicU64::new(0));
let quit = Arc::new(AtomicBool::new(false));
let max_burst_size = MAX_DATAGRAM_SIZE * 10;
let max_burst_size = if enable_gso {
MAX_DATAGRAM_SIZE * 10
} else {
MAX_DATAGRAM_SIZE
};
{
let number_of_loops = number_of_loops.clone();
@ -575,9 +552,22 @@ fn create_client_task(
}
if total_length > 0 && send_message_to.is_some() {
sender
.send((send_message_to.unwrap(), out[..total_length].to_vec()))
.unwrap();
let send_result = if enable_pacing {
send_with_pacing(
&socket,
&out[..total_length],
&send_message_to.unwrap(),
enable_gso,
)
} else {
socket.send(&out[..total_length])
};
match send_result {
Ok(_written) => {}
Err(e) => {
log::error!("sending failed with error : {e:?}");
}
}
total_length = 0;
} else {
break;
@ -664,7 +654,7 @@ fn create_dispatching_thread(
});
}
fn set_txtime_sockopt(sock: &mio::net::UdpSocket) -> std::io::Result<()> {
fn set_txtime_sockopt(sock: &UdpSocket) -> std::io::Result<()> {
use nix::sys::socket::setsockopt;
use nix::sys::socket::sockopt::TxTime;
use std::os::unix::io::AsRawFd;
@ -674,7 +664,6 @@ fn set_txtime_sockopt(sock: &mio::net::UdpSocket) -> std::io::Result<()> {
flags: 0,
};
// mio::net::UdpSocket doesn't implement AsFd (yet?).
let fd = unsafe { std::os::fd::BorrowedFd::borrow_raw(sock.as_raw_fd()) };
setsockopt(&fd, TxTime, &config)?;
@ -696,7 +685,7 @@ fn std_time_to_u64(time: &std::time::Instant) -> u64 {
const GSO_SEGMENT_SIZE: u16 = MAX_DATAGRAM_SIZE as u16;
fn send_with_pacing(
socket: &mio::net::UdpSocket,
socket: &UdpSocket,
buf: &[u8],
send_info: &quiche::SendInfo,
enable_gso: bool,
@ -727,7 +716,7 @@ fn send_with_pacing(
}
}
pub fn detect_gso(socket: &mio::net::UdpSocket, segment_size: usize) -> bool {
pub fn detect_gso(socket: &UdpSocket, segment_size: usize) -> bool {
use nix::sys::socket::setsockopt;
use nix::sys::socket::sockopt::UdpGsoSegment;
use std::os::unix::io::AsRawFd;