diff --git a/core/src/serve_repair.rs b/core/src/serve_repair.rs index 0d36bea19..6a02fe189 100644 --- a/core/src/serve_repair.rs +++ b/core/src/serve_repair.rs @@ -17,7 +17,7 @@ use { solana_gossip::{ cluster_info::{ClusterInfo, ClusterInfoError}, contact_info::ContactInfo, - weighted_shuffle::{weighted_best, weighted_shuffle}, + weighted_shuffle::WeightedShuffle, }, solana_ledger::{ ancestor_iterator::{AncestorIterator, AncestorIteratorWithHash}, @@ -525,16 +525,17 @@ impl ServeRepair { if repair_peers.is_empty() { return Err(ClusterInfoError::NoPeers.into()); } - let weights = cluster_slots.compute_weights_exclude_nonfrozen(slot, &repair_peers); - let mut sampled_validators = weighted_shuffle( - weights.into_iter().map(|(stake, _i)| stake), - solana_sdk::pubkey::new_rand().to_bytes(), - ); - sampled_validators.truncate(ANCESTOR_HASH_REPAIR_SAMPLE_SIZE); - Ok(sampled_validators + let (weights, index): (Vec<_>, Vec<_>) = cluster_slots + .compute_weights_exclude_nonfrozen(slot, &repair_peers) .into_iter() + .unzip(); + let peers = WeightedShuffle::new("repair_request_ancestor_hashes", &weights) + .shuffle(&mut rand::thread_rng()) + .take(ANCESTOR_HASH_REPAIR_SAMPLE_SIZE) + .map(|i| index[i]) .map(|i| (repair_peers[i].id, repair_peers[i].serve_repair)) - .collect()) + .collect(); + Ok(peers) } pub fn repair_request_duplicate_compute_best_peer( @@ -547,8 +548,12 @@ impl ServeRepair { if repair_peers.is_empty() { return Err(ClusterInfoError::NoPeers.into()); } - let weights = cluster_slots.compute_weights_exclude_nonfrozen(slot, &repair_peers); - let n = weighted_best(&weights, solana_sdk::pubkey::new_rand().to_bytes()); + let (weights, index): (Vec<_>, Vec<_>) = cluster_slots + .compute_weights_exclude_nonfrozen(slot, &repair_peers) + .into_iter() + .unzip(); + let k = WeightedIndex::new(weights)?.sample(&mut rand::thread_rng()); + let n = index[k]; Ok((repair_peers[n].id, repair_peers[n].serve_repair)) } diff --git a/gossip/benches/weighted_shuffle.rs b/gossip/benches/weighted_shuffle.rs index e21f3eb20..309062f89 100644 --- a/gossip/benches/weighted_shuffle.rs +++ b/gossip/benches/weighted_shuffle.rs @@ -5,7 +5,7 @@ extern crate test; use { rand::{Rng, SeedableRng}, rand_chacha::ChaChaRng, - solana_gossip::weighted_shuffle::{weighted_shuffle, WeightedShuffle}, + solana_gossip::weighted_shuffle::WeightedShuffle, std::iter::repeat_with, test::Bencher, }; @@ -15,18 +15,7 @@ fn make_weights(rng: &mut R) -> Vec { } #[bench] -fn bench_weighted_shuffle_old(bencher: &mut Bencher) { - let mut seed = [0u8; 32]; - let mut rng = rand::thread_rng(); - let weights = make_weights(&mut rng); - bencher.iter(|| { - rng.fill(&mut seed[..]); - weighted_shuffle::>(weights.iter(), seed); - }); -} - -#[bench] -fn bench_weighted_shuffle_new(bencher: &mut Bencher) { +fn bench_weighted_shuffle(bencher: &mut Bencher) { let mut seed = [0u8; 32]; let mut rng = rand::thread_rng(); let weights = make_weights(&mut rng); diff --git a/gossip/src/weighted_shuffle.rs b/gossip/src/weighted_shuffle.rs index 9975af352..8b2cf2311 100644 --- a/gossip/src/weighted_shuffle.rs +++ b/gossip/src/weighted_shuffle.rs @@ -1,18 +1,12 @@ //! The `weighted_shuffle` module provides an iterator over shuffled weights. use { - itertools::Itertools, - num_traits::{CheckedAdd, FromPrimitive, ToPrimitive}, + num_traits::CheckedAdd, rand::{ distributions::uniform::{SampleUniform, UniformSampler}, - Rng, SeedableRng, - }, - rand_chacha::ChaChaRng, - std::{ - borrow::Borrow, - iter, - ops::{AddAssign, Div, Sub, SubAssign}, + Rng, }, + std::ops::{AddAssign, Sub, SubAssign}, }; /// Implements an iterator where indices are shuffled according to their @@ -182,68 +176,12 @@ where } } -/// Returns a list of indexes shuffled based on the input weights -/// Note - The sum of all weights must not exceed `u64::MAX` -pub fn weighted_shuffle(weights: F, seed: [u8; 32]) -> Vec -where - T: Copy + PartialOrd + iter::Sum + Div + FromPrimitive + ToPrimitive, - B: Borrow, - F: Iterator + Clone, -{ - let total_weight: T = weights.clone().map(|x| *x.borrow()).sum(); - let mut rng = ChaChaRng::from_seed(seed); - weights - .enumerate() - .map(|(i, weight)| { - let weight = weight.borrow(); - // This generates an "inverse" weight but it avoids floating point math - let x = (total_weight / *weight) - .to_u64() - .expect("values > u64::max are not supported"); - ( - i, - // capture the u64 into u128s to prevent overflow - rng.gen_range(1, u128::from(std::u16::MAX)) * u128::from(x), - ) - }) - // sort in ascending order - .sorted_by(|(_, l_val), (_, r_val)| l_val.cmp(r_val)) - .map(|x| x.0) - .collect() -} - -/// Returns the highest index after computing a weighted shuffle. -/// Saves doing any sorting for O(n) max calculation. -// TODO: Remove in favor of rand::distributions::WeightedIndex. -pub fn weighted_best(weights_and_indexes: &[(u64, usize)], seed: [u8; 32]) -> usize { - if weights_and_indexes.is_empty() { - return 0; - } - let mut rng = ChaChaRng::from_seed(seed); - let total_weight: u64 = weights_and_indexes.iter().map(|x| x.0).sum(); - let mut lowest_weight = std::u128::MAX; - let mut best_index = 0; - for v in weights_and_indexes { - // This generates an "inverse" weight but it avoids floating point math - let x = (total_weight / v.0) - .to_u64() - .expect("values > u64::max are not supported"); - // capture the u64 into u128s to prevent overflow - let computed_weight = rng.gen_range(1, u128::from(std::u16::MAX)) * u128::from(x); - // The highest input weight maps to the lowest computed weight - if computed_weight < lowest_weight { - lowest_weight = computed_weight; - best_index = v.1; - } - } - - best_index -} - #[cfg(test)] mod tests { use { super::*, + rand::SeedableRng, + rand_chacha::ChaChaRng, std::{convert::TryInto, iter::repeat_with}, }; @@ -280,72 +218,6 @@ mod tests { shuffle } - #[test] - fn test_weighted_shuffle_iterator() { - let mut test_set = [0; 6]; - let mut count = 0; - let shuffle = weighted_shuffle(vec![50, 10, 2, 1, 1, 1].into_iter(), [0x5a; 32]); - shuffle.into_iter().for_each(|x| { - assert_eq!(test_set[x], 0); - test_set[x] = 1; - count += 1; - }); - assert_eq!(count, 6); - } - - #[test] - fn test_weighted_shuffle_iterator_large() { - let mut test_set = [0; 100]; - let mut test_weights = vec![0; 100]; - (0..100).for_each(|i| test_weights[i] = (i + 1) as u64); - let mut count = 0; - let shuffle = weighted_shuffle(test_weights.into_iter(), [0xa5; 32]); - shuffle.into_iter().for_each(|x| { - assert_eq!(test_set[x], 0); - test_set[x] = 1; - count += 1; - }); - assert_eq!(count, 100); - } - - #[test] - fn test_weighted_shuffle_compare() { - let shuffle = weighted_shuffle(vec![50, 10, 2, 1, 1, 1].into_iter(), [0x5a; 32]); - - let shuffle1 = weighted_shuffle(vec![50, 10, 2, 1, 1, 1].into_iter(), [0x5a; 32]); - shuffle1 - .into_iter() - .zip(shuffle.into_iter()) - .for_each(|(x, y)| { - assert_eq!(x, y); - }); - } - - #[test] - fn test_weighted_shuffle_imbalanced() { - let mut weights = vec![std::u32::MAX as u64; 3]; - weights.push(1); - let shuffle = weighted_shuffle(weights.iter().cloned(), [0x5a; 32]); - shuffle.into_iter().for_each(|x| { - if x == weights.len() - 1 { - assert_eq!(weights[x], 1); - } else { - assert_eq!(weights[x], std::u32::MAX as u64); - } - }); - } - - #[test] - fn test_weighted_best() { - let weights_and_indexes: Vec<_> = vec![100u64, 1000, 10_000, 10] - .into_iter() - .enumerate() - .map(|(i, weight)| (weight, i)) - .collect(); - let best_index = weighted_best(&weights_and_indexes, [0x5b; 32]); - assert_eq!(best_index, 2); - } - // Asserts that empty weights will return empty shuffle. #[test] fn test_weighted_shuffle_empty_weights() { diff --git a/gossip/tests/cluster_info.rs b/gossip/tests/cluster_info.rs index 3d99efad7..1b45ec2aa 100644 --- a/gossip/tests/cluster_info.rs +++ b/gossip/tests/cluster_info.rs @@ -2,12 +2,14 @@ use { crossbeam_channel::{unbounded, Receiver, Sender, TryRecvError}, itertools::Itertools, + rand::SeedableRng, + rand_chacha::ChaChaRng, rayon::{iter::ParallelIterator, prelude::*}, serial_test::serial, solana_gossip::{ cluster_info::{compute_retransmit_peers, ClusterInfo}, contact_info::ContactInfo, - weighted_shuffle::weighted_shuffle, + weighted_shuffle::WeightedShuffle, }, solana_sdk::{pubkey::Pubkey, signer::keypair::Keypair}, solana_streamer::socket::SocketAddrSpace, @@ -95,11 +97,13 @@ fn shuffle_peers_and_index( } fn stake_weighted_shuffle(stakes_and_index: &[(u64, usize)], seed: [u8; 32]) -> Vec<(u64, usize)> { - let stake_weights = stakes_and_index.iter().map(|(w, _)| *w); - - let shuffle = weighted_shuffle(stake_weights, seed); - - shuffle.iter().map(|x| stakes_and_index[*x]).collect() + let mut rng = ChaChaRng::from_seed(seed); + let stake_weights: Vec<_> = stakes_and_index.iter().map(|(w, _)| *w).collect(); + let shuffle = WeightedShuffle::new("stake_weighted_shuffle", &stake_weights); + shuffle + .shuffle(&mut rng) + .map(|i| stakes_and_index[i]) + .collect() } fn retransmit(