diff --git a/src/packet.rs b/src/packet.rs index 043b8cd18a..53b79fc40c 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -117,6 +117,14 @@ impl Default for Packets { } } +impl Packets { + pub fn set_addr(&mut self, addr: &SocketAddr) { + for m in self.packets.iter_mut() { + m.meta.set_addr(&addr); + } + } +} + #[derive(Clone)] pub struct Blob { pub data: [u8; BLOB_SIZE], @@ -475,7 +483,17 @@ mod tests { use solana_sdk::transaction::Transaction; use std::io; use std::io::Write; - use std::net::UdpSocket; + use std::net::{Ipv4Addr, SocketAddr, UdpSocket}; + + #[test] + fn test_packets_set_addr() { + // test that the address is actually being updated + let send_addr = socketaddr!([127, 0, 0, 1], 123); + let packets = vec![Packet::default()]; + let mut msgs = Packets { packets }; + msgs.set_addr(&send_addr); + assert_eq!(SocketAddr::from(msgs.packets[0].meta.addr()), send_addr); + } #[test] pub fn packet_send_recv() { diff --git a/src/tpu_forwarder.rs b/src/tpu_forwarder.rs index 5e2796aa15..203e233679 100644 --- a/src/tpu_forwarder.rs +++ b/src/tpu_forwarder.rs @@ -5,20 +5,27 @@ use crate::cluster_info::ClusterInfo; use crate::contact_info::ContactInfo; use crate::counter::Counter; -use crate::packet::Packets; use crate::result::Result; use crate::service::Service; use crate::streamer::{self, PacketReceiver}; use log::Level; use solana_sdk::pubkey::Pubkey; -use std::error::Error; -use std::net::UdpSocket; -use std::result; +use std::net::{SocketAddr, UdpSocket}; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::mpsc::channel; use std::sync::{Arc, RwLock}; use std::thread::{self, Builder, JoinHandle}; +fn get_forwarding_addr(leader_data: Option<&ContactInfo>, my_id: &Pubkey) -> Option { + let leader_data = leader_data?; + if leader_data.id == *my_id || !ContactInfo::is_valid_address(&leader_data.tpu) { + // weird cases, but we don't want to broadcast, send to ANY, or + // induce an infinite loop, but this shouldn't happen, or shouldn't be true for long... + return None; + } + Some(leader_data.tpu) +} + pub struct TpuForwarder { exit: Arc, thread_hdls: Vec>, @@ -28,10 +35,7 @@ impl TpuForwarder { fn forward(receiver: &PacketReceiver, cluster_info: &Arc>) -> Result<()> { let socket = UdpSocket::bind("0.0.0.0:0")?; - let my_id = cluster_info - .read() - .expect("cluster_info.read() in TpuForwarder::forward()") - .id(); + let my_id = cluster_info.read().unwrap().id(); loop { let msgs = receiver.recv()?; @@ -40,37 +44,13 @@ impl TpuForwarder { "tpu_forwarder-msgs_received", msgs.read().unwrap().packets.len() ); - let leader_data = cluster_info - .read() - .expect("cluster_info.read() in TpuForwarder::forward()") - .leader_data() - .cloned(); - match TpuForwarder::update_addrs(leader_data, &my_id, &msgs.clone()) { - Ok(_) => msgs.read().unwrap().send_to(&socket)?, - Err(_) => continue, - } - } - } - fn update_addrs( - leader_data: Option, - my_id: &Pubkey, - msgs: &Arc>, - ) -> result::Result<(), Box> { - match leader_data { - Some(leader_data) => { - if leader_data.id == *my_id || !ContactInfo::is_valid_address(&leader_data.tpu) { - // weird cases, but we don't want to broadcast, send to ANY, or - // induce an infinite loop, but this shouldn't happen, or shouldn't be true for long... - return Err("Invalid leader addr")?; - } + let send_addr = get_forwarding_addr(cluster_info.read().unwrap().leader_data(), &my_id); - for m in msgs.write().unwrap().packets.iter_mut() { - m.meta.set_addr(&leader_data.tpu); - } - Ok(()) + if let Some(send_addr) = send_addr { + msgs.write().unwrap().set_addr(&send_addr); + msgs.read().unwrap().send_to(&socket)?; } - _ => Err("No leader contact data")?, } } @@ -123,51 +103,26 @@ impl Service for TpuForwarder { mod tests { use super::*; use crate::contact_info::ContactInfo; - use crate::packet::Packet; - use solana_netutil::bind_in_range; use solana_sdk::signature::{Keypair, KeypairUtil}; - use std::net::{Ipv4Addr, SocketAddr}; #[test] - pub fn test_update_addrs() { - let keypair = Keypair::new(); - let my_id = keypair.pubkey(); - // test with no pubkey - assert!(!TpuForwarder::update_addrs( - None, - &my_id, - &Arc::new(RwLock::new(Packets::default())) - ) - .is_ok()); - // test with no tpu - assert!(!TpuForwarder::update_addrs( - Some(ContactInfo::default()), - &my_id, - &Arc::new(RwLock::new(Packets::default())) - ) - .is_ok()); - // test with my pubkey + fn test_get_forwarding_addr() { + let my_id = Keypair::new().pubkey(); + + // Test with no leader + assert_eq!(get_forwarding_addr(None, &my_id,), None); + + // Test with no TPU + let leader_data = ContactInfo::default(); + assert_eq!(get_forwarding_addr(Some(&leader_data), &my_id,), None); + + // Test with my pubkey let leader_data = ContactInfo::new_localhost(my_id, 0); - assert!(!TpuForwarder::update_addrs( - Some(leader_data), - &my_id, - &Arc::new(RwLock::new(Packets::default())) - ) - .is_ok()); - // test that the address is actually being updated - let (port, _) = bind_in_range((8000, 10000)).unwrap(); - let leader_data = ContactInfo::new_with_socketaddr(&socketaddr!([127, 0, 0, 1], port)); - let packet = Packet::default(); - let p = Packets { - packets: vec![packet], - }; - let msgs = Arc::new(RwLock::new(p)); - assert!( - TpuForwarder::update_addrs(Some(leader_data.clone()), &my_id, &msgs.clone()).is_ok() - ); - assert_eq!( - SocketAddr::from(msgs.read().unwrap().packets[0].meta.addr()), - leader_data.tpu - ); + assert_eq!(get_forwarding_addr(Some(&leader_data), &my_id,), None); + + // Test with pubkey other than mine + let alice_id = Keypair::new().pubkey(); + let leader_data = ContactInfo::new_localhost(alice_id, 0); + assert!(get_forwarding_addr(Some(&leader_data), &my_id,).is_some()); } }