From 801337a422101c6dde4cdcfd267c3b5f5e6d04fc Mon Sep 17 00:00:00 2001 From: Sagar Dhawan Date: Tue, 29 Oct 2019 21:02:11 -0700 Subject: [PATCH] Refactor Weighted Shuffle (#6614) automerge --- core/src/cluster_info.rs | 12 +++++------- core/src/crds_gossip_push.rs | 7 +++---- core/src/retransmit_stage.rs | 4 +--- core/src/weighted_shuffle.rs | 21 +++++++++++---------- 4 files changed, 20 insertions(+), 24 deletions(-) diff --git a/core/src/cluster_info.rs b/core/src/cluster_info.rs index bbeffaeeb9..cebff05adb 100644 --- a/core/src/cluster_info.rs +++ b/core/src/cluster_info.rs @@ -26,9 +26,7 @@ use crate::weighted_shuffle::{weighted_best, weighted_shuffle}; use bincode::{deserialize, serialize, serialized_size}; use core::cmp; use itertools::Itertools; -use rand::SeedableRng; use rand::{thread_rng, Rng}; -use rand_chacha::ChaChaRng; use solana_ledger::bank_forks::BankForks; use solana_ledger::blocktree::Blocktree; use solana_ledger::staking_utils; @@ -510,11 +508,11 @@ impl ClusterInfo { fn stake_weighted_shuffle( stakes_and_index: &[(u64, usize)], - rng: ChaChaRng, + seed: [u8; 32], ) -> Vec<(u64, usize)> { let stake_weights = stakes_and_index.iter().map(|(w, _)| *w).collect(); - let shuffle = weighted_shuffle(stake_weights, rng); + let shuffle = weighted_shuffle(stake_weights, seed); shuffle.iter().map(|x| stakes_and_index[*x]).collect() } @@ -536,9 +534,9 @@ impl ClusterInfo { id: &Pubkey, peers: &[ContactInfo], stakes_and_index: &[(u64, usize)], - rng: ChaChaRng, + seed: [u8; 32], ) -> (usize, Vec<(u64, usize)>) { - let shuffled_stakes_and_index = ClusterInfo::stake_weighted_shuffle(stakes_and_index, rng); + let shuffled_stakes_and_index = ClusterInfo::stake_weighted_shuffle(stakes_and_index, seed); let mut self_index = 0; shuffled_stakes_and_index .iter() @@ -723,7 +721,7 @@ impl ClusterInfo { .into_iter() .zip(seeds) .map(|(shred, seed)| { - let broadcast_index = weighted_best(&peers_and_stakes, ChaChaRng::from_seed(*seed)); + let broadcast_index = weighted_best(&peers_and_stakes, *seed); (shred, &peers[broadcast_index].tvu) }) diff --git a/core/src/crds_gossip_push.rs b/core/src/crds_gossip_push.rs index 8f583212ea..5afbe9b7ff 100644 --- a/core/src/crds_gossip_push.rs +++ b/core/src/crds_gossip_push.rs @@ -20,8 +20,7 @@ use indexmap::map::IndexMap; use itertools::Itertools; use rand; use rand::seq::SliceRandom; -use rand::{thread_rng, RngCore, SeedableRng}; -use rand_chacha::ChaChaRng; +use rand::{thread_rng, RngCore}; use solana_runtime::bloom::Bloom; use solana_sdk::hash::Hash; use solana_sdk::pubkey::Pubkey; @@ -106,7 +105,7 @@ impl CrdsGossipPush { seed[0..8].copy_from_slice(&thread_rng().next_u64().to_le_bytes()); let shuffle = weighted_shuffle( staked_peers.iter().map(|(_, stake)| *stake).collect_vec(), - ChaChaRng::from_seed(seed), + seed, ); let mut keep = HashSet::new(); @@ -244,7 +243,7 @@ impl CrdsGossipPush { seed[0..8].copy_from_slice(&thread_rng().next_u64().to_le_bytes()); let mut shuffle = weighted_shuffle( options.iter().map(|weighted| weighted.0).collect_vec(), - ChaChaRng::from_seed(seed), + seed, ) .into_iter(); diff --git a/core/src/retransmit_stage.rs b/core/src/retransmit_stage.rs index ad350e397b..6a395d2d1c 100644 --- a/core/src/retransmit_stage.rs +++ b/core/src/retransmit_stage.rs @@ -10,8 +10,6 @@ use crate::{ window_service::{should_retransmit_and_persist, WindowService}, }; use crossbeam_channel::Receiver as CrossbeamReceiver; -use rand::SeedableRng; -use rand_chacha::ChaChaRng; use solana_ledger::{ bank_forks::BankForks, blocktree::{Blocktree, CompletedSlotsReceiver}, @@ -92,7 +90,7 @@ fn retransmit( &me.id, &peers, &stakes_and_index, - ChaChaRng::from_seed(packet.meta.seed), + packet.meta.seed, ); peers_len = cmp::max(peers_len, shuffled_stakes_and_index.len()); shuffled_stakes_and_index.remove(my_index); diff --git a/core/src/weighted_shuffle.rs b/core/src/weighted_shuffle.rs index 28f40660fe..8220f77256 100644 --- a/core/src/weighted_shuffle.rs +++ b/core/src/weighted_shuffle.rs @@ -2,18 +2,19 @@ use itertools::Itertools; use num_traits::{FromPrimitive, ToPrimitive}; -use rand::Rng; +use rand::{Rng, SeedableRng}; use rand_chacha::ChaChaRng; use std::iter; use std::ops::Div; /// Returns a list of indexes shuffled based on the input weights /// Note - The sum of all weights must not exceed `u64::MAX` -pub fn weighted_shuffle(weights: Vec, mut rng: ChaChaRng) -> Vec +pub fn weighted_shuffle(weights: Vec, seed: [u8; 32]) -> Vec where T: Copy + PartialOrd + iter::Sum + Div + FromPrimitive + ToPrimitive, { let total_weight: T = weights.clone().into_iter().sum(); + let mut rng = ChaChaRng::from_seed(seed); weights .into_iter() .enumerate() @@ -36,10 +37,11 @@ where /// Returns the highest index after computing a weighted shuffle. /// Saves doing any sorting for O(n) max calculation. -pub fn weighted_best(weights_and_indexes: &[(u64, usize)], mut rng: ChaChaRng) -> usize { +pub fn weighted_best(weights_and_indexes: &[(u64, usize)], seed: [u8; 32]) -> usize { if weights_and_indexes.is_empty() { return 0; } + let mut rng = ChaChaRng::from_seed(seed); let total_weight: u64 = weights_and_indexes.iter().map(|x| x.0).sum(); let mut lowest_weight = std::u128::MAX; let mut best_index = 0; @@ -63,13 +65,12 @@ pub fn weighted_best(weights_and_indexes: &[(u64, usize)], mut rng: ChaChaRng) - #[cfg(test)] mod tests { use super::*; - use rand::SeedableRng; #[test] fn test_weighted_shuffle_iterator() { let mut test_set = [0; 6]; let mut count = 0; - let shuffle = weighted_shuffle(vec![50, 10, 2, 1, 1, 1], ChaChaRng::from_seed([0x5a; 32])); + let shuffle = weighted_shuffle(vec![50, 10, 2, 1, 1, 1], [0x5a; 32]); shuffle.into_iter().for_each(|x| { assert_eq!(test_set[x], 0); test_set[x] = 1; @@ -84,7 +85,7 @@ mod tests { let mut test_weights = vec![0; 100]; (0..100).for_each(|i| test_weights[i] = (i + 1) as u64); let mut count = 0; - let shuffle = weighted_shuffle(test_weights, ChaChaRng::from_seed([0xa5; 32])); + let shuffle = weighted_shuffle(test_weights, [0xa5; 32]); shuffle.into_iter().for_each(|x| { assert_eq!(test_set[x], 0); test_set[x] = 1; @@ -95,9 +96,9 @@ mod tests { #[test] fn test_weighted_shuffle_compare() { - let shuffle = weighted_shuffle(vec![50, 10, 2, 1, 1, 1], ChaChaRng::from_seed([0x5a; 32])); + let shuffle = weighted_shuffle(vec![50, 10, 2, 1, 1, 1], [0x5a; 32]); - let shuffle1 = weighted_shuffle(vec![50, 10, 2, 1, 1, 1], ChaChaRng::from_seed([0x5a; 32])); + let shuffle1 = weighted_shuffle(vec![50, 10, 2, 1, 1, 1], [0x5a; 32]); shuffle1 .into_iter() .zip(shuffle.into_iter()) @@ -110,7 +111,7 @@ mod tests { fn test_weighted_shuffle_imbalanced() { let mut weights = vec![std::u32::MAX as u64; 3]; weights.push(1); - let shuffle = weighted_shuffle(weights.clone(), ChaChaRng::from_seed([0x5a; 32])); + let shuffle = weighted_shuffle(weights.clone(), [0x5a; 32]); shuffle.into_iter().for_each(|x| { if x == weights.len() - 1 { assert_eq!(weights[x], 1); @@ -127,7 +128,7 @@ mod tests { .enumerate() .map(|(i, weight)| (weight, i)) .collect(); - let best_index = weighted_best(&weights_and_indexes, ChaChaRng::from_seed([0x5b; 32])); + let best_index = weighted_best(&weights_and_indexes, [0x5b; 32]); assert_eq!(best_index, 2); } }