diff --git a/core/Cargo.toml b/core/Cargo.toml index 2eac2a2c66..de74fa8f09 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -97,6 +97,9 @@ name = "blockstore" [[bench]] name = "crds_gossip_pull" +[[bench]] +name = "crds_shards" + [[bench]] name = "gen_keys" diff --git a/core/benches/crds_shards.rs b/core/benches/crds_shards.rs new file mode 100644 index 0000000000..c419d5dce4 --- /dev/null +++ b/core/benches/crds_shards.rs @@ -0,0 +1,69 @@ +#![feature(test)] + +extern crate test; + +use rand::{thread_rng, Rng}; +use solana_core::contact_info::ContactInfo; +use solana_core::crds::VersionedCrdsValue; +use solana_core::crds_shards::CrdsShards; +use solana_core::crds_value::{CrdsData, CrdsValue}; +use solana_sdk::pubkey::Pubkey; +use solana_sdk::timing::timestamp; +use test::Bencher; + +const CRDS_SHARDS_BITS: u32 = 8; + +fn new_test_crds_value() -> VersionedCrdsValue { + let data = CrdsData::ContactInfo(ContactInfo::new_localhost(&Pubkey::new_rand(), timestamp())); + VersionedCrdsValue::new(timestamp(), CrdsValue::new_unsigned(data)) +} + +fn bench_crds_shards_find(bencher: &mut Bencher, num_values: usize, mask_bits: u32) { + let values: Vec = std::iter::repeat_with(new_test_crds_value) + .take(num_values) + .collect(); + let mut shards = CrdsShards::new(CRDS_SHARDS_BITS); + for (index, value) in values.iter().enumerate() { + assert!(shards.insert(index, value)); + } + let mut rng = thread_rng(); + bencher.iter(|| { + let mask = rng.gen(); + let _hits = shards.find(mask, mask_bits).count(); + }); +} + +#[bench] +fn bench_crds_shards_find_0(bencher: &mut Bencher) { + bench_crds_shards_find(bencher, 100_000, 0); +} + +#[bench] +fn bench_crds_shards_find_1(bencher: &mut Bencher) { + bench_crds_shards_find(bencher, 100_000, 1); +} + +#[bench] +fn bench_crds_shards_find_3(bencher: &mut Bencher) { + bench_crds_shards_find(bencher, 100_000, 3); +} + +#[bench] +fn bench_crds_shards_find_5(bencher: &mut Bencher) { + bench_crds_shards_find(bencher, 100_000, 5); +} + +#[bench] +fn bench_crds_shards_find_7(bencher: &mut Bencher) { + bench_crds_shards_find(bencher, 100_000, 7); +} + +#[bench] +fn bench_crds_shards_find_8(bencher: &mut Bencher) { + bench_crds_shards_find(bencher, 100_000, 8); +} + +#[bench] +fn bench_crds_shards_find_9(bencher: &mut Bencher) { + bench_crds_shards_find(bencher, 100_000, 9); +} diff --git a/core/src/crds.rs b/core/src/crds.rs index e07e9f328e..cd142bf552 100644 --- a/core/src/crds.rs +++ b/core/src/crds.rs @@ -24,22 +24,24 @@ //! A value is updated to a new version if the labels match, and the value //! wallclock is later, or the value hash is greater. -use crate::crds_gossip_pull::CrdsFilter; +use crate::crds_shards::CrdsShards; use crate::crds_value::{CrdsValue, CrdsValueLabel}; use bincode::serialize; -use indexmap::map::IndexMap; +use indexmap::map::{Entry, IndexMap}; use solana_sdk::hash::{hash, Hash}; use solana_sdk::pubkey::Pubkey; use std::cmp; use std::collections::HashMap; +use std::ops::Index; + +const CRDS_SHARDS_BITS: u32 = 8; #[derive(Clone)] pub struct Crds { /// Stores the map of labels and values pub table: IndexMap, pub num_inserts: usize, - - pub masks: IndexMap, + pub shards: CrdsShards, } #[derive(PartialEq, Debug)] @@ -89,7 +91,7 @@ impl Default for Crds { Crds { table: IndexMap::new(), num_inserts: 0, - masks: IndexMap::new(), + shards: CrdsShards::new(CRDS_SHARDS_BITS), } } } @@ -123,23 +125,28 @@ impl Crds { new_value: VersionedCrdsValue, ) -> Result, CrdsError> { let label = new_value.value.label(); - let wallclock = new_value.value.wallclock(); - let do_insert = self - .table - .get(&label) - .map(|current| new_value > *current) - .unwrap_or(true); - if do_insert { - self.masks.insert( - label.clone(), - CrdsFilter::hash_as_u64(&new_value.value_hash), - ); - let old = self.table.insert(label, new_value); - self.num_inserts += 1; - Ok(old) - } else { - trace!("INSERT FAILED data: {} new.wallclock: {}", label, wallclock,); - Err(CrdsError::InsertFailed) + match self.table.entry(label) { + Entry::Vacant(entry) => { + assert!(self.shards.insert(entry.index(), &new_value)); + entry.insert(new_value); + self.num_inserts += 1; + Ok(None) + } + Entry::Occupied(mut entry) if *entry.get() < new_value => { + let index = entry.index(); + assert!(self.shards.remove(index, entry.get())); + assert!(self.shards.insert(index, &new_value)); + self.num_inserts += 1; + Ok(Some(entry.insert(new_value))) + } + _ => { + trace!( + "INSERT FAILED data: {} new.wallclock: {}", + new_value.value.label(), + new_value.value.wallclock(), + ); + Err(CrdsError::InsertFailed) + } } } pub fn insert( @@ -200,8 +207,16 @@ impl Crds { } pub fn remove(&mut self, key: &CrdsValueLabel) { - self.table.swap_remove(key); - self.masks.swap_remove(key); + if let Some((index, _, value)) = self.table.swap_remove_full(key) { + assert!(self.shards.remove(index, &value)); + // The previously last element in the table is now moved to the + // 'index' position. Shards need to be updated accordingly. + if index < self.table.len() { + let value = self.table.index(index); + assert!(self.shards.remove(self.table.len(), value)); + assert!(self.shards.insert(index, value)); + } + } } } @@ -210,6 +225,7 @@ mod test { use super::*; use crate::contact_info::ContactInfo; use crate::crds_value::CrdsData; + use rand::{thread_rng, Rng}; #[test] fn test_insert() { @@ -329,6 +345,45 @@ mod test { assert_eq!(crds.find_old_labels(3, &set), vec![val.label()]); } + #[test] + fn test_crds_shards() { + fn check_crds_shards(crds: &Crds) { + crds.shards + .check(&crds.table.values().cloned().collect::>()); + } + + let mut crds = Crds::default(); + let pubkeys: Vec<_> = std::iter::repeat_with(Pubkey::new_rand).take(256).collect(); + let mut rng = thread_rng(); + let mut num_inserts = 0; + for _ in 0..4096 { + let pubkey = pubkeys[rng.gen_range(0, pubkeys.len())]; + let value = VersionedCrdsValue::new( + rng.gen(), // local_timestamp + CrdsValue::new_unsigned(CrdsData::ContactInfo(ContactInfo::new_localhost( + &pubkey, + rng.gen(), // now + ))), + ); + if crds.insert_versioned(value).is_ok() { + check_crds_shards(&crds); + num_inserts += 1; + } + } + assert_eq!(num_inserts, crds.num_inserts); + assert!(num_inserts > 700); + assert!(crds.table.len() > 200); + assert!(num_inserts > crds.table.len()); + check_crds_shards(&crds); + // Remove values one by one and assert that shards stay valid. + while !crds.table.is_empty() { + let index = rng.gen_range(0, crds.table.len()); + let key = crds.table.get_index(index).unwrap().0.clone(); + crds.remove(&key); + check_crds_shards(&crds); + } + } + #[test] fn test_remove_staked() { let mut crds = Crds::default(); diff --git a/core/src/crds_gossip_pull.rs b/core/src/crds_gossip_pull.rs index ee5a8569a9..83c10e02bd 100644 --- a/core/src/crds_gossip_pull.rs +++ b/core/src/crds_gossip_pull.rs @@ -23,6 +23,7 @@ use std::cmp; use std::collections::VecDeque; use std::collections::{HashMap, HashSet}; use std::convert::TryInto; +use std::ops::Index; pub const CRDS_GOSSIP_PULL_CRDS_TIMEOUT_MS: u64 = 15000; // The maximum age of a value received over pull responses @@ -418,52 +419,44 @@ impl CrdsGossipPull { filters: &[(CrdsValue, CrdsFilter)], now: u64, ) -> Vec> { - let mut ret = vec![vec![]; filters.len()]; let msg_timeout = CRDS_GOSSIP_PULL_CRDS_TIMEOUT_MS; let jitter = rand::thread_rng().gen_range(0, msg_timeout / 4); - let start = filters.len(); //skip filters from callers that are too old let future = now.saturating_add(msg_timeout); let past = now.saturating_sub(msg_timeout); - let recent: Vec<_> = filters + let mut dropped_requests = 0; + let mut total_skipped = 0; + let ret = filters .iter() - .enumerate() - .filter(|(_, (caller, _))| caller.wallclock() < future && caller.wallclock() >= past) + .map(|(caller, filter)| { + let caller_wallclock = caller.wallclock(); + if caller_wallclock >= future || caller_wallclock < past { + dropped_requests += 1; + return vec![]; + } + let caller_wallclock = caller_wallclock.checked_add(jitter).unwrap_or(0); + crds.shards + .find(filter.mask, filter.mask_bits) + .filter_map(|index| { + let item = crds.table.index(index); + debug_assert!(filter.test_mask(&item.value_hash)); + //skip values that are too new + if item.value.wallclock() > caller_wallclock { + total_skipped += 1; + None + } else if filter.filter_contains(&item.value_hash) { + None + } else { + Some(item.value.clone()) + } + }) + .collect() + }) .collect(); inc_new_counter_info!( "gossip_filter_crds_values-dropped_requests", - start - recent.len() + dropped_requests ); - if recent.is_empty() { - return ret; - } - let mut total_skipped = 0; - let mask_ones: Vec<_> = recent - .iter() - .map(|(_i, (_caller, filter))| (!0u64).checked_shr(filter.mask_bits).unwrap_or(!0u64)) - .collect(); - for (label, mask) in crds.masks.iter() { - recent - .iter() - .zip(mask_ones.iter()) - .for_each(|((i, (caller, filter)), mask_ones)| { - if filter.test_mask_u64(*mask, *mask_ones) { - let item = crds.table.get(label).unwrap(); - - //skip values that are too new - if item.value.wallclock() - > caller.wallclock().checked_add(jitter).unwrap_or_else(|| 0) - { - total_skipped += 1; - return; - } - - if !filter.filter_contains(&item.value_hash) { - ret[*i].push(item.value.clone()); - } - } - }); - } inc_new_counter_info!("gossip_filter_crds_values-dropped_values", total_skipped); ret } diff --git a/core/src/crds_shards.rs b/core/src/crds_shards.rs new file mode 100644 index 0000000000..917eb656d2 --- /dev/null +++ b/core/src/crds_shards.rs @@ -0,0 +1,233 @@ +use crate::crds::VersionedCrdsValue; +use crate::crds_gossip_pull::CrdsFilter; +use indexmap::map::IndexMap; +use std::cmp::Ordering; +use std::ops::{Index, IndexMut}; + +#[derive(Clone)] +pub struct CrdsShards { + // shards[k] includes crds values which the first shard_bits of their hash + // value is equal to k. Each shard is a mapping from crds values indices to + // their hash value. + shards: Vec>, + shard_bits: u32, +} + +impl CrdsShards { + pub fn new(shard_bits: u32) -> Self { + CrdsShards { + shards: vec![IndexMap::new(); 1 << shard_bits], + shard_bits, + } + } + + #[must_use] + pub fn insert(&mut self, index: usize, value: &VersionedCrdsValue) -> bool { + let hash = CrdsFilter::hash_as_u64(&value.value_hash); + self.shard_mut(hash).insert(index, hash).is_none() + } + + #[must_use] + pub fn remove(&mut self, index: usize, value: &VersionedCrdsValue) -> bool { + let hash = CrdsFilter::hash_as_u64(&value.value_hash); + self.shard_mut(hash).swap_remove(&index).is_some() + } + + /// Returns indices of all crds values which the first 'mask_bits' of their + /// hash value is equal to 'mask'. + pub fn find(&self, mask: u64, mask_bits: u32) -> impl Iterator + '_ { + let ones = (!0u64).checked_shr(mask_bits).unwrap_or(0); + let mask = mask | ones; + match self.shard_bits.cmp(&mask_bits) { + Ordering::Less => { + let pred = move |(&index, hash)| { + if hash | ones == mask { + Some(index) + } else { + None + } + }; + Iter::Less(self.shard(mask).iter().filter_map(pred)) + } + Ordering::Equal => Iter::Equal(self.shard(mask).keys().cloned()), + Ordering::Greater => { + let count = 1 << (self.shard_bits - mask_bits); + let end = self.shard_index(mask) + 1; + Iter::Greater( + self.shards[end - count..end] + .iter() + .flat_map(IndexMap::keys) + .cloned(), + ) + } + } + } + + #[inline] + fn shard_index(&self, hash: u64) -> usize { + hash.checked_shr(64 - self.shard_bits).unwrap_or(0) as usize + } + + #[inline] + fn shard(&self, hash: u64) -> &IndexMap { + let shard_index = self.shard_index(hash); + self.shards.index(shard_index) + } + + #[inline] + fn shard_mut(&mut self, hash: u64) -> &mut IndexMap { + let shard_index = self.shard_index(hash); + self.shards.index_mut(shard_index) + } + + // Checks invariants in the shards tables against the crds table. + #[cfg(test)] + pub fn check(&self, crds: &[VersionedCrdsValue]) { + let mut indices: Vec<_> = self + .shards + .iter() + .flat_map(IndexMap::keys) + .cloned() + .collect(); + indices.sort_unstable(); + assert_eq!(indices, (0..crds.len()).collect::>()); + for (shard_index, shard) in self.shards.iter().enumerate() { + for (&index, &hash) in shard { + assert_eq!(hash, CrdsFilter::hash_as_u64(&crds[index].value_hash)); + assert_eq!( + shard_index as u64, + hash.checked_shr(64 - self.shard_bits).unwrap_or(0) + ); + } + } + } +} + +// Wrapper for 3 types of iterators we get when comparing shard_bits and +// mask_bits in find method. This is to avoid Box> +// which involves dynamic dispatch and is relatively slow. +enum Iter { + Less(R), + Equal(S), + Greater(T), +} + +impl Iterator for Iter +where + R: Iterator, + S: Iterator, + T: Iterator, +{ + type Item = usize; + + fn next(&mut self) -> Option { + match self { + Self::Greater(iter) => iter.next(), + Self::Less(iter) => iter.next(), + Self::Equal(iter) => iter.next(), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::contact_info::ContactInfo; + use crate::crds_value::{CrdsData, CrdsValue}; + use rand::{thread_rng, Rng}; + use solana_sdk::pubkey::Pubkey; + use solana_sdk::timing::timestamp; + use std::collections::HashSet; + use std::ops::Index; + + fn new_test_crds_value() -> VersionedCrdsValue { + let data = + CrdsData::ContactInfo(ContactInfo::new_localhost(&Pubkey::new_rand(), timestamp())); + VersionedCrdsValue::new(timestamp(), CrdsValue::new_unsigned(data)) + } + + // Returns true if the first mask_bits most significant bits of hash is the + // same as the given bit mask. + fn check_mask(value: &VersionedCrdsValue, mask: u64, mask_bits: u32) -> bool { + let hash = CrdsFilter::hash_as_u64(&value.value_hash); + let ones = (!0u64).checked_shr(mask_bits).unwrap_or(0u64); + (hash | ones) == (mask | ones) + } + + // Manual filtering by scanning all the values. + fn filter_crds_values( + values: &[VersionedCrdsValue], + mask: u64, + mask_bits: u32, + ) -> HashSet { + values + .iter() + .enumerate() + .filter_map(|(index, value)| { + if check_mask(value, mask, mask_bits) { + Some(index) + } else { + None + } + }) + .collect() + } + + #[test] + fn test_crds_shards_round_trip() { + let mut rng = thread_rng(); + // Generate some random hash and crds value labels. + let mut values: Vec<_> = std::iter::repeat_with(new_test_crds_value) + .take(4096) + .collect(); + // Insert everything into the crds shards. + let mut shards = CrdsShards::new(5); + for (index, value) in values.iter().enumerate() { + assert!(shards.insert(index, value)); + } + shards.check(&values); + // Remove some of the values. + for _ in 0..512 { + let index = rng.gen_range(0, values.len()); + let value = values.swap_remove(index); + assert!(shards.remove(index, &value)); + if index < values.len() { + let value = values.index(index); + assert!(shards.remove(values.len(), value)); + assert!(shards.insert(index, value)); + } + shards.check(&values); + } + // Random masks. + for _ in 0..10 { + let mask = rng.gen(); + for mask_bits in 0..12 { + let mut set = filter_crds_values(&values, mask, mask_bits); + for index in shards.find(mask, mask_bits) { + assert!(set.remove(&index)); + } + assert!(set.is_empty()); + } + } + // Existing hash values. + for (index, value) in values.iter().enumerate() { + let mask = CrdsFilter::hash_as_u64(&value.value_hash); + let hits: Vec<_> = shards.find(mask, 64).collect(); + assert_eq!(hits, vec![index]); + } + // Remove everything. + while !values.is_empty() { + let index = rng.gen_range(0, values.len()); + let value = values.swap_remove(index); + assert!(shards.remove(index, &value)); + if index < values.len() { + let value = values.index(index); + assert!(shards.remove(values.len(), value)); + assert!(shards.insert(index, value)); + } + if index % 5 == 0 { + shards.check(&values); + } + } + } +} diff --git a/core/src/lib.rs b/core/src/lib.rs index 042b0ef740..7a555e39ae 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -29,6 +29,7 @@ pub mod crds_gossip; pub mod crds_gossip_error; pub mod crds_gossip_pull; pub mod crds_gossip_push; +pub mod crds_shards; pub mod crds_value; pub mod epoch_slots; pub mod fetch_stage;