solana/core/src/crds_shards.rs

233 lines
7.6 KiB
Rust

use crate::crds::VersionedCrdsValue;
use crate::crds_gossip_pull::CrdsFilter;
use indexmap::map::IndexMap;
use std::cmp::Ordering;
use std::ops::{Index, IndexMut};
#[derive(Clone)]
pub struct CrdsShards {
// shards[k] includes crds values which the first shard_bits of their hash
// value is equal to k. Each shard is a mapping from crds values indices to
// their hash value.
shards: Vec<IndexMap<usize, u64>>,
shard_bits: u32,
}
impl CrdsShards {
pub fn new(shard_bits: u32) -> Self {
CrdsShards {
shards: vec![IndexMap::new(); 1 << shard_bits],
shard_bits,
}
}
pub fn insert(&mut self, index: usize, value: &VersionedCrdsValue) -> bool {
let hash = CrdsFilter::hash_as_u64(&value.value_hash);
self.shard_mut(hash).insert(index, hash).is_none()
}
pub fn remove(&mut self, index: usize, value: &VersionedCrdsValue) -> bool {
let hash = CrdsFilter::hash_as_u64(&value.value_hash);
self.shard_mut(hash).swap_remove(&index).is_some()
}
/// Returns indices of all crds values which the first 'mask_bits' of their
/// hash value is equal to 'mask'.
pub fn find(&self, mask: u64, mask_bits: u32) -> impl Iterator<Item = usize> + '_ {
let ones = (!0u64).checked_shr(mask_bits).unwrap_or(0);
let mask = mask | ones;
match self.shard_bits.cmp(&mask_bits) {
Ordering::Less => {
let pred = move |(&index, hash)| {
if hash | ones == mask {
Some(index)
} else {
None
}
};
Iter::Less(self.shard(mask).iter().filter_map(pred))
}
Ordering::Equal => Iter::Equal(self.shard(mask).keys().cloned()),
Ordering::Greater => {
let count = 1 << (self.shard_bits - mask_bits);
let end = self.shard_index(mask) + 1;
Iter::Greater(
self.shards[end - count..end]
.iter()
.flat_map(IndexMap::keys)
.cloned(),
)
}
}
}
#[inline]
fn shard_index(&self, hash: u64) -> usize {
hash.checked_shr(64 - self.shard_bits).unwrap_or(0) as usize
}
#[inline]
fn shard(&self, hash: u64) -> &IndexMap<usize, u64> {
let shard_index = self.shard_index(hash);
self.shards.index(shard_index)
}
#[inline]
fn shard_mut(&mut self, hash: u64) -> &mut IndexMap<usize, u64> {
let shard_index = self.shard_index(hash);
self.shards.index_mut(shard_index)
}
// Checks invariants in the shards tables against the crds table.
#[cfg(test)]
pub fn check(&self, crds: &[VersionedCrdsValue]) {
let mut indices: Vec<_> = self
.shards
.iter()
.flat_map(IndexMap::keys)
.cloned()
.collect();
indices.sort_unstable();
assert_eq!(indices, (0..crds.len()).collect::<Vec<_>>());
for (shard_index, shard) in self.shards.iter().enumerate() {
for (&index, &hash) in shard {
assert_eq!(hash, CrdsFilter::hash_as_u64(&crds[index].value_hash));
assert_eq!(
shard_index as u64,
hash.checked_shr(64 - self.shard_bits).unwrap_or(0)
);
}
}
}
}
// Wrapper for 3 types of iterators we get when comparing shard_bits and
// mask_bits in find method. This is to avoid Box<dyn Iterator<Item =...>>
// which involves dynamic dispatch and is relatively slow.
enum Iter<R, S, T> {
Less(R),
Equal(S),
Greater(T),
}
impl<R, S, T> Iterator for Iter<R, S, T>
where
R: Iterator<Item = usize>,
S: Iterator<Item = usize>,
T: Iterator<Item = usize>,
{
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
match self {
Self::Greater(iter) => iter.next(),
Self::Less(iter) => iter.next(),
Self::Equal(iter) => iter.next(),
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::contact_info::ContactInfo;
use crate::crds_value::{CrdsData, CrdsValue};
use rand::{thread_rng, Rng};
use solana_sdk::timing::timestamp;
use std::collections::HashSet;
use std::ops::Index;
fn new_test_crds_value() -> VersionedCrdsValue {
let data = CrdsData::ContactInfo(ContactInfo::new_localhost(
&solana_sdk::pubkey::new_rand(),
timestamp(),
));
VersionedCrdsValue::new(timestamp(), CrdsValue::new_unsigned(data))
}
// Returns true if the first mask_bits most significant bits of hash is the
// same as the given bit mask.
fn check_mask(value: &VersionedCrdsValue, mask: u64, mask_bits: u32) -> bool {
let hash = CrdsFilter::hash_as_u64(&value.value_hash);
let ones = (!0u64).checked_shr(mask_bits).unwrap_or(0u64);
(hash | ones) == (mask | ones)
}
// Manual filtering by scanning all the values.
fn filter_crds_values(
values: &[VersionedCrdsValue],
mask: u64,
mask_bits: u32,
) -> HashSet<usize> {
values
.iter()
.enumerate()
.filter_map(|(index, value)| {
if check_mask(value, mask, mask_bits) {
Some(index)
} else {
None
}
})
.collect()
}
#[test]
fn test_crds_shards_round_trip() {
let mut rng = thread_rng();
// Generate some random hash and crds value labels.
let mut values: Vec<_> = std::iter::repeat_with(new_test_crds_value)
.take(4096)
.collect();
// Insert everything into the crds shards.
let mut shards = CrdsShards::new(5);
for (index, value) in values.iter().enumerate() {
assert!(shards.insert(index, value));
}
shards.check(&values);
// Remove some of the values.
for _ in 0..512 {
let index = rng.gen_range(0, values.len());
let value = values.swap_remove(index);
assert!(shards.remove(index, &value));
if index < values.len() {
let value = values.index(index);
assert!(shards.remove(values.len(), value));
assert!(shards.insert(index, value));
}
shards.check(&values);
}
// Random masks.
for _ in 0..10 {
let mask = rng.gen();
for mask_bits in 0..12 {
let mut set = filter_crds_values(&values, mask, mask_bits);
for index in shards.find(mask, mask_bits) {
assert!(set.remove(&index));
}
assert!(set.is_empty());
}
}
// Existing hash values.
for (index, value) in values.iter().enumerate() {
let mask = CrdsFilter::hash_as_u64(&value.value_hash);
let hits: Vec<_> = shards.find(mask, 64).collect();
assert_eq!(hits, vec![index]);
}
// Remove everything.
while !values.is_empty() {
let index = rng.gen_range(0, values.len());
let value = values.swap_remove(index);
assert!(shards.remove(index, &value));
if index < values.len() {
let value = values.index(index);
assert!(shards.remove(values.len(), value));
assert!(shards.insert(index, value));
}
if index % 5 == 0 {
shards.check(&values);
}
}
}
}