From c3048b451dc1d8da683ee97c828092092a8ce1f9 Mon Sep 17 00:00:00 2001 From: behzad nouri Date: Thu, 3 Dec 2020 14:26:07 +0000 Subject: [PATCH] samples repair peers using WeightedIndex (#13919) To output one random sample, weighted_best generates n random numbers: https://github.com/solana-labs/solana/blob/f751a5d4e/core/src/weighted_shuffle.rs#L38-L63 WeightedIndex does so with only one random number: https://github.com/rust-random/rand/blob/eb02f0e46/src/distributions/weighted_index.rs#L223-L240 Additionally, if the index is already constructed, it only does a total of O(log(n)) amount of work; which can be achieved if RepairCache, caches the weighted index: https://github.com/solana-labs/solana/blob/f751a5d4e/core/src/serve_repair.rs#L83 Also, the repair-peers code can be reorganized to have fewer redundant unlock-then-lock code. --- core/src/cluster_info.rs | 48 +++++++++++++++++++----------------- core/src/cluster_slots.rs | 48 +++++++++++++++++++----------------- core/src/crds.rs | 11 ++++++--- core/src/result.rs | 6 +++++ core/src/serve_repair.rs | 27 +++++++++++--------- core/src/weighted_shuffle.rs | 1 + 6 files changed, 82 insertions(+), 59 deletions(-) diff --git a/core/src/cluster_info.rs b/core/src/cluster_info.rs index 7c4712a09a..3defb1c70e 100644 --- a/core/src/cluster_info.rs +++ b/core/src/cluster_info.rs @@ -1170,13 +1170,15 @@ impl ClusterInfo { /// all validators that have a valid tvu port and are on the same `shred_version`. pub fn tvu_peers(&self) -> Vec { + let self_pubkey = self.id(); + let self_shred_version = self.my_shred_version(); self.time_gossip_read_lock("tvu_peers", &self.stats.tvu_peers) .crds .get_nodes_contact_info() - .filter(|x| { - ContactInfo::is_valid_address(&x.tvu) - && x.id != self.id() - && x.shred_version == self.my_shred_version() + .filter(|node| { + node.id != self_pubkey + && node.shred_version == self_shred_version + && ContactInfo::is_valid_address(&node.tvu) }) .cloned() .collect() @@ -1200,22 +1202,24 @@ impl ClusterInfo { /// all tvu peers with valid gossip addrs that likely have the slot being requested pub fn repair_peers(&self, slot: Slot) -> Vec { let mut time = Measure::start("repair_peers"); - let ret = ClusterInfo::tvu_peers(self) - .into_iter() - .filter(|x| { - x.id != self.id() - && x.shred_version == self.my_shred_version() - && ContactInfo::is_valid_address(&x.serve_repair) - && { - self.get_lowest_slot_for_node(&x.id, None, |lowest_slot, _| { - lowest_slot.lowest <= slot - }) - .unwrap_or_else(|| /* fallback to legacy behavior */ true) - } - }) - .collect(); + // self.tvu_peers() already filters on: + // node.id != self.id() && + // node.shred_verion == self.my_shred_version() + let nodes = { + let gossip = self.gossip.read().unwrap(); + self.tvu_peers() + .into_iter() + .filter(|node| { + ContactInfo::is_valid_address(&node.serve_repair) + && match gossip.crds.get_lowest_slot(node.id) { + None => true, // fallback to legacy behavior + Some(lowest_slot) => lowest_slot.lowest <= slot, + } + }) + .collect() + }; self.stats.repair_peers.add_measure(&mut time); - ret + nodes } fn is_spy_node(contact_info: &ContactInfo) -> bool { @@ -1654,7 +1658,7 @@ impl ClusterInfo { push_messages .into_iter() .filter_map(|(pubkey, messages)| { - let peer = gossip.crds.get_contact_info(&pubkey)?; + let peer = gossip.crds.get_contact_info(pubkey)?; Some((peer.gossip, messages)) }) .collect() @@ -2351,7 +2355,7 @@ impl ClusterInfo { let gossip = self.gossip.read().unwrap(); messages .iter() - .map(|(from, _)| match gossip.crds.get_contact_info(from) { + .map(|(from, _)| match gossip.crds.get_contact_info(*from) { None => 0, Some(info) => info.shred_version, }) @@ -2424,7 +2428,7 @@ impl ClusterInfo { .into_par_iter() .with_min_len(256) .filter_map(|(from, prunes)| { - let peer = gossip.crds.get_contact_info(&from)?; + let peer = gossip.crds.get_contact_info(from)?; let mut prune_data = PruneData { pubkey: self_pubkey, prunes, diff --git a/core/src/cluster_slots.rs b/core/src/cluster_slots.rs index aefdb0139c..2235f09573 100644 --- a/core/src/cluster_slots.rs +++ b/core/src/cluster_slots.rs @@ -106,28 +106,30 @@ impl ClusterSlots { } } - pub fn compute_weights(&self, slot: Slot, repair_peers: &[ContactInfo]) -> Vec<(u64, usize)> { - let slot_peers = self.lookup(slot); + pub fn compute_weights(&self, slot: Slot, repair_peers: &[ContactInfo]) -> Vec { + let stakes = { + let validator_stakes = self.validator_stakes.read().unwrap(); + repair_peers + .iter() + .map(|peer| { + validator_stakes + .get(&peer.id) + .map(|node| node.total_stake) + .unwrap_or(0) + + 1 + }) + .collect() + }; + let slot_peers = match self.lookup(slot) { + None => return stakes, + Some(slot_peers) => slot_peers, + }; + let slot_peers = slot_peers.read().unwrap(); repair_peers .iter() - .enumerate() - .map(|(i, x)| { - let peer_stake = slot_peers - .as_ref() - .and_then(|v| v.read().unwrap().get(&x.id).cloned()) - .unwrap_or(0); - ( - 1 + peer_stake - + self - .validator_stakes - .read() - .unwrap() - .get(&x.id) - .map(|v| v.total_stake) - .unwrap_or(0), - i, - ) - }) + .map(|peer| slot_peers.get(&peer.id).cloned().unwrap_or(0)) + .zip(stakes) + .map(|(a, b)| a + b) .collect() } @@ -228,7 +230,7 @@ mod tests { fn test_compute_weights() { let cs = ClusterSlots::default(); let ci = ContactInfo::default(); - assert_eq!(cs.compute_weights(0, &[ci]), vec![(1, 0)]); + assert_eq!(cs.compute_weights(0, &[ci]), vec![1]); } #[test] @@ -249,7 +251,7 @@ mod tests { c2.id = k2; assert_eq!( cs.compute_weights(0, &[c1, c2]), - vec![(std::u64::MAX / 2 + 1, 0), (1, 1)] + vec![std::u64::MAX / 2 + 1, 1] ); } @@ -281,7 +283,7 @@ mod tests { c2.id = k2; assert_eq!( cs.compute_weights(0, &[c1, c2]), - vec![(std::u64::MAX / 2 + 1, 0), (1, 1)] + vec![std::u64::MAX / 2 + 1, 1] ); } diff --git a/core/src/crds.rs b/core/src/crds.rs index 53d390d066..544f523250 100644 --- a/core/src/crds.rs +++ b/core/src/crds.rs @@ -26,7 +26,7 @@ use crate::contact_info::ContactInfo; use crate::crds_shards::CrdsShards; -use crate::crds_value::{CrdsData, CrdsValue, CrdsValueLabel}; +use crate::crds_value::{CrdsData, CrdsValue, CrdsValueLabel, LowestSlot}; use bincode::serialize; use indexmap::map::{rayon::ParValues, Entry, IndexMap, Iter, Values}; use indexmap::set::IndexSet; @@ -182,11 +182,16 @@ impl Crds { self.table.get(label) } - pub fn get_contact_info(&self, pubkey: &Pubkey) -> Option<&ContactInfo> { - let label = CrdsValueLabel::ContactInfo(*pubkey); + pub fn get_contact_info(&self, pubkey: Pubkey) -> Option<&ContactInfo> { + let label = CrdsValueLabel::ContactInfo(pubkey); self.table.get(&label)?.value.contact_info() } + pub fn get_lowest_slot(&self, pubkey: Pubkey) -> Option<&LowestSlot> { + let lable = CrdsValueLabel::LowestSlot(pubkey); + self.table.get(&lable)?.value.lowest_slot() + } + /// Returns all entries which are ContactInfo. pub fn get_nodes(&self) -> impl Iterator { self.nodes.iter().map(move |i| self.table.index(*i)) diff --git a/core/src/result.rs b/core/src/result.rs index 6c53662936..c633be554b 100644 --- a/core/src/result.rs +++ b/core/src/result.rs @@ -31,6 +31,7 @@ pub enum Error { BlockstoreError(blockstore::BlockstoreError), FsExtra(fs_extra::error::Error), SnapshotError(snapshot_utils::SnapshotError), + WeightedIndexError(rand::distributions::weighted::WeightedError), } pub type Result = std::result::Result; @@ -143,6 +144,11 @@ impl std::convert::From for Error { Error::SnapshotError(e) } } +impl std::convert::From for Error { + fn from(e: rand::distributions::weighted::WeightedError) -> Error { + Error::WeightedIndexError(e) + } +} #[cfg(test)] mod tests { diff --git a/core/src/serve_repair.rs b/core/src/serve_repair.rs index c4183616cc..6772981cd4 100644 --- a/core/src/serve_repair.rs +++ b/core/src/serve_repair.rs @@ -8,6 +8,7 @@ use crate::{ weighted_shuffle::weighted_best, }; use bincode::serialize; +use rand::distributions::{Distribution, WeightedIndex}; use solana_ledger::{blockstore::Blockstore, shred::Nonce}; use solana_measure::measure::Measure; use solana_measure::thread_mem_usage; @@ -21,7 +22,7 @@ use solana_sdk::{ }; use solana_streamer::streamer::{PacketReceiver, PacketSender}; use std::{ - collections::{HashMap, HashSet}, + collections::{hash_map::Entry, HashMap, HashSet}, net::SocketAddr, sync::atomic::{AtomicBool, Ordering}, sync::{Arc, RwLock}, @@ -80,7 +81,7 @@ pub struct ServeRepair { cluster_info: Arc, } -type RepairCache = HashMap, Vec<(u64, usize)>)>; +type RepairCache = HashMap, WeightedIndex)>; impl ServeRepair { /// Without a valid keypair gossip will not function. Only useful for tests. @@ -387,16 +388,20 @@ impl ServeRepair { // find a peer that appears to be accepting replication and has the desired slot, as indicated // by a valid tvu port location let slot = repair_request.slot(); - if cache.get(&slot).is_none() { - let repair_peers = self.repair_peers(&repair_validators, slot); - if repair_peers.is_empty() { - return Err(ClusterInfoError::NoPeers.into()); + let (repair_peers, weighted_index) = match cache.entry(slot) { + Entry::Occupied(entry) => entry.into_mut(), + Entry::Vacant(entry) => { + let repair_peers = self.repair_peers(&repair_validators, slot); + if repair_peers.is_empty() { + return Err(Error::from(ClusterInfoError::NoPeers)); + } + let weights = cluster_slots.compute_weights(slot, &repair_peers); + debug_assert_eq!(weights.len(), repair_peers.len()); + let weighted_index = WeightedIndex::new(weights)?; + entry.insert((repair_peers, weighted_index)) } - let weights = cluster_slots.compute_weights(slot, &repair_peers); - cache.insert(slot, (repair_peers, weights)); - } - let (repair_peers, weights) = cache.get(&slot).unwrap(); - let n = weighted_best(&weights, solana_sdk::pubkey::new_rand().to_bytes()); + }; + let n = weighted_index.sample(&mut rand::thread_rng()); let addr = repair_peers[n].serve_repair; // send the request to the peer's serve_repair port let repair_peer_id = repair_peers[n].id; let out = self.map_repair_request( diff --git a/core/src/weighted_shuffle.rs b/core/src/weighted_shuffle.rs index 8220f77256..29afc6c64e 100644 --- a/core/src/weighted_shuffle.rs +++ b/core/src/weighted_shuffle.rs @@ -37,6 +37,7 @@ where /// Returns the highest index after computing a weighted shuffle. /// Saves doing any sorting for O(n) max calculation. +// TODO: Remove in favor of rand::distributions::WeightedIndex. pub fn weighted_best(weights_and_indexes: &[(u64, usize)], seed: [u8; 32]) -> usize { if weights_and_indexes.is_empty() { return 0;