diff --git a/runtime/src/accounts_db.rs b/runtime/src/accounts_db.rs index 78cf38620..b03748563 100644 --- a/runtime/src/accounts_db.rs +++ b/runtime/src/accounts_db.rs @@ -4378,17 +4378,22 @@ impl AccountsDb { check_hash: bool, ) -> Result<(Hash, u64), BankHashVerificationError> { if !use_index { - let mut time = Measure::start("collect"); - let combined_maps = self.get_snapshot_storages(slot, Some(ancestors)).0; - time.stop(); + let mut collect_time = Measure::start("collect"); + let (combined_maps, slots) = self.get_snapshot_storages(slot, Some(ancestors)); + collect_time.stop(); + + let mut sort_time = Measure::start("sort_storages"); + let storages = SortedStorages::new_with_slots(&combined_maps, &slots); + sort_time.stop(); let timings = HashStats { - collect_snapshots_us: time.as_us(), + collect_snapshots_us: collect_time.as_us(), + storage_sort_us: sort_time.as_us(), ..HashStats::default() }; Self::calculate_accounts_hash_without_index( - &combined_maps, + &storages, Some(&self.thread_pool_clean), timings, check_hash, @@ -4437,7 +4442,7 @@ impl AccountsDb { let bin_calculator = PubkeyBinCalculator16::new(bins); assert!(bin_range.start < bins && bin_range.end <= bins && bin_range.start < bin_range.end); let mut time = Measure::start("scan all accounts"); - stats.num_snapshot_storage = storage.len(); + stats.num_snapshot_storage = storage.slot_count(); let mismatch_found = AtomicU64::new(0); let result: Vec>> = Self::scan_account_storage_no_bank( @@ -4504,7 +4509,7 @@ impl AccountsDb { // modeled after get_accounts_delta_hash // intended to be faster than calculate_accounts_hash pub fn calculate_accounts_hash_without_index( - storages: &[SnapshotStorage], + storages: &SortedStorages, thread_pool: Option<&ThreadPool>, mut stats: HashStats, check_hash: bool, @@ -4528,11 +4533,6 @@ impl AccountsDb { let mut previous_pass = PreviousPass::default(); let mut final_result = (Hash::default(), 0); - let mut sort_time = Measure::start("sort_storages"); - let storages = SortedStorages::new(storages); - sort_time.stop(); - stats.storage_sort_us = sort_time.as_us(); - for pass in 0..num_scan_passes { let bounds = Range { start: pass * bins_per_pass, @@ -6182,7 +6182,7 @@ pub mod tests { let (storages, _size, _slot_expected) = sample_storage(); let result = AccountsDb::calculate_accounts_hash_without_index( - &storages, + &get_storage_refs(&storages), None, HashStats::default(), false, @@ -6203,7 +6203,7 @@ pub mod tests { }); let sum = raw_expected.iter().map(|item| item.lamports).sum(); let result = AccountsDb::calculate_accounts_hash_without_index( - &storages, + &get_storage_refs(&storages), None, HashStats::default(), false, diff --git a/runtime/src/snapshot_utils.rs b/runtime/src/snapshot_utils.rs index d98edf359..ee31cf1cb 100644 --- a/runtime/src/snapshot_utils.rs +++ b/runtime/src/snapshot_utils.rs @@ -11,6 +11,7 @@ use { snapshot_package::{ AccountsPackage, AccountsPackagePre, AccountsPackageSendError, AccountsPackageSender, }, + sorted_storages::SortedStorages, }, bincode::{config::Options, serialize_into}, bzip2::bufread::BzDecoder, @@ -999,8 +1000,9 @@ pub fn process_accounts_package_pre( let hash = accounts_package.hash; // temporarily remaining here if let Some(expected_hash) = accounts_package.hash_for_testing { + let sorted_storages = SortedStorages::new(&accounts_package.storages); let (hash, lamports) = AccountsDb::calculate_accounts_hash_without_index( - &accounts_package.storages, + &sorted_storages, thread_pool, crate::accounts_hash::HashStats::default(), false, diff --git a/runtime/src/sorted_storages.rs b/runtime/src/sorted_storages.rs index c2ac982db..c1beff0c9 100644 --- a/runtime/src/sorted_storages.rs +++ b/runtime/src/sorted_storages.rs @@ -7,7 +7,7 @@ use std::ops::Range; pub struct SortedStorages<'a> { range: Range, storages: Vec>, - count: usize, + slot_count: usize, } impl<'a> SortedStorages<'a> { @@ -28,35 +28,45 @@ impl<'a> SortedStorages<'a> { &self.range } - pub fn len(&self) -> usize { - self.count - } - - pub fn is_empty(&self) -> bool { - self.len() == 0 + pub fn slot_count(&self) -> usize { + self.slot_count } + // assumptions: + // 1. each SnapshotStorage.!is_empty() + // 2. SnapshotStorage.first().unwrap().get_slot() is unique from all other SnapshotStorage items. pub fn new(source: &'a [SnapshotStorage]) -> Self { - let mut min = Slot::MAX; - let mut max = Slot::MIN; - let mut count = 0; - let mut time = Measure::start("get slot"); let slots = source .iter() .map(|storages| { - count += storages.len(); - if !storages.is_empty() { - storages.first().map(|storage| { - let slot = storage.slot(); - min = std::cmp::min(slot, min); - max = std::cmp::max(slot + 1, max); - slot - }) - } else { - None - } + let first = storages.first(); + assert!(first.is_some(), "SnapshotStorage.is_empty()"); + let storage = first.unwrap(); + storage.slot() // this must be unique. Will be enforced in new_with_slots }) .collect::>(); + Self::new_with_slots(source, &slots) + } + + // source[i] is in slot slots[i] + // assumptions: + // 1. slots vector contains unique slot #s. + // 2. slots and source are the same len + pub fn new_with_slots(source: &'a [SnapshotStorage], slots: &[Slot]) -> Self { + assert_eq!( + source.len(), + slots.len(), + "source and slots are different lengths" + ); + let mut min = Slot::MAX; + let mut max = Slot::MIN; + let slot_count = source.len(); + let mut time = Measure::start("get slot"); + slots.iter().for_each(|slot| { + let slot = *slot; + min = std::cmp::min(slot, min); + max = std::cmp::max(slot + 1, max); + }); time.stop(); let mut time2 = Measure::start("sort"); let range; @@ -75,11 +85,9 @@ impl<'a> SortedStorages<'a> { .iter() .zip(slots) .for_each(|(original_storages, slot)| { - if let Some(slot) = slot { - let index = (slot - min) as usize; - assert!(storages[index].is_none()); - storages[index] = Some(original_storages); - } + let index = (slot - min) as usize; + assert!(storages[index].is_none(), "slots are not unique"); // we should not encounter the same slot twice + storages[index] = Some(original_storages); }); } time2.stop(); @@ -87,7 +95,7 @@ impl<'a> SortedStorages<'a> { Self { range, storages, - count, + slot_count, } } } @@ -102,7 +110,7 @@ pub mod tests { start: min, end: min + len as Slot, }; - let count = source.len(); + let slot_count = source.len(); for (storage, slot) in source { storages[*slot as usize] = Some(*storage); } @@ -110,8 +118,79 @@ pub mod tests { Self { range, storages, - count, + slot_count, } } } + + #[test] + #[should_panic(expected = "SnapshotStorage.is_empty()")] + fn test_sorted_storages_empty() { + SortedStorages::new(&[Vec::new()]); + } + + #[test] + #[should_panic(expected = "slots are not unique")] + fn test_sorted_storages_duplicate_slots() { + SortedStorages::new_with_slots(&[Vec::new(), Vec::new()], &[0, 0]); + } + + #[test] + #[should_panic(expected = "source and slots are different lengths")] + fn test_sorted_storages_mismatched_lengths() { + SortedStorages::new_with_slots(&[Vec::new()], &[0, 0]); + } + + #[test] + fn test_sorted_storages_none() { + let result = SortedStorages::new_with_slots(&[], &[]); + assert_eq!(result.range, Range::default()); + assert_eq!(result.slot_count, 0); + assert_eq!(result.storages.len(), 0); + assert!(result.get(0).is_none()); + } + + #[test] + fn test_sorted_storages_1() { + let vec = vec![]; + let vec_check = vec.clone(); + let slot = 4; + let vecs = [vec]; + let result = SortedStorages::new_with_slots(&vecs, &[slot]); + assert_eq!( + result.range, + Range { + start: slot, + end: slot + 1 + } + ); + assert_eq!(result.slot_count, 1); + assert_eq!(result.storages.len(), 1); + assert_eq!(result.get(slot).unwrap().len(), vec_check.len()); + } + + #[test] + fn test_sorted_storages_2() { + let vec = vec![]; + let vec_check = vec.clone(); + let slots = [4, 7]; + let vecs = [vec.clone(), vec]; + let result = SortedStorages::new_with_slots(&vecs, &slots); + assert_eq!( + result.range, + Range { + start: slots[0], + end: slots[1] + 1, + } + ); + assert_eq!(result.slot_count, 2); + assert_eq!(result.storages.len() as Slot, slots[1] - slots[0] + 1); + assert!(result.get(0).is_none()); + assert!(result.get(3).is_none()); + assert!(result.get(5).is_none()); + assert!(result.get(6).is_none()); + assert!(result.get(8).is_none()); + assert_eq!(result.get(slots[0]).unwrap().len(), vec_check.len()); + assert_eq!(result.get(slots[1]).unwrap().len(), vec_check.len()); + } }