192 lines
6.2 KiB
Rust
192 lines
6.2 KiB
Rust
|
use {
|
||
|
itertools::Itertools,
|
||
|
lru::LruCache,
|
||
|
solana_sdk::pubkey::Pubkey,
|
||
|
std::{cmp::Reverse, collections::HashMap},
|
||
|
};
|
||
|
|
||
|
// For each origin, tracks which nodes have sent messages from that origin and
|
||
|
// their respective score in terms of timeliness of delivered messages.
|
||
|
pub(crate) struct ReceivedCache(LruCache</*origin/owner:*/ Pubkey, ReceivedCacheEntry>);
|
||
|
|
||
|
#[derive(Clone, Default)]
|
||
|
struct ReceivedCacheEntry {
|
||
|
nodes: HashMap<Pubkey, /*score:*/ usize>,
|
||
|
num_upserts: usize,
|
||
|
}
|
||
|
|
||
|
impl ReceivedCache {
|
||
|
// Minimum number of upserts before a cache entry can be pruned.
|
||
|
const MIN_NUM_UPSERTS: usize = 20;
|
||
|
|
||
|
pub(crate) fn new(capacity: usize) -> Self {
|
||
|
Self(LruCache::new(capacity))
|
||
|
}
|
||
|
|
||
|
pub(crate) fn record(&mut self, origin: Pubkey, node: Pubkey, num_dups: usize) {
|
||
|
match self.0.get_mut(&origin) {
|
||
|
Some(entry) => entry.record(node, num_dups),
|
||
|
None => {
|
||
|
let mut entry = ReceivedCacheEntry::default();
|
||
|
entry.record(node, num_dups);
|
||
|
self.0.put(origin, entry);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
pub(crate) fn prune(
|
||
|
&mut self,
|
||
|
pubkey: &Pubkey, // This node.
|
||
|
origin: Pubkey, // CRDS value owner.
|
||
|
stake_threshold: f64,
|
||
|
min_ingress_nodes: usize,
|
||
|
stakes: &HashMap<Pubkey, u64>,
|
||
|
) -> impl Iterator<Item = Pubkey> {
|
||
|
match self.0.peek_mut(&origin) {
|
||
|
None => None,
|
||
|
Some(entry) if entry.num_upserts < Self::MIN_NUM_UPSERTS => None,
|
||
|
Some(entry) => Some(
|
||
|
std::mem::take(entry)
|
||
|
.prune(pubkey, &origin, stake_threshold, min_ingress_nodes, stakes)
|
||
|
.filter(move |node| node != &origin),
|
||
|
),
|
||
|
}
|
||
|
.into_iter()
|
||
|
.flatten()
|
||
|
}
|
||
|
|
||
|
// Only for tests and simulations.
|
||
|
pub(crate) fn mock_clone(&self) -> Self {
|
||
|
let mut cache = LruCache::new(self.0.cap());
|
||
|
for (&origin, entry) in self.0.iter().rev() {
|
||
|
cache.put(origin, entry.clone());
|
||
|
}
|
||
|
Self(cache)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl ReceivedCacheEntry {
|
||
|
// Limit how big the cache can get if it is spammed
|
||
|
// with old messages with random pubkeys.
|
||
|
const CAPACITY: usize = 50;
|
||
|
// Threshold for the number of duplicates before which a message
|
||
|
// is counted as timely towards node's score.
|
||
|
const NUM_DUPS_THRESHOLD: usize = 2;
|
||
|
|
||
|
fn record(&mut self, node: Pubkey, num_dups: usize) {
|
||
|
if num_dups == 0 {
|
||
|
self.num_upserts = self.num_upserts.saturating_add(1);
|
||
|
}
|
||
|
// If the message has been timely enough increment node's score.
|
||
|
if num_dups < Self::NUM_DUPS_THRESHOLD {
|
||
|
let score = self.nodes.entry(node).or_default();
|
||
|
*score = score.saturating_add(1);
|
||
|
} else if self.nodes.len() < Self::CAPACITY {
|
||
|
// Ensure that node is inserted into the cache for later pruning.
|
||
|
let _ = self.nodes.entry(node).or_default();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
fn prune(
|
||
|
self,
|
||
|
pubkey: &Pubkey, // This node.
|
||
|
origin: &Pubkey, // CRDS value owner.
|
||
|
stake_threshold: f64,
|
||
|
min_ingress_nodes: usize,
|
||
|
stakes: &HashMap<Pubkey, u64>,
|
||
|
) -> impl Iterator<Item = Pubkey> {
|
||
|
debug_assert!((0.0..=1.0).contains(&stake_threshold));
|
||
|
debug_assert!(self.num_upserts >= ReceivedCache::MIN_NUM_UPSERTS);
|
||
|
// Enforce a minimum aggregate ingress stake; see:
|
||
|
// https://github.com/solana-labs/solana/issues/3214
|
||
|
let min_ingress_stake = {
|
||
|
let stake = stakes.get(pubkey).min(stakes.get(origin));
|
||
|
(stake.copied().unwrap_or_default() as f64 * stake_threshold) as u64
|
||
|
};
|
||
|
self.nodes
|
||
|
.into_iter()
|
||
|
.map(|(node, score)| {
|
||
|
let stake = stakes.get(&node).copied().unwrap_or_default();
|
||
|
(node, score, stake)
|
||
|
})
|
||
|
.sorted_unstable_by_key(|&(_, score, stake)| Reverse((score, stake)))
|
||
|
.scan(0u64, |acc, (node, _score, stake)| {
|
||
|
let old = *acc;
|
||
|
*acc = acc.saturating_add(stake);
|
||
|
Some((node, old))
|
||
|
})
|
||
|
.skip(min_ingress_nodes)
|
||
|
.skip_while(move |&(_, stake)| stake < min_ingress_stake)
|
||
|
.map(|(node, _stake)| node)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[cfg(test)]
|
||
|
mod tests {
|
||
|
use {
|
||
|
super::*,
|
||
|
std::{collections::HashSet, iter::repeat_with},
|
||
|
};
|
||
|
|
||
|
#[test]
|
||
|
fn test_received_cache() {
|
||
|
let mut cache = ReceivedCache::new(/*capacity:*/ 100);
|
||
|
let pubkey = Pubkey::new_unique();
|
||
|
let origin = Pubkey::new_unique();
|
||
|
let records = vec![
|
||
|
vec![3, 1, 7, 5],
|
||
|
vec![7, 6, 5, 2],
|
||
|
vec![2, 0, 0, 2],
|
||
|
vec![3, 5, 0, 6],
|
||
|
vec![6, 2, 6, 2],
|
||
|
];
|
||
|
let nodes: Vec<_> = repeat_with(Pubkey::new_unique)
|
||
|
.take(records.len())
|
||
|
.collect();
|
||
|
for (node, records) in nodes.iter().zip(records) {
|
||
|
for (num_dups, k) in records.into_iter().enumerate() {
|
||
|
for _ in 0..k {
|
||
|
cache.record(origin, *node, num_dups);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
assert_eq!(cache.0.get(&origin).unwrap().num_upserts, 21);
|
||
|
let scores: HashMap<Pubkey, usize> = [
|
||
|
(nodes[0], 4),
|
||
|
(nodes[1], 13),
|
||
|
(nodes[2], 2),
|
||
|
(nodes[3], 8),
|
||
|
(nodes[4], 8),
|
||
|
]
|
||
|
.into_iter()
|
||
|
.collect();
|
||
|
assert_eq!(cache.0.get(&origin).unwrap().nodes, scores);
|
||
|
let stakes = [
|
||
|
(nodes[0], 6),
|
||
|
(nodes[1], 1),
|
||
|
(nodes[2], 5),
|
||
|
(nodes[3], 3),
|
||
|
(nodes[4], 7),
|
||
|
(pubkey, 9),
|
||
|
(origin, 9),
|
||
|
]
|
||
|
.into_iter()
|
||
|
.collect();
|
||
|
let prunes: HashSet<Pubkey> = [nodes[0], nodes[2], nodes[3]].into_iter().collect();
|
||
|
assert_eq!(
|
||
|
cache
|
||
|
.mock_clone()
|
||
|
.prune(&pubkey, origin, 0.5, 2, &stakes)
|
||
|
.collect::<HashSet<_>>(),
|
||
|
prunes
|
||
|
);
|
||
|
let prunes: HashSet<Pubkey> = [nodes[0], nodes[2]].into_iter().collect();
|
||
|
assert_eq!(
|
||
|
cache
|
||
|
.prune(&pubkey, origin, 1.0, 0, &stakes)
|
||
|
.collect::<HashSet<_>>(),
|
||
|
prunes
|
||
|
);
|
||
|
}
|
||
|
}
|