diff --git a/quic-forward-proxy/src/outbound/mod.rs b/quic-forward-proxy/src/outbound/mod.rs index 30b80e1d..968674c2 100644 --- a/quic-forward-proxy/src/outbound/mod.rs +++ b/quic-forward-proxy/src/outbound/mod.rs @@ -1 +1,2 @@ +mod sharder; pub mod tx_forward; diff --git a/quic-forward-proxy/src/outbound/tx_forward.rs b/quic-forward-proxy/src/outbound/tx_forward.rs index 4e9e7f4a..1a13ceb0 100644 --- a/quic-forward-proxy/src/outbound/tx_forward.rs +++ b/quic-forward-proxy/src/outbound/tx_forward.rs @@ -1,3 +1,4 @@ +use crate::outbound::sharder::Sharder; use crate::quic_util::SkipServerVerification; use crate::quinn_auto_reconnect::AutoReconnect; use crate::shared::ForwardPacket; @@ -52,12 +53,15 @@ pub async fn tx_forwarder( agents.entry(tpu_address).or_insert_with(|| { let mut agent_exit_signals = Vec::new(); - for connection_idx in 1..PARALLEL_TPU_CONNECTION_COUNT { + for connection_idx in 0..PARALLEL_TPU_CONNECTION_COUNT { + let sharder = + Sharder::new(connection_idx as u32, PARALLEL_TPU_CONNECTION_COUNT as u32); let global_exit_signal = exit_signal.clone(); let agent_exit_signal = Arc::new(AtomicBool::new(false)); let endpoint_copy = endpoint.clone(); let agent_exit_signal_copy = agent_exit_signal.clone(); // by subscribing we expect to get a copy of each packet + let mut per_connection_receiver = broadcast_in.subscribe(); tokio::spawn(async move { debug!( @@ -82,6 +86,9 @@ pub async fn tx_forwarder( if packet.tpu_address != tpu_address { continue; } + if !sharder.matching(packet.shard_hash()) { + continue; + } let mut transactions_batch: Vec = packet.transactions.clone(); @@ -90,6 +97,9 @@ pub async fn tx_forwarder( if more.tpu_address != tpu_address { continue; } + if !sharder.matching(more.shard_hash()) { + continue; + } transactions_batch.extend(more.transactions.clone()); } diff --git a/quic-forward-proxy/src/shared/mod.rs b/quic-forward-proxy/src/shared/mod.rs index 6b3f4c43..63557f81 100644 --- a/quic-forward-proxy/src/shared/mod.rs +++ b/quic-forward-proxy/src/shared/mod.rs @@ -1,4 +1,6 @@ use solana_sdk::transaction::VersionedTransaction; +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; use std::net::SocketAddr; /// internal structure with transactions and target TPU @@ -7,3 +9,12 @@ pub struct ForwardPacket { pub transactions: Vec, pub tpu_address: SocketAddr, } + +impl ForwardPacket { + pub fn shard_hash(&self) -> u64 { + let mut hasher = DefaultHasher::new(); + // note: assumes that there are transactions with >=0 signatures + self.transactions[0].signatures[0].hash(&mut hasher); + hasher.finish() + } +}