diff --git a/common/Cargo.toml b/common/Cargo.toml index f385a54..dc7e7cf 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -3,8 +3,6 @@ name = "quic-geyser-common" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] solana-sdk = { workspace = "true" } solana-streamer = { workspace = "true" } diff --git a/common/src/filters.rs b/common/src/filters.rs index 8694484..ecad576 100644 --- a/common/src/filters.rs +++ b/common/src/filters.rs @@ -3,8 +3,41 @@ use std::collections::HashSet; use serde::{Deserialize, Serialize}; use solana_sdk::pubkey::Pubkey; -#[derive(Serialize, Deserialize, Clone)] +use crate::message::Message; + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] +pub enum Filter { + Account(AccountFilter), +} + +impl Filter { + pub fn allows(&self, message: &Message) -> bool { + match &self { + Filter::Account(account) => account.allows(message), + } + } +} + +// setting owner to 11111111111111111111111111111111 will subscribe to all the accounts +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] pub struct AccountFilter { owner: Option, accounts: Option>, } + +impl AccountFilter { + pub fn allows(&self, message: &Message) -> bool { + if let Message::AccountMsg(account) = message { + if let Some(owner) = self.owner { + // check if filter subscribes to all the accounts + if owner == Pubkey::default() || owner == account.owner { + return true; + } + } + if let Some(accounts) = &self.accounts { + return accounts.contains(&account.pubkey); + } + } + false + } +} diff --git a/common/src/message.rs b/common/src/message.rs index be7c6ef..0b07448 100644 --- a/common/src/message.rs +++ b/common/src/message.rs @@ -1,8 +1,9 @@ use serde::{Deserialize, Serialize}; -use crate::types::account::Account; +use crate::{filters::Filter, types::account::Account}; -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] pub enum Message { AccountMsg(Account), + Filters(Vec), // sent from client to server } diff --git a/common/src/quic/configure_client.rs b/common/src/quic/configure_client.rs index f7559b3..b6a4709 100644 --- a/common/src/quic/configure_client.rs +++ b/common/src/quic/configure_client.rs @@ -1,12 +1,24 @@ -use std::{net::{IpAddr, Ipv4Addr}, sync::Arc, time::Duration}; +use std::{ + net::{IpAddr, Ipv4Addr}, + sync::Arc, + time::Duration, +}; -use quinn::{ClientConfig, Endpoint, EndpointConfig, IdleTimeout, TokioRuntime, TransportConfig, VarInt}; +use quinn::{ + ClientConfig, Endpoint, EndpointConfig, IdleTimeout, TokioRuntime, TransportConfig, VarInt, +}; use solana_sdk::signature::Keypair; use solana_streamer::tls_certificates::new_self_signed_tls_certificate; -use crate::quic::{configure_server::ALPN_GEYSER_PROTOCOL_ID, skip_verification::ClientSkipServerVerification}; +use crate::quic::{ + configure_server::ALPN_GEYSER_PROTOCOL_ID, skip_verification::ClientSkipServerVerification, +}; -pub fn create_client_endpoint(certificate: rustls::Certificate, key: rustls::PrivateKey, maximum_streams: u32) -> Endpoint { +pub fn create_client_endpoint( + certificate: rustls::Certificate, + key: rustls::PrivateKey, + maximum_streams: u32, +) -> Endpoint { const DATAGRAM_RECEIVE_BUFFER_SIZE: usize = 64 * 1024 * 1024; const DATAGRAM_SEND_BUFFER_SIZE: usize = 64 * 1024 * 1024; const INITIAL_MAXIMUM_TRANSMISSION_UNIT: u16 = MINIMUM_MAXIMUM_TRANSMISSION_UNIT; @@ -42,7 +54,7 @@ pub fn create_client_endpoint(certificate: rustls::Certificate, key: rustls::Pri transport_config.datagram_receive_buffer_size(Some(DATAGRAM_RECEIVE_BUFFER_SIZE)); transport_config.datagram_send_buffer_size(DATAGRAM_SEND_BUFFER_SIZE); transport_config.initial_mtu(INITIAL_MAXIMUM_TRANSMISSION_UNIT); - transport_config.max_concurrent_bidi_streams(VarInt::from(maximum_streams)); + transport_config.max_concurrent_bidi_streams(VarInt::from_u32(0)); transport_config.max_concurrent_uni_streams(VarInt::from(maximum_streams)); transport_config.min_mtu(MINIMUM_MAXIMUM_TRANSMISSION_UNIT); transport_config.mtu_discovery_config(None); @@ -58,9 +70,11 @@ pub async fn configure_client( identity: &Keypair, maximum_concurrent_streams: u32, ) -> anyhow::Result { - let (certificate, key) = new_self_signed_tls_certificate( - &identity, - IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), - )?; - Ok(create_client_endpoint(certificate, key, maximum_concurrent_streams)) -} \ No newline at end of file + let (certificate, key) = + new_self_signed_tls_certificate(identity, IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)))?; + Ok(create_client_endpoint( + certificate, + key, + maximum_concurrent_streams, + )) +} diff --git a/common/src/quic/configure_server.rs b/common/src/quic/configure_server.rs index a433321..6328db9 100644 --- a/common/src/quic/configure_server.rs +++ b/common/src/quic/configure_server.rs @@ -34,8 +34,8 @@ pub fn configure_server( server_config.use_retry(true); let config = Arc::get_mut(&mut server_config.transport).unwrap(); - config.max_concurrent_uni_streams((max_concurrent_streams as u32).into()); - let recv_size = (recieve_window_size as u32).into(); + config.max_concurrent_uni_streams((max_concurrent_streams).into()); + let recv_size = recieve_window_size.into(); config.stream_receive_window(recv_size); config.receive_window(recv_size); @@ -44,7 +44,7 @@ pub fn configure_server( config.max_idle_timeout(Some(timeout)); // disable bidi & datagrams - config.max_concurrent_bidi_streams(max_concurrent_streams.into()); + config.max_concurrent_bidi_streams(0u32.into()); config.datagram_receive_buffer_size(None); Ok((server_config, cert_chain_pem)) diff --git a/common/src/quic/connection_manager.rs b/common/src/quic/connection_manager.rs new file mode 100644 index 0000000..04893c2 --- /dev/null +++ b/common/src/quic/connection_manager.rs @@ -0,0 +1,154 @@ +use quinn::{Connection, Endpoint}; +use std::sync::Arc; +use std::{collections::VecDeque, time::Duration}; +use tokio::{sync::RwLock, task::JoinHandle, time::Instant}; + +use crate::{filters::Filter, message::Message}; + +use super::{quinn_reciever::recv_message, quinn_sender::send_message}; + +pub struct ConnectionData { + pub id: u64, + pub connection: Connection, + pub filters: Vec, + pub since: Instant, +} + +impl ConnectionData { + pub fn new(id: u64, connection: Connection) -> Self { + Self { + id, + connection, + filters: vec![], + since: Instant::now(), + } + } +} + +/* + This class will take care of adding connections and filters etc +*/ +pub struct ConnectionManager { + connections: Arc>>, +} + +impl ConnectionManager { + pub fn new(endpoint: Endpoint) -> (Self, JoinHandle<()>) { + let connections: Arc>> = + Arc::new(RwLock::new(VecDeque::new())); + // create a task to add incoming connections + let connection_adder_jh = { + let connections = connections.clone(); + tokio::spawn(async move { + let mut id = 0; + loop { + // acceept incoming connections + if let Some(connecting) = endpoint.accept().await { + let connection_result = connecting.await; + match connection_result { + Ok(connection) => { + // connection established + // add the connection in the connections list + let mut lk = connections.write().await; + id += 1; + let current_id = id; + lk.push_back(ConnectionData::new(current_id, connection.clone())); + drop(lk); + + let connections_tmp = connections.clone(); + + // task to add filters + let connection_to_listen = connection.clone(); + tokio::spawn(async move { + loop { + if let Ok(recv_stream) = + connection_to_listen.accept_uni().await + { + match tokio::time::timeout( + Duration::from_secs(10), + recv_message(recv_stream), + ) + .await + { + Ok(Ok(filters)) => { + let Message::Filters(mut filters) = filters + else { + continue; + }; + let mut lk = connections_tmp.write().await; + let connection_data = + lk.iter_mut().find(|x| x.id == current_id); + if let Some(connection_data) = connection_data { + connection_data + .filters + .append(&mut filters); + } + } + Ok(Err(e)) => { + log::error!("error getting message from the client : {}", e); + } + Err(_timeout) => { + log::warn!("Client request timeout"); + } + } + } + } + }); + + let connections = connections.clone(); + // create a connection removing task + tokio::spawn(async move { + // if connection is closed remove it + let closed_error = connection.closed().await; + log::info!("connection closed with error {}", closed_error); + let mut lk = connections.write().await; + lk.retain(|x| x.id != current_id); + }); + } + Err(e) => log::error!("Error connecting {}", e), + } + } + } + }) + }; + + (Self { connections }, connection_adder_jh) + } + + pub async fn dispach(&self, message: Message, retry_count: u64) { + let lk = self.connections.read().await; + + for connection_data in lk.iter() { + if connection_data.filters.iter().any(|x| x.allows(&message)) { + let connection = connection_data.connection.clone(); + let message = message.clone(); + tokio::spawn(async move { + for _ in 0..retry_count { + let send_stream = connection.open_uni().await; + match send_stream { + Ok(send_stream) => { + match send_message(send_stream, message.clone()).await { + Ok(_) => { + log::debug!("Message sucessfully sent"); + } + Err(e) => { + log::error!( + "error dispatching message and sending data : {}", + e + ) + } + } + } + Err(e) => { + log::error!( + "error dispatching message while creating stream : {}", + e + ); + } + } + } + }); + } + } + } +} diff --git a/common/src/quic/mod.rs b/common/src/quic/mod.rs index 86768e8..c008a35 100644 --- a/common/src/quic/mod.rs +++ b/common/src/quic/mod.rs @@ -1,5 +1,6 @@ pub mod configure_client; pub mod configure_server; +pub mod connection_manager; pub mod quinn_reciever; pub mod quinn_sender; pub mod skip_verification; diff --git a/common/src/quic/quinn_reciever.rs b/common/src/quic/quinn_reciever.rs index a4de320..fe81bed 100644 --- a/common/src/quic/quinn_reciever.rs +++ b/common/src/quic/quinn_reciever.rs @@ -30,10 +30,10 @@ pub async fn recv_message(mut recv_stream: RecvStream) -> anyhow::Result, } diff --git a/plugin/src/quic_plugin.rs b/plugin/src/quic_plugin.rs new file mode 100644 index 0000000..4599485 --- /dev/null +++ b/plugin/src/quic_plugin.rs @@ -0,0 +1,67 @@ +use agave_geyser_plugin_interface::geyser_plugin_interface::GeyserPlugin; + +#[derive(Debug)] +pub struct QuicGeyserPlugin { +} + +impl GeyserPlugin for QuicGeyserPlugin { + fn name(&self) -> &'static str { + "quic_geyser_plugin" + } + + fn on_load(&mut self, _config_file: &str, _is_reload: bool) -> Result<()> { + Ok(()) + } + + fn on_unload(&mut self) {} + + fn update_account( + &self, + account: ReplicaAccountInfoVersions, + slot: Slot, + is_startup: bool, + ) -> Result<()> { + Ok(()) + } + + fn notify_end_of_startup(&self) -> Result<()> { + Ok(()) + } + + fn update_slot_status( + &self, + slot: Slot, + parent: Option, + status: SlotStatus, + ) -> Result<()> { + Ok(()) + } + + fn notify_transaction( + &self, + transaction: ReplicaTransactionInfoVersions, + slot: Slot, + ) -> Result<()> { + Ok(()) + } + + fn notify_entry(&self, entry: ReplicaEntryInfoVersions) -> Result<()> { + Ok(()) + } + + fn notify_block_metadata(&self, blockinfo: ReplicaBlockInfoVersions) -> Result<()> { + Ok(()) + } + + fn account_data_notifications_enabled(&self) -> bool { + true + } + + fn transaction_notifications_enabled(&self) -> bool { + true + } + + fn entry_notifications_enabled(&self) -> bool { + false + } +} \ No newline at end of file