diff --git a/gossip/src/cluster_info.rs b/gossip/src/cluster_info.rs index eb13d93691..4c7ececef9 100644 --- a/gossip/src/cluster_info.rs +++ b/gossip/src/cluster_info.rs @@ -1324,7 +1324,7 @@ impl ClusterInfo { fn append_entrypoint_to_pulls( &self, thread_pool: &ThreadPool, - pulls: &mut Vec<(ContactInfo, Vec)>, + pulls: &mut HashMap>, ) { const THROTTLE_DELAY: u64 = CRDS_GOSSIP_PULL_CRDS_TIMEOUT_MS / 2; let entrypoint = { @@ -1349,17 +1349,16 @@ impl ClusterInfo { } entrypoint.clone() }; - let filters = match pulls.first() { - Some((_, filters)) => filters.clone(), - None => { - let _st = ScopedTimer::from(&self.stats.entrypoint2); - self.gossip - .pull - .build_crds_filters(thread_pool, &self.gossip.crds, MAX_BLOOM_SIZE) - } + let filters = if pulls.is_empty() { + let _st = ScopedTimer::from(&self.stats.entrypoint2); + self.gossip + .pull + .build_crds_filters(thread_pool, &self.gossip.crds, MAX_BLOOM_SIZE) + } else { + pulls.values().flatten().cloned().collect() }; self.stats.pull_from_entrypoint_count.add_relaxed(1); - pulls.push((entrypoint, filters)); + pulls.insert(entrypoint, filters); } /// Splits an input feed of serializable data into chunks where the sum of @@ -1424,30 +1423,29 @@ impl ClusterInfo { ) { let now = timestamp(); let mut pings = Vec::new(); - let mut pulls: Vec<_> = { + let mut pulls = { let _st = ScopedTimer::from(&self.stats.new_pull_requests); - match self.gossip.new_pull_request( - thread_pool, - self.keypair().deref(), - self.my_shred_version(), - now, - gossip_validators, - stakes, - MAX_BLOOM_SIZE, - &self.ping_cache, - &mut pings, - &self.socket_addr_space, - ) { - Err(_) => Vec::default(), - Ok((peer, filters)) => vec![(peer, filters)], - } + self.gossip + .new_pull_request( + thread_pool, + self.keypair().deref(), + self.my_shred_version(), + now, + gossip_validators, + stakes, + MAX_BLOOM_SIZE, + &self.ping_cache, + &mut pings, + &self.socket_addr_space, + ) + .unwrap_or_default() }; self.append_entrypoint_to_pulls(thread_pool, &mut pulls); - let num_requests = pulls.iter().map(|(_, filters)| filters.len() as u64).sum(); + let num_requests = pulls.values().map(Vec::len).sum::() as u64; self.stats.new_pull_requests_count.add_relaxed(num_requests); { let _st = ScopedTimer::from(&self.stats.mark_pull_request); - for (peer, _) in &pulls { + for peer in pulls.keys() { self.gossip.mark_pull_request_creation_time(peer.id, now); } } diff --git a/gossip/src/contact_info.rs b/gossip/src/contact_info.rs index a614fee7e5..1a5a33aa79 100644 --- a/gossip/src/contact_info.rs +++ b/gossip/src/contact_info.rs @@ -12,7 +12,9 @@ use { }; /// Structure representing a node on the network -#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd, AbiExample, Deserialize, Serialize)] +#[derive( + Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, AbiExample, Deserialize, Serialize, +)] pub struct ContactInfo { pub id: Pubkey, /// gossip address diff --git a/gossip/src/crds_gossip.rs b/gossip/src/crds_gossip.rs index 0820ab75a4..afb31c8de1 100644 --- a/gossip/src/crds_gossip.rs +++ b/gossip/src/crds_gossip.rs @@ -217,7 +217,7 @@ impl CrdsGossip { ping_cache: &Mutex, pings: &mut Vec<(SocketAddr, Ping)>, socket_addr_space: &SocketAddrSpace, - ) -> Result<(ContactInfo, Vec), CrdsGossipError> { + ) -> Result>, CrdsGossipError> { self.pull.new_pull_request( thread_pool, &self.crds, diff --git a/gossip/src/crds_gossip_pull.rs b/gossip/src/crds_gossip_pull.rs index aecefead4c..2780bf7dab 100644 --- a/gossip/src/crds_gossip_pull.rs +++ b/gossip/src/crds_gossip_pull.rs @@ -21,10 +21,13 @@ use { crds_gossip_error::CrdsGossipError, crds_value::CrdsValue, ping_pong::PingCache, - weighted_shuffle::WeightedShuffle, }, + itertools::Itertools, lru::LruCache, - rand::Rng, + rand::{ + distributions::{Distribution, WeightedIndex}, + Rng, + }, rayon::{prelude::*, ThreadPool}, solana_bloom::bloom::{AtomicBloom, Bloom}, solana_sdk::{ @@ -228,52 +231,43 @@ impl CrdsGossipPull { ping_cache: &Mutex, pings: &mut Vec<(SocketAddr, Ping)>, socket_addr_space: &SocketAddrSpace, - ) -> Result<(ContactInfo, Vec), CrdsGossipError> { + ) -> Result>, CrdsGossipError> { + // Gossip peers and respective sampling weights. + let peers = self.pull_options( + crds, + &self_keypair.pubkey(), + self_shred_version, + now, + gossip_validators, + stakes, + socket_addr_space, + ); + // Check for nodes which have responded to ping messages. + let mut rng = rand::thread_rng(); let (weights, peers): (Vec<_>, Vec<_>) = { - self.pull_options( - crds, - &self_keypair.pubkey(), - self_shred_version, - now, - gossip_validators, - stakes, - socket_addr_space, - ) - .into_iter() - .map(|(weight, node, gossip_addr)| (weight, (node, gossip_addr))) - .unzip() + let mut ping_cache = ping_cache.lock().unwrap(); + let mut pingf = move || Ping::new_rand(&mut rng, self_keypair).ok(); + let now = Instant::now(); + peers + .into_iter() + .filter_map(|(weight, peer)| { + let node = (peer.id, peer.gossip); + let (check, ping) = ping_cache.check(now, node, &mut pingf); + if let Some(ping) = ping { + pings.push((peer.gossip, ping)); + } + check.then(|| (weight, peer)) + }) + .unzip() }; if peers.is_empty() { return Err(CrdsGossipError::NoPeers); } - let mut rng = rand::thread_rng(); - let mut peers = WeightedShuffle::new("pull-options", &weights) - .shuffle(&mut rng) - .map(|i| peers[i]); - let peer = { - let mut rng = rand::thread_rng(); - let mut ping_cache = ping_cache.lock().unwrap(); - let mut pingf = move || Ping::new_rand(&mut rng, self_keypair).ok(); - let now = Instant::now(); - peers.find(|node| { - let (_, gossip_addr) = *node; - let (check, ping) = ping_cache.check(now, *node, &mut pingf); - if let Some(ping) = ping { - pings.push((gossip_addr, ping)); - } - check - }) - }; - let peer = match peer { - None => return Err(CrdsGossipError::NoPeers), - Some((node, _gossip_addr)) => node, - }; + // Associate each pull-request filter with a randomly selected peer. let filters = self.build_crds_filters(thread_pool, crds, bloom_size); - let peer = match crds.read().unwrap().get::<&ContactInfo>(peer) { - None => return Err(CrdsGossipError::NoPeers), - Some(node) => node.clone(), - }; - Ok((peer, filters)) + let dist = WeightedIndex::new(&weights).unwrap(); + let peers = repeat_with(|| peers[dist.sample(&mut rng)].clone()); + Ok(peers.zip(filters).into_group_map()) } fn pull_options( @@ -285,11 +279,7 @@ impl CrdsGossipPull { gossip_validators: Option<&HashSet>, stakes: &HashMap, socket_addr_space: &SocketAddrSpace, - ) -> Vec<( - u64, // weight - Pubkey, // node - SocketAddr, // gossip address - )> { + ) -> Vec<(/*weight:*/ u64, ContactInfo)> { let mut rng = rand::thread_rng(); let active_cutoff = now.saturating_sub(PULL_ACTIVE_TIMEOUT_MS); let pull_request_time = self.pull_request_time.read().unwrap(); @@ -327,7 +317,7 @@ impl CrdsGossipPull { let weight = get_weight(max_weight, since, stake); // Weights are bounded by max_weight defined above. // So this type-cast should be safe. - ((weight * 100.0) as u64, item.id, item.gossip) + ((weight * 100.0) as u64, item.clone()) }) .collect() } @@ -757,10 +747,9 @@ pub(crate) mod tests { &SocketAddrSpace::Unspecified, ); assert!(!options.is_empty()); - options - .sort_by(|(weight_l, _, _), (weight_r, _, _)| weight_r.partial_cmp(weight_l).unwrap()); + options.sort_by(|(weight_l, _), (weight_r, _)| weight_r.partial_cmp(weight_l).unwrap()); // check that the highest stake holder is also the heaviest weighted. - assert_eq!(stakes[&options[0].1], 3000_u64); + assert_eq!(stakes[&options[0].1.id], 3000_u64); } #[test] @@ -818,7 +807,7 @@ pub(crate) mod tests { &SocketAddrSpace::Unspecified, ) .iter() - .map(|(_, pk, _)| *pk) + .map(|(_, peer)| peer.id) .collect::>(); assert_eq!(options.len(), 1); assert!(!options.contains(&spy.pubkey())); @@ -836,7 +825,7 @@ pub(crate) mod tests { &SocketAddrSpace::Unspecified, ) .iter() - .map(|(_, pk, _)| *pk) + .map(|(_, peer)| peer.id) .collect::>(); assert_eq!(options.len(), 3); assert!(options.contains(&me.pubkey())); @@ -906,7 +895,7 @@ pub(crate) mod tests { &SocketAddrSpace::Unspecified, ); assert_eq!(options.len(), 1); - assert_eq!(options[0].1, node_123.pubkey()); + assert_eq!(options[0].1.id, node_123.pubkey()); } #[test] @@ -1085,8 +1074,8 @@ pub(crate) mod tests { &mut pings, &SocketAddrSpace::Unspecified, ); - let (peer, _) = req.unwrap(); - assert_eq!(peer, *new.contact_info().unwrap()); + let peers: Vec<_> = req.unwrap().into_keys().collect(); + assert_eq!(peers, vec![new.contact_info().unwrap().clone()]); node.mark_pull_request_creation_time(new.contact_info().unwrap().id, now); let offline = ContactInfo::new_localhost(&solana_sdk::pubkey::new_rand(), now); @@ -1110,8 +1099,8 @@ pub(crate) mod tests { ); // Even though the offline node should have higher weight, we shouldn't request from it // until we receive a ping. - let (peer, _) = req.unwrap(); - assert_eq!(peer, *new.contact_info().unwrap()); + let peers: Vec<_> = req.unwrap().into_keys().collect(); + assert_eq!(peers, vec![new.contact_info().unwrap().clone()]); } #[test] @@ -1152,7 +1141,7 @@ pub(crate) mod tests { let ping_cache = Mutex::new(ping_cache); let old = old.contact_info().unwrap(); let count = repeat_with(|| { - let (peer, _filters) = node + let requests = node .new_pull_request( &thread_pool, &crds, @@ -1167,8 +1156,9 @@ pub(crate) mod tests { &SocketAddrSpace::Unspecified, ) .unwrap(); - peer + requests.into_keys() }) + .flatten() .take(100) .filter(|peer| peer != old) .count(); @@ -1250,7 +1240,7 @@ pub(crate) mod tests { ); let dest_crds = RwLock::::default(); - let (_, filters) = req.unwrap(); + let filters = req.unwrap().into_values().flatten(); let mut filters: Vec<_> = filters.into_iter().map(|f| (caller.clone(), f)).collect(); let rsp = CrdsGossipPull::generate_pull_responses( &thread_pool, @@ -1353,7 +1343,7 @@ pub(crate) mod tests { ); let dest_crds = RwLock::::default(); - let (_, filters) = req.unwrap(); + let filters = req.unwrap().into_values().flatten(); let filters: Vec<_> = filters.into_iter().map(|f| (caller.clone(), f)).collect(); let rsp = CrdsGossipPull::generate_pull_responses( &thread_pool, @@ -1439,7 +1429,7 @@ pub(crate) mod tests { &mut pings, &SocketAddrSpace::Unspecified, ); - let (_, filters) = req.unwrap(); + let filters = req.unwrap().into_values().flatten(); let filters: Vec<_> = filters.into_iter().map(|f| (caller.clone(), f)).collect(); let rsp = CrdsGossipPull::generate_pull_responses( &thread_pool, diff --git a/gossip/tests/crds_gossip.rs b/gossip/tests/crds_gossip.rs index 9fb2770b29..7095465c38 100644 --- a/gossip/tests/crds_gossip.rs +++ b/gossip/tests/crds_gossip.rs @@ -490,9 +490,9 @@ fn network_run_pull( let requests: Vec<_> = { network_values .par_iter() - .filter_map(|from| { + .flat_map_iter(|from| { let mut pings = Vec::new(); - let (peer, filters) = from + let requests = from .gossip .new_pull_request( thread_pool, @@ -506,12 +506,14 @@ fn network_run_pull( &mut pings, &SocketAddrSpace::Unspecified, ) - .ok()?; + .unwrap_or_default(); let from_pubkey = from.keypair.pubkey(); let label = CrdsValueLabel::ContactInfo(from_pubkey); let gossip_crds = from.gossip.crds.read().unwrap(); let self_info = gossip_crds.get::<&CrdsValue>(&label).unwrap().clone(); - Some((peer.id, filters, self_info)) + requests + .into_iter() + .map(move |(peer, filters)| (peer.id, filters, self_info.clone())) }) .collect() };