diff --git a/core/src/cluster_nodes.rs b/core/src/cluster_nodes.rs index 1df141607a..231c55f464 100644 --- a/core/src/cluster_nodes.rs +++ b/core/src/cluster_nodes.rs @@ -326,7 +326,7 @@ pub fn new_cluster_nodes( .collect(); let broadcast = TypeId::of::() == TypeId::of::(); let stakes: Vec = nodes.iter().map(|node| node.stake).collect(); - let mut weighted_shuffle = WeightedShuffle::new(&stakes).unwrap(); + let mut weighted_shuffle = WeightedShuffle::new("cluster-nodes", &stakes); if broadcast { weighted_shuffle.remove_index(index[&self_pubkey]); } diff --git a/gossip/benches/weighted_shuffle.rs b/gossip/benches/weighted_shuffle.rs index 72f3b6dbcc..e21f3eb203 100644 --- a/gossip/benches/weighted_shuffle.rs +++ b/gossip/benches/weighted_shuffle.rs @@ -32,8 +32,7 @@ fn bench_weighted_shuffle_new(bencher: &mut Bencher) { let weights = make_weights(&mut rng); bencher.iter(|| { rng.fill(&mut seed[..]); - let shuffle = WeightedShuffle::new(&weights).unwrap(); - shuffle + WeightedShuffle::new("", &weights) .shuffle(&mut ChaChaRng::from_seed(seed)) .collect::>() }); diff --git a/gossip/src/cluster_info.rs b/gossip/src/cluster_info.rs index 3b246b5b49..541d71f446 100644 --- a/gossip/src/cluster_info.rs +++ b/gossip/src/cluster_info.rs @@ -2019,7 +2019,7 @@ impl ClusterInfo { return packet_batch; } let mut rng = rand::thread_rng(); - let shuffle = WeightedShuffle::new(&scores).unwrap().shuffle(&mut rng); + let shuffle = WeightedShuffle::new("handle-pull-requests", &scores).shuffle(&mut rng); let mut total_bytes = 0; let mut sent = 0; for (addr, response) in shuffle.map(|i| &responses[i]) { diff --git a/gossip/src/crds_gossip_pull.rs b/gossip/src/crds_gossip_pull.rs index 30ebe48328..ca26ed1bf1 100644 --- a/gossip/src/crds_gossip_pull.rs +++ b/gossip/src/crds_gossip_pull.rs @@ -246,8 +246,7 @@ impl CrdsGossipPull { return Err(CrdsGossipError::NoPeers); } let mut rng = rand::thread_rng(); - let mut peers = WeightedShuffle::new(&weights) - .unwrap() + let mut peers = WeightedShuffle::new("pull-options", &weights) .shuffle(&mut rng) .map(|i| peers[i]); let peer = { diff --git a/gossip/src/crds_gossip_push.rs b/gossip/src/crds_gossip_push.rs index cde6ff442f..76356ab215 100644 --- a/gossip/src/crds_gossip_push.rs +++ b/gossip/src/crds_gossip_push.rs @@ -169,8 +169,7 @@ impl CrdsGossipPush { .filter(|(_, stake)| *stake > 0) .collect(); let weights: Vec<_> = peers.iter().map(|(_, stake)| *stake).collect(); - WeightedShuffle::new(&weights) - .unwrap() + WeightedShuffle::new("prune-received-cache", &weights) .shuffle(&mut rng) .map(move |i| peers[i]) }; @@ -370,7 +369,7 @@ impl CrdsGossipPush { return; } let num_bloom_items = MIN_NUM_BLOOM_ITEMS.max(network_size); - let shuffle = WeightedShuffle::new(&weights).unwrap().shuffle(&mut rng); + let shuffle = WeightedShuffle::new("push-options", &weights).shuffle(&mut rng); let mut active_set = self.active_set.write().unwrap(); let need = Self::compute_need(self.num_active, active_set.len(), ratio); for peer in shuffle.map(|i| peers[i]) { diff --git a/gossip/src/weighted_shuffle.rs b/gossip/src/weighted_shuffle.rs index 52da902fbc..9975af3521 100644 --- a/gossip/src/weighted_shuffle.rs +++ b/gossip/src/weighted_shuffle.rs @@ -15,12 +15,6 @@ use { }, }; -#[derive(Debug)] -pub enum WeightedShuffleError { - NegativeWeight(T), - SumOverflow, -} - /// Implements an iterator where indices are shuffled according to their /// weights: /// - Returned indices are unique in the range [0, weights.len()). @@ -43,34 +37,48 @@ impl WeightedShuffle where T: Copy + Default + PartialOrd + AddAssign + CheckedAdd, { - /// Returns error if: - /// - any of the weights are negative. - /// - sum of weights overflows. - pub fn new(weights: &[T]) -> Result> { + /// If weights are negative or overflow the total sum + /// they are treated as zero. + pub fn new(name: &'static str, weights: &[T]) -> Self { let size = weights.len() + 1; let zero = ::default(); let mut arr = vec![zero; size]; let mut sum = zero; let mut zeros = Vec::default(); + let mut num_negative = 0; + let mut num_overflow = 0; for (mut k, &weight) in (1usize..).zip(weights) { #[allow(clippy::neg_cmp_op_on_partial_ord)] // weight < zero does not work for NaNs. if !(weight >= zero) { - return Err(WeightedShuffleError::NegativeWeight(weight)); + zeros.push(k - 1); + num_negative += 1; + continue; } if weight == zero { zeros.push(k - 1); continue; } - sum = sum - .checked_add(&weight) - .ok_or(WeightedShuffleError::SumOverflow)?; + sum = match sum.checked_add(&weight) { + Some(val) => val, + None => { + zeros.push(k - 1); + num_overflow += 1; + continue; + } + }; while k < size { arr[k] += weight; k += k & k.wrapping_neg(); } } - Ok(Self { arr, sum, zeros }) + if num_negative > 0 { + datapoint_error!("weighted-shuffle-negative", (name, num_negative, i64)); + } + if num_overflow > 0 { + datapoint_error!("weighted-shuffle-overflow", (name, num_overflow, i64)); + } + Self { arr, sum, zeros } } } @@ -343,7 +351,7 @@ mod tests { fn test_weighted_shuffle_empty_weights() { let weights = Vec::::new(); let mut rng = rand::thread_rng(); - let shuffle = WeightedShuffle::new(&weights).unwrap(); + let shuffle = WeightedShuffle::new("", &weights); assert!(shuffle.clone().shuffle(&mut rng).next().is_none()); assert!(shuffle.first(&mut rng).is_none()); } @@ -354,7 +362,7 @@ mod tests { let weights = vec![0u64; 5]; let seed = [37u8; 32]; let mut rng = ChaChaRng::from_seed(seed); - let shuffle = WeightedShuffle::new(&weights).unwrap(); + let shuffle = WeightedShuffle::new("", &weights); assert_eq!( shuffle.clone().shuffle(&mut rng).collect::>(), [1, 4, 2, 3, 0] @@ -372,14 +380,14 @@ mod tests { let weights = [1, 0, 1000, 0, 0, 10, 100, 0]; let mut counts = [0; 8]; for _ in 0..100000 { - let mut shuffle = WeightedShuffle::new(&weights).unwrap().shuffle(&mut rng); + let mut shuffle = WeightedShuffle::new("", &weights).shuffle(&mut rng); counts[shuffle.next().unwrap()] += 1; let _ = shuffle.count(); // consume the rest. } assert_eq!(counts, [95, 0, 90069, 0, 0, 908, 8928, 0]); let mut counts = [0; 8]; for _ in 0..100000 { - let mut shuffle = WeightedShuffle::new(&weights).unwrap(); + let mut shuffle = WeightedShuffle::new("", &weights); shuffle.remove_index(5); shuffle.remove_index(3); shuffle.remove_index(1); @@ -390,6 +398,26 @@ mod tests { assert_eq!(counts, [97, 0, 90862, 0, 0, 0, 9041, 0]); } + #[test] + fn test_weighted_shuffle_negative_overflow() { + const SEED: [u8; 32] = [48u8; 32]; + let weights = [19i64, 23, 7, 0, 0, 23, 3, 0, 5, 0, 19, 29]; + let mut rng = ChaChaRng::from_seed(SEED); + let shuffle = WeightedShuffle::new("", &weights); + assert_eq!( + shuffle.shuffle(&mut rng).collect::>(), + [8, 1, 5, 10, 11, 0, 2, 6, 9, 4, 3, 7] + ); + // Negative weights and overflowing ones are treated as zero. + let weights = [19, 23, 7, -57, i64::MAX, 23, 3, i64::MAX, 5, -79, 19, 29]; + let mut rng = ChaChaRng::from_seed(SEED); + let shuffle = WeightedShuffle::new("", &weights); + assert_eq!( + shuffle.shuffle(&mut rng).collect::>(), + [8, 1, 5, 10, 11, 0, 2, 6, 9, 4, 3, 7] + ); + } + #[test] fn test_weighted_shuffle_hard_coded() { let weights = [ @@ -397,7 +425,7 @@ mod tests { ]; let seed = [48u8; 32]; let mut rng = ChaChaRng::from_seed(seed); - let mut shuffle = WeightedShuffle::new(&weights).unwrap(); + let mut shuffle = WeightedShuffle::new("", &weights); assert_eq!( shuffle.clone().shuffle(&mut rng).collect::>(), [2, 12, 18, 0, 14, 15, 17, 10, 1, 9, 7, 6, 13, 20, 4, 19, 3, 8, 11, 16, 5] @@ -417,7 +445,7 @@ mod tests { assert_eq!(shuffle.first(&mut rng), Some(4)); let seed = [37u8; 32]; let mut rng = ChaChaRng::from_seed(seed); - let mut shuffle = WeightedShuffle::new(&weights).unwrap(); + let mut shuffle = WeightedShuffle::new("", &weights); assert_eq!( shuffle.clone().shuffle(&mut rng).collect::>(), [19, 3, 15, 14, 6, 10, 17, 18, 9, 2, 4, 1, 0, 7, 8, 20, 12, 13, 16, 5, 11] @@ -447,13 +475,13 @@ mod tests { let mut seed = [0u8; 32]; rng.fill(&mut seed[..]); let mut rng = ChaChaRng::from_seed(seed); - let shuffle = WeightedShuffle::new(&weights).unwrap(); + let shuffle = WeightedShuffle::new("", &weights); let shuffle: Vec<_> = shuffle.shuffle(&mut rng).collect(); let mut rng = ChaChaRng::from_seed(seed); let shuffle_slow = weighted_shuffle_slow(&mut rng, weights.clone()); assert_eq!(shuffle, shuffle_slow); let mut rng = ChaChaRng::from_seed(seed); - let shuffle = WeightedShuffle::new(&weights).unwrap(); + let shuffle = WeightedShuffle::new("", &weights); assert_eq!(shuffle.first(&mut rng), Some(shuffle_slow[0])); } }