Adding a connection manager to manage quic connections for server

This commit is contained in:
Godmode Galactus 2024-05-13 16:38:10 +02:00
parent 211f439d37
commit da2dd63704
No known key found for this signature in database
GPG Key ID: 22DA4A30887FDA3C
10 changed files with 305 additions and 26 deletions

View File

@ -3,8 +3,6 @@ name = "quic-geyser-common"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
solana-sdk = { workspace = "true" }
solana-streamer = { workspace = "true" }

View File

@ -3,8 +3,41 @@ use std::collections::HashSet;
use serde::{Deserialize, Serialize};
use solana_sdk::pubkey::Pubkey;
#[derive(Serialize, Deserialize, Clone)]
use crate::message::Message;
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
pub enum Filter {
Account(AccountFilter),
}
impl Filter {
pub fn allows(&self, message: &Message) -> bool {
match &self {
Filter::Account(account) => account.allows(message),
}
}
}
// setting owner to 11111111111111111111111111111111 will subscribe to all the accounts
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
pub struct AccountFilter {
owner: Option<Pubkey>,
accounts: Option<HashSet<Pubkey>>,
}
impl AccountFilter {
pub fn allows(&self, message: &Message) -> bool {
if let Message::AccountMsg(account) = message {
if let Some(owner) = self.owner {
// check if filter subscribes to all the accounts
if owner == Pubkey::default() || owner == account.owner {
return true;
}
}
if let Some(accounts) = &self.accounts {
return accounts.contains(&account.pubkey);
}
}
false
}
}

View File

@ -1,8 +1,9 @@
use serde::{Deserialize, Serialize};
use crate::types::account::Account;
use crate::{filters::Filter, types::account::Account};
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord)]
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
pub enum Message {
AccountMsg(Account),
Filters(Vec<Filter>), // sent from client to server
}

View File

@ -1,12 +1,24 @@
use std::{net::{IpAddr, Ipv4Addr}, sync::Arc, time::Duration};
use std::{
net::{IpAddr, Ipv4Addr},
sync::Arc,
time::Duration,
};
use quinn::{ClientConfig, Endpoint, EndpointConfig, IdleTimeout, TokioRuntime, TransportConfig, VarInt};
use quinn::{
ClientConfig, Endpoint, EndpointConfig, IdleTimeout, TokioRuntime, TransportConfig, VarInt,
};
use solana_sdk::signature::Keypair;
use solana_streamer::tls_certificates::new_self_signed_tls_certificate;
use crate::quic::{configure_server::ALPN_GEYSER_PROTOCOL_ID, skip_verification::ClientSkipServerVerification};
use crate::quic::{
configure_server::ALPN_GEYSER_PROTOCOL_ID, skip_verification::ClientSkipServerVerification,
};
pub fn create_client_endpoint(certificate: rustls::Certificate, key: rustls::PrivateKey, maximum_streams: u32) -> Endpoint {
pub fn create_client_endpoint(
certificate: rustls::Certificate,
key: rustls::PrivateKey,
maximum_streams: u32,
) -> Endpoint {
const DATAGRAM_RECEIVE_BUFFER_SIZE: usize = 64 * 1024 * 1024;
const DATAGRAM_SEND_BUFFER_SIZE: usize = 64 * 1024 * 1024;
const INITIAL_MAXIMUM_TRANSMISSION_UNIT: u16 = MINIMUM_MAXIMUM_TRANSMISSION_UNIT;
@ -42,7 +54,7 @@ pub fn create_client_endpoint(certificate: rustls::Certificate, key: rustls::Pri
transport_config.datagram_receive_buffer_size(Some(DATAGRAM_RECEIVE_BUFFER_SIZE));
transport_config.datagram_send_buffer_size(DATAGRAM_SEND_BUFFER_SIZE);
transport_config.initial_mtu(INITIAL_MAXIMUM_TRANSMISSION_UNIT);
transport_config.max_concurrent_bidi_streams(VarInt::from(maximum_streams));
transport_config.max_concurrent_bidi_streams(VarInt::from_u32(0));
transport_config.max_concurrent_uni_streams(VarInt::from(maximum_streams));
transport_config.min_mtu(MINIMUM_MAXIMUM_TRANSMISSION_UNIT);
transport_config.mtu_discovery_config(None);
@ -58,9 +70,11 @@ pub async fn configure_client(
identity: &Keypair,
maximum_concurrent_streams: u32,
) -> anyhow::Result<Endpoint> {
let (certificate, key) = new_self_signed_tls_certificate(
&identity,
IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
)?;
Ok(create_client_endpoint(certificate, key, maximum_concurrent_streams))
}
let (certificate, key) =
new_self_signed_tls_certificate(identity, IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)))?;
Ok(create_client_endpoint(
certificate,
key,
maximum_concurrent_streams,
))
}

View File

@ -34,8 +34,8 @@ pub fn configure_server(
server_config.use_retry(true);
let config = Arc::get_mut(&mut server_config.transport).unwrap();
config.max_concurrent_uni_streams((max_concurrent_streams as u32).into());
let recv_size = (recieve_window_size as u32).into();
config.max_concurrent_uni_streams((max_concurrent_streams).into());
let recv_size = recieve_window_size.into();
config.stream_receive_window(recv_size);
config.receive_window(recv_size);
@ -44,7 +44,7 @@ pub fn configure_server(
config.max_idle_timeout(Some(timeout));
// disable bidi & datagrams
config.max_concurrent_bidi_streams(max_concurrent_streams.into());
config.max_concurrent_bidi_streams(0u32.into());
config.datagram_receive_buffer_size(None);
Ok((server_config, cert_chain_pem))

View File

@ -0,0 +1,154 @@
use quinn::{Connection, Endpoint};
use std::sync::Arc;
use std::{collections::VecDeque, time::Duration};
use tokio::{sync::RwLock, task::JoinHandle, time::Instant};
use crate::{filters::Filter, message::Message};
use super::{quinn_reciever::recv_message, quinn_sender::send_message};
pub struct ConnectionData {
pub id: u64,
pub connection: Connection,
pub filters: Vec<Filter>,
pub since: Instant,
}
impl ConnectionData {
pub fn new(id: u64, connection: Connection) -> Self {
Self {
id,
connection,
filters: vec![],
since: Instant::now(),
}
}
}
/*
This class will take care of adding connections and filters etc
*/
pub struct ConnectionManager {
connections: Arc<RwLock<VecDeque<ConnectionData>>>,
}
impl ConnectionManager {
pub fn new(endpoint: Endpoint) -> (Self, JoinHandle<()>) {
let connections: Arc<RwLock<VecDeque<ConnectionData>>> =
Arc::new(RwLock::new(VecDeque::new()));
// create a task to add incoming connections
let connection_adder_jh = {
let connections = connections.clone();
tokio::spawn(async move {
let mut id = 0;
loop {
// acceept incoming connections
if let Some(connecting) = endpoint.accept().await {
let connection_result = connecting.await;
match connection_result {
Ok(connection) => {
// connection established
// add the connection in the connections list
let mut lk = connections.write().await;
id += 1;
let current_id = id;
lk.push_back(ConnectionData::new(current_id, connection.clone()));
drop(lk);
let connections_tmp = connections.clone();
// task to add filters
let connection_to_listen = connection.clone();
tokio::spawn(async move {
loop {
if let Ok(recv_stream) =
connection_to_listen.accept_uni().await
{
match tokio::time::timeout(
Duration::from_secs(10),
recv_message(recv_stream),
)
.await
{
Ok(Ok(filters)) => {
let Message::Filters(mut filters) = filters
else {
continue;
};
let mut lk = connections_tmp.write().await;
let connection_data =
lk.iter_mut().find(|x| x.id == current_id);
if let Some(connection_data) = connection_data {
connection_data
.filters
.append(&mut filters);
}
}
Ok(Err(e)) => {
log::error!("error getting message from the client : {}", e);
}
Err(_timeout) => {
log::warn!("Client request timeout");
}
}
}
}
});
let connections = connections.clone();
// create a connection removing task
tokio::spawn(async move {
// if connection is closed remove it
let closed_error = connection.closed().await;
log::info!("connection closed with error {}", closed_error);
let mut lk = connections.write().await;
lk.retain(|x| x.id != current_id);
});
}
Err(e) => log::error!("Error connecting {}", e),
}
}
}
})
};
(Self { connections }, connection_adder_jh)
}
pub async fn dispach(&self, message: Message, retry_count: u64) {
let lk = self.connections.read().await;
for connection_data in lk.iter() {
if connection_data.filters.iter().any(|x| x.allows(&message)) {
let connection = connection_data.connection.clone();
let message = message.clone();
tokio::spawn(async move {
for _ in 0..retry_count {
let send_stream = connection.open_uni().await;
match send_stream {
Ok(send_stream) => {
match send_message(send_stream, message.clone()).await {
Ok(_) => {
log::debug!("Message sucessfully sent");
}
Err(e) => {
log::error!(
"error dispatching message and sending data : {}",
e
)
}
}
}
Err(e) => {
log::error!(
"error dispatching message while creating stream : {}",
e
);
}
}
}
});
}
}
}
}

View File

@ -1,5 +1,6 @@
pub mod configure_client;
pub mod configure_server;
pub mod connection_manager;
pub mod quinn_reciever;
pub mod quinn_sender;
pub mod skip_verification;

View File

@ -30,10 +30,10 @@ pub async fn recv_message(mut recv_stream: RecvStream) -> anyhow::Result<Message
}
}
#[cfg(test)]
mod tests {
use std::{
net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket},
str::FromStr,
sync::Arc,
};
@ -77,6 +77,8 @@ mod tests {
blockhash: Hash::new_unique(),
},
pubkey: Pubkey::new_unique(),
owner: Pubkey::new_unique(),
write_version: 0,
data: vec![6; 2],
};
let message = Message::AccountMsg(account);
@ -88,12 +90,15 @@ mod tests {
let endpoint = configure_client(&Keypair::new(), 1).await.unwrap();
let connecting = endpoint
.connect(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port), "tmp")
.connect(
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port),
"tmp",
)
.unwrap();
let connection = connecting.await.unwrap();
let recv_stream = connection.accept_uni().await.unwrap();
let recved_message = recv_message(recv_stream).await.unwrap();
// assert if sent and recieved message match
// assert if sent and recieved message match
assert_eq!(sent_message, recved_message);
})
};
@ -134,7 +139,9 @@ mod tests {
blockhash: Hash::new_unique(),
},
pubkey: Pubkey::new_unique(),
data: vec![6; 100_000_000],
owner: Pubkey::new_unique(),
write_version: 0,
data: vec![9; 100_000_000],
};
let message = Message::AccountMsg(account);
@ -145,7 +152,10 @@ mod tests {
let endpoint = configure_client(&Keypair::new(), 0).await.unwrap();
let connecting = endpoint
.connect(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port), "tmp")
.connect(
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port),
"tmp",
)
.unwrap();
let connection = connecting.await.unwrap();
let send_stream = connection.open_uni().await.unwrap();
@ -163,5 +173,4 @@ mod tests {
assert_eq!(message, recved_message);
endpoint.close(VarInt::from_u32(0), b"");
}
}
}

View File

@ -7,5 +7,7 @@ use super::slot_identifier::SlotIdentifier;
pub struct Account {
pub slot_identifier: SlotIdentifier,
pub pubkey: Pubkey,
pub owner: Pubkey,
pub write_version: u64,
pub data: Vec<u8>,
}

67
plugin/src/quic_plugin.rs Normal file
View File

@ -0,0 +1,67 @@
use agave_geyser_plugin_interface::geyser_plugin_interface::GeyserPlugin;
#[derive(Debug)]
pub struct QuicGeyserPlugin {
}
impl GeyserPlugin for QuicGeyserPlugin {
fn name(&self) -> &'static str {
"quic_geyser_plugin"
}
fn on_load(&mut self, _config_file: &str, _is_reload: bool) -> Result<()> {
Ok(())
}
fn on_unload(&mut self) {}
fn update_account(
&self,
account: ReplicaAccountInfoVersions,
slot: Slot,
is_startup: bool,
) -> Result<()> {
Ok(())
}
fn notify_end_of_startup(&self) -> Result<()> {
Ok(())
}
fn update_slot_status(
&self,
slot: Slot,
parent: Option<u64>,
status: SlotStatus,
) -> Result<()> {
Ok(())
}
fn notify_transaction(
&self,
transaction: ReplicaTransactionInfoVersions,
slot: Slot,
) -> Result<()> {
Ok(())
}
fn notify_entry(&self, entry: ReplicaEntryInfoVersions) -> Result<()> {
Ok(())
}
fn notify_block_metadata(&self, blockinfo: ReplicaBlockInfoVersions) -> Result<()> {
Ok(())
}
fn account_data_notifications_enabled(&self) -> bool {
true
}
fn transaction_notifications_enabled(&self) -> bool {
true
}
fn entry_notifications_enabled(&self) -> bool {
false
}
}