diff --git a/runtime/src/accounts_db.rs b/runtime/src/accounts_db.rs index 332d0c6c21..733013a6a2 100644 --- a/runtime/src/accounts_db.rs +++ b/runtime/src/accounts_db.rs @@ -20,7 +20,7 @@ use crate::{ accounts_cache::{AccountsCache, CachedAccount, SlotCache}, - accounts_hash::{AccountsHash, CalculateHashIntermediate, HashStats}, + accounts_hash::{AccountsHash, CalculateHashIntermediate, HashStats, PreviousPass}, accounts_index::{ AccountIndex, AccountsIndex, AccountsIndexRootsStats, Ancestors, IndexKey, IsCached, SlotList, SlotSlice, ZeroLamport, @@ -54,7 +54,7 @@ use std::{ collections::{hash_map::Entry, BTreeMap, BTreeSet, HashMap, HashSet}, convert::TryFrom, io::{Error as IoError, Result as IoResult}, - ops::RangeBounds, + ops::{Range, RangeBounds}, path::{Path, PathBuf}, sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}, sync::{Arc, Mutex, MutexGuard, RwLock}, @@ -3705,9 +3705,11 @@ impl AccountsDb { storage: &[SnapshotStorage], mut stats: &mut crate::accounts_hash::HashStats, bins: usize, + bin_range: &Range, ) -> Vec>> { let max_plus_1 = std::u8::MAX as usize + 1; assert!(bins <= max_plus_1 && bins > 0); + assert!(bin_range.start < bins && bin_range.end <= bins && bin_range.start < bin_range.end); let mut time = Measure::start("scan all accounts"); stats.num_snapshot_storage = storage.len(); let result: Vec>> = Self::scan_account_storage_no_bank( @@ -3716,6 +3718,12 @@ impl AccountsDb { |loaded_account: LoadedAccount, accum: &mut Vec>, slot: Slot| { + let pubkey = *loaded_account.pubkey(); + let pubkey_to_bin_index = pubkey.as_ref()[0] as usize * bins / max_plus_1; + if !bin_range.contains(&pubkey_to_bin_index) { + return; + } + let version = loaded_account.write_version(); let raw_lamports = loaded_account.lamports(); let zero_raw_lamports = raw_lamports == 0; @@ -3729,7 +3737,6 @@ impl AccountsDb { ) }; - let pubkey = *loaded_account.pubkey(); let source_item = CalculateHashIntermediate::new( version, *loaded_account.loaded_hash(), @@ -3737,12 +3744,11 @@ impl AccountsDb { slot, pubkey, ); - let rng_index = pubkey.as_ref()[0] as usize * bins / max_plus_1; let max = accum.len(); if max == 0 { accum.extend(vec![Vec::new(); bins]); } - accum[rng_index].push(source_item); + accum[pubkey_to_bin_index].push(source_item); }, ); time.stop(); @@ -3759,14 +3765,46 @@ impl AccountsDb { let scan_and_hash = || { let mut stats = HashStats::default(); // When calculating hashes, it is helpful to break the pubkeys found into bins based on the pubkey value. + // More bins means smaller vectors to sort, copy, etc. const PUBKEY_BINS_FOR_CALCULATING_HASHES: usize = 64; - let result = Self::scan_snapshot_stores( - storages, - &mut stats, - PUBKEY_BINS_FOR_CALCULATING_HASHES, - ); - AccountsHash::rest_of_hash_calculation(result, &mut stats) + // # of passes should be a function of the total # of accounts that are active. + // higher passes = slower total time, lower dynamic memory usage + // lower passes = faster total time, higher dynamic memory usage + // passes=2 cuts dynamic memory usage in approximately half. + let num_scan_passes: usize = 2; + + let bins_per_pass = PUBKEY_BINS_FOR_CALCULATING_HASHES / num_scan_passes; + assert_eq!( + bins_per_pass * num_scan_passes, + PUBKEY_BINS_FOR_CALCULATING_HASHES + ); // evenly divisible + let mut previous_pass = PreviousPass::default(); + let mut final_result = (Hash::default(), 0); + + for pass in 0..num_scan_passes { + let bounds = Range { + start: pass * bins_per_pass, + end: (pass + 1) * bins_per_pass, + }; + + let result = Self::scan_snapshot_stores( + storages, + &mut stats, + PUBKEY_BINS_FOR_CALCULATING_HASHES, + &bounds, + ); + + let (hash, lamports, for_next_pass) = AccountsHash::rest_of_hash_calculation( + result, + &mut stats, + pass == num_scan_passes - 1, + previous_pass, + ); + previous_pass = for_next_pass; + final_result = (hash, lamports); + } + final_result }; if let Some(thread_pool) = thread_pool { thread_pool.install(scan_and_hash) @@ -5015,13 +5053,49 @@ pub mod tests { #[should_panic(expected = "assertion failed: bins <= max_plus_1 && bins > 0")] fn test_accountsdb_scan_snapshot_stores_illegal_bins2() { let mut stats = HashStats::default(); - AccountsDb::scan_snapshot_stores(&[], &mut stats, 257); + let bounds = Range { start: 0, end: 0 }; + + AccountsDb::scan_snapshot_stores(&[], &mut stats, 257, &bounds); } #[test] #[should_panic(expected = "assertion failed: bins <= max_plus_1 && bins > 0")] fn test_accountsdb_scan_snapshot_stores_illegal_bins() { let mut stats = HashStats::default(); - AccountsDb::scan_snapshot_stores(&[], &mut stats, 0); + let bounds = Range { start: 0, end: 0 }; + + AccountsDb::scan_snapshot_stores(&[], &mut stats, 0, &bounds); + } + + #[test] + #[should_panic( + expected = "bin_range.start < bins && bin_range.end <= bins &&\\n bin_range.start < bin_range.end" + )] + fn test_accountsdb_scan_snapshot_stores_illegal_range_start() { + let mut stats = HashStats::default(); + let bounds = Range { start: 2, end: 2 }; + + AccountsDb::scan_snapshot_stores(&[], &mut stats, 2, &bounds); + } + #[test] + #[should_panic( + expected = "bin_range.start < bins && bin_range.end <= bins &&\\n bin_range.start < bin_range.end" + )] + fn test_accountsdb_scan_snapshot_stores_illegal_range_end() { + let mut stats = HashStats::default(); + let bounds = Range { start: 1, end: 3 }; + + AccountsDb::scan_snapshot_stores(&[], &mut stats, 2, &bounds); + } + + #[test] + #[should_panic( + expected = "bin_range.start < bins && bin_range.end <= bins &&\\n bin_range.start < bin_range.end" + )] + fn test_accountsdb_scan_snapshot_stores_illegal_range_inverse() { + let mut stats = HashStats::default(); + let bounds = Range { start: 1, end: 0 }; + + AccountsDb::scan_snapshot_stores(&[], &mut stats, 2, &bounds); } fn sample_storages_and_accounts() -> (SnapshotStorages, Vec) { @@ -5108,11 +5182,28 @@ pub mod tests { let bins = 1; let mut stats = HashStats::default(); - let result = AccountsDb::scan_snapshot_stores(&storages, &mut stats, bins); + + let result = AccountsDb::scan_snapshot_stores( + &storages, + &mut stats, + bins, + &Range { + start: 0, + end: bins, + }, + ); assert_eq!(result, vec![vec![raw_expected.clone()]]); let bins = 2; - let result = AccountsDb::scan_snapshot_stores(&storages, &mut stats, bins); + let result = AccountsDb::scan_snapshot_stores( + &storages, + &mut stats, + bins, + &Range { + start: 0, + end: bins, + }, + ); let mut expected = vec![Vec::new(); bins]; expected[0].push(raw_expected[0].clone()); expected[0].push(raw_expected[1].clone()); @@ -5121,7 +5212,15 @@ pub mod tests { assert_eq!(result, vec![expected]); let bins = 4; - let result = AccountsDb::scan_snapshot_stores(&storages, &mut stats, bins); + let result = AccountsDb::scan_snapshot_stores( + &storages, + &mut stats, + bins, + &Range { + start: 0, + end: bins, + }, + ); let mut expected = vec![Vec::new(); bins]; expected[0].push(raw_expected[0].clone()); expected[1].push(raw_expected[1].clone()); @@ -5130,7 +5229,15 @@ pub mod tests { assert_eq!(result, vec![expected]); let bins = 256; - let result = AccountsDb::scan_snapshot_stores(&storages, &mut stats, bins); + let result = AccountsDb::scan_snapshot_stores( + &storages, + &mut stats, + bins, + &Range { + start: 0, + end: bins, + }, + ); let mut expected = vec![Vec::new(); bins]; expected[0].push(raw_expected[0].clone()); expected[127].push(raw_expected[1].clone()); @@ -5151,13 +5258,126 @@ pub mod tests { storages[0].splice(0..0, vec![arc; MAX_ITEMS_PER_CHUNK]); let mut stats = HashStats::default(); - let result = AccountsDb::scan_snapshot_stores(&storages, &mut stats, bins); + let result = AccountsDb::scan_snapshot_stores( + &storages, + &mut stats, + bins, + &Range { + start: 0, + end: bins, + }, + ); assert_eq!(result.len(), 2); // 2 chunks assert_eq!(result[0].len(), 0); // nothing found in first slots assert_eq!(result[1].len(), bins); assert_eq!(result[1], vec![raw_expected]); } + #[test] + fn test_accountsdb_scan_snapshot_stores_binning() { + let mut stats = HashStats::default(); + let (mut storages, raw_expected) = sample_storages_and_accounts(); + + // just the first bin of 2 + let bins = 2; + let result = AccountsDb::scan_snapshot_stores( + &storages, + &mut stats, + bins, + &Range { + start: 0, + end: bins / 2, + }, + ); + let mut expected = vec![Vec::new(); bins]; + expected[0].push(raw_expected[0].clone()); + expected[0].push(raw_expected[1].clone()); + assert_eq!(result, vec![expected]); + + // just the second bin of 2 + let result = AccountsDb::scan_snapshot_stores( + &storages, + &mut stats, + bins, + &Range { + start: 1, + end: bins, + }, + ); + + let mut expected = vec![Vec::new(); bins]; + expected[bins - 1].push(raw_expected[2].clone()); + expected[bins - 1].push(raw_expected[3].clone()); + assert_eq!(result, vec![expected]); + + // 1 bin at a time of 4 + let bins = 4; + for bin in 0..bins { + let result = AccountsDb::scan_snapshot_stores( + &storages, + &mut stats, + bins, + &Range { + start: bin, + end: bin + 1, + }, + ); + let mut expected = vec![Vec::new(); bins]; + expected[bin].push(raw_expected[bin].clone()); + assert_eq!(result, vec![expected]); + } + + let bins = 256; + let bin_locations = vec![0, 127, 128, 255]; + for bin in 0..bins { + let result = AccountsDb::scan_snapshot_stores( + &storages, + &mut stats, + bins, + &Range { + start: bin, + end: bin + 1, + }, + ); + let mut expected = vec![]; + if let Some(index) = bin_locations.iter().position(|&r| r == bin) { + expected = vec![Vec::new(); bins]; + expected[bin].push(raw_expected[index].clone()); + } + assert_eq!(result, vec![expected]); + } + + // enough stores to get to 2nd chunk + // range is for only 1 bin out of 256. + let bins = 256; + let (_temp_dirs, paths) = get_temp_accounts_paths(1).unwrap(); + let slot_expected: Slot = 0; + let size: usize = 123; + let data = AccountStorageEntry::new(&paths[0], slot_expected, 0, size as u64); + + let arc = Arc::new(data); + + const MAX_ITEMS_PER_CHUNK: usize = 5_000; + storages[0].splice(0..0, vec![arc; MAX_ITEMS_PER_CHUNK]); + + let mut stats = HashStats::default(); + let result = AccountsDb::scan_snapshot_stores( + &storages, + &mut stats, + bins, + &Range { + start: 127, + end: 128, + }, + ); + assert_eq!(result.len(), 2); // 2 chunks + assert_eq!(result[0].len(), 0); // nothing found in first slots + let mut expected = vec![Vec::new(); bins]; + expected[127].push(raw_expected[1].clone()); + assert_eq!(result[1].len(), bins); + assert_eq!(result[1], expected); + } + #[test] fn test_accountsdb_calculate_accounts_hash_without_index_simple() { solana_logger::setup(); diff --git a/runtime/src/accounts_hash.rs b/runtime/src/accounts_hash.rs index 9af37532f1..03d8d81f86 100644 --- a/runtime/src/accounts_hash.rs +++ b/runtime/src/accounts_hash.rs @@ -11,6 +11,13 @@ use std::{convert::TryInto, sync::Mutex}; pub const ZERO_RAW_LAMPORTS_SENTINEL: u64 = std::u64::MAX; pub const MERKLE_FANOUT: usize = 16; +#[derive(Default, Debug)] +pub struct PreviousPass { + pub reduced_hashes: Vec>, + pub remaining_unhashed: Vec, + pub lamports: u64, +} + #[derive(Debug, Default)] pub struct HashStats { pub scan_time_total_us: u64, @@ -197,8 +204,9 @@ impl AccountsHash { MERKLE_FANOUT, None, |start: usize| cumulative_offsets.get_slice(&hashes, start), + None, ); - (result, hash_total) + (result.0, hash_total) } pub fn compute_merkle_root(hashes: Vec<(Pubkey, Hash)>, fanout: usize) -> Hash { @@ -259,48 +267,85 @@ impl AccountsHash { } } - // This function is designed to allow hashes to be located in multiple, perhaps multiply deep vecs. - // The caller provides a function to return a slice from the source data. - pub fn compute_merkle_root_from_slices<'a, F>( + fn calculate_three_level_chunks( total_hashes: usize, fanout: usize, max_levels_per_pass: Option, - get_hashes: F, - ) -> Hash - where - F: Fn(usize) -> &'a [Hash] + std::marker::Sync, - { - if total_hashes == 0 { - return Hasher::default().result(); - } - - let mut time = Measure::start("time"); - + specific_level_count: Option, + ) -> (usize, usize, bool) { const THREE_LEVEL_OPTIMIZATION: usize = 3; // this '3' is dependent on the code structure below where we manually unroll let target = fanout.pow(THREE_LEVEL_OPTIMIZATION as u32); // Only use the 3 level optimization if we have at least 4 levels of data. // Otherwise, we'll be serializing a parallel operation. let threshold = target * fanout; - let three_level = max_levels_per_pass.unwrap_or(usize::MAX) >= THREE_LEVEL_OPTIMIZATION + let mut three_level = max_levels_per_pass.unwrap_or(usize::MAX) >= THREE_LEVEL_OPTIMIZATION && total_hashes >= threshold; - let num_hashes_per_chunk = if three_level { target } else { fanout }; + if three_level { + if let Some(specific_level_count_value) = specific_level_count { + three_level = specific_level_count_value >= THREE_LEVEL_OPTIMIZATION; + } + } + let (num_hashes_per_chunk, levels_hashed) = if three_level { + (target, THREE_LEVEL_OPTIMIZATION) + } else { + (fanout, 1) + }; + (num_hashes_per_chunk, levels_hashed, three_level) + } + + // This function is designed to allow hashes to be located in multiple, perhaps multiply deep vecs. + // The caller provides a function to return a slice from the source data. + pub fn compute_merkle_root_from_slices<'a, F>( + total_hashes: usize, + fanout: usize, + max_levels_per_pass: Option, + get_hash_slice_starting_at_index: F, + specific_level_count: Option, + ) -> (Hash, Vec) + where + F: Fn(usize) -> &'a [Hash] + std::marker::Sync, + { + if total_hashes == 0 { + return (Hasher::default().result(), vec![]); + } + + let mut time = Measure::start("time"); + + let (num_hashes_per_chunk, levels_hashed, three_level) = Self::calculate_three_level_chunks( + total_hashes, + fanout, + max_levels_per_pass, + specific_level_count, + ); let chunks = Self::div_ceil(total_hashes, num_hashes_per_chunk); // initial fetch - could return entire slice - let data: &[Hash] = get_hashes(0); + let data: &[Hash] = get_hash_slice_starting_at_index(0); let data_len = data.len(); let result: Vec<_> = (0..chunks) .into_par_iter() .map(|i| { + // summary: + // this closure computes 1 or 3 levels of merkle tree (all chunks will be 1 or all will be 3) + // for a subset (our chunk) of the input data [start_index..end_index] + + // index into get_hash_slice_starting_at_index where this chunk's range begins let start_index = i * num_hashes_per_chunk; + // index into get_hash_slice_starting_at_index where this chunk's range ends let end_index = std::cmp::min(start_index + num_hashes_per_chunk, total_hashes); + // will compute the final result for this closure let mut hasher = Hasher::default(); + + // index into 'data' where we are currently pulling data + // if we exhaust our data, then we will request a new slice, and data_index resets to 0, the beginning of the new slice let mut data_index = start_index; + // source data, which we may refresh when we exhaust let mut data = data; + // len of the source data let mut data_len = data_len; if !three_level { @@ -308,8 +353,8 @@ impl AccountsHash { // The result of this loop is a single hash value from fanout input hashes. for i in start_index..end_index { if data_index >= data_len { - // fetch next slice - data = get_hashes(i); + // we exhausted our data, fetch next slice starting at i + data = get_hash_slice_starting_at_index(i); data_len = data.len(); data_index = 0; } @@ -318,7 +363,43 @@ impl AccountsHash { } } else { // hash 3 levels of fanout simultaneously. + // This codepath produces 1 hash value for between 1..=fanout^3 input hashes. + // It is equivalent to running the normal merkle tree calculation 3 iterations on the input. + // + // big idea: + // merkle trees usually reduce the input vector by a factor of fanout with each iteration + // example with fanout 2: + // start: [0,1,2,3,4,5,6,7] in our case: [...16M...] or really, 1B + // iteration0 [.5, 2.5, 4.5, 6.5] [... 1M...] + // iteration1 [1.5, 5.5] [...65k...] + // iteration2 3.5 [...4k... ] + // So iteration 0 consumes N elements, hashes them in groups of 'fanout' and produces a vector of N/fanout elements + // and the process repeats until there is only 1 hash left. + // + // With the three_level code path, we make each chunk we iterate of size fanout^3 (4096) + // So, the input could be 16M hashes and the output will be 4k hashes, or N/fanout^3 + // The goal is to reduce the amount of data that has to be constructed and held in memory. + // When we know we have enough hashes, then, in 1 pass, we hash 3 levels simultaneously, storing far fewer intermediate hashes. + // + // Now, some details: // The result of this loop is a single hash value from fanout^3 input hashes. + // concepts: + // what we're conceptually hashing: "raw_hashes"[start_index..end_index] + // example: [a,b,c,d,e,f] + // but... hashes[] may really be multiple vectors that are pieced together. + // example: [[a,b],[c],[d,e,f]] + // get_hash_slice_starting_at_index(any_index) abstracts that and returns a slice starting at raw_hashes[any_index..] + // such that the end of get_hash_slice_starting_at_index may be <, >, or = end_index + // example: get_hash_slice_starting_at_index(1) returns [b] + // get_hash_slice_starting_at_index(3) returns [d,e,f] + // This code is basically 3 iterations of merkle tree hashing occurring simultaneously. + // The first fanout raw hashes are hashed in hasher_k. This is iteration0 + // Once hasher_k has hashed fanout hashes, hasher_k's result hash is hashed in hasher_j and then discarded + // hasher_k then starts over fresh and hashes the next fanout raw hashes. This is iteration0 again for a new set of data. + // Once hasher_j has hashed fanout hashes (from k), hasher_j's result hash is hashed in hasher and then discarded + // Once hasher has hashed fanout hashes (from j), then the result of hasher is the hash for fanout^3 raw hashes. + // If there are < fanout^3 hashes, then this code stops when it runs out of raw hashes and returns whatever it hashed. + // This is always how the very last elements work in a merkle tree. let mut i = start_index; while i < end_index { let mut hasher_j = Hasher::default(); @@ -327,8 +408,8 @@ impl AccountsHash { let end = std::cmp::min(end_index - i, fanout); for _k in 0..end { if data_index >= data_len { - // fetch next slice - data = get_hashes(i); + // we exhausted our data, fetch next slice starting at i + data = get_hash_slice_starting_at_index(i); data_len = data.len(); data_index = 0; } @@ -351,13 +432,47 @@ impl AccountsHash { time.stop(); debug!("hashing {} {}", total_hashes, time); - if result.len() == 1 { - result[0] + if let Some(mut specific_level_count_value) = specific_level_count { + specific_level_count_value -= levels_hashed; + if specific_level_count_value == 0 { + (Hash::default(), result) + } else { + assert!(specific_level_count_value > 0); + // We did not hash the number of levels required by 'specific_level_count', so repeat + Self::compute_merkle_root_from_slices_recurse( + result, + fanout, + max_levels_per_pass, + Some(specific_level_count_value), + ) + } } else { - Self::compute_merkle_root_recurse(result, fanout) + ( + if result.len() == 1 { + result[0] + } else { + Self::compute_merkle_root_recurse(result, fanout) + }, + vec![], // no intermediate results needed by caller + ) } } + pub fn compute_merkle_root_from_slices_recurse( + hashes: Vec, + fanout: usize, + max_levels_per_pass: Option, + specific_level_count: Option, + ) -> (Hash, Vec) { + Self::compute_merkle_root_from_slices( + hashes.len(), + fanout, + max_levels_per_pass, + |start| &hashes[start..], + specific_level_count, + ) + } + pub fn accumulate_account_hashes(mut hashes: Vec<(Pubkey, Hash)>) -> Hash { Self::sort_hashes_by_pubkey(&mut hashes); @@ -405,7 +520,7 @@ impl AccountsHash { } flatten_time.stop(); stats.flatten_time_total_us += flatten_time.as_us(); - stats.unreduced_entries = raw_len; + stats.unreduced_entries += raw_len; data_by_pubkey } @@ -569,50 +684,120 @@ impl AccountsHash { (result, sum) } - fn flatten_hashes_and_hash( - hashes: Vec>>, - fanout: usize, - stats: &mut HashStats, - ) -> Hash { - let mut hash_time = Measure::start("flat2"); - - let offsets = CumulativeOffsets::from_raw_2d(&hashes); - - let get_slice = |start: usize| -> &[Hash] { offsets.get_slice_2d(&hashes, start) }; - let hash = AccountsHash::compute_merkle_root_from_slices( - offsets.total_count, - fanout, - None, - get_slice, - ); - hash_time.stop(); - stats.hash_time_total_us += hash_time.as_us(); - stats.hash_total = offsets.total_count; - - hash - } - // input: // vec: unordered, created by parallelism // vec: [0..bins] - where bins are pubkey ranges - // vec: [..] - items which fin in the containing bin, unordered within this vec + // vec: [..] - items which fit in the containing bin, unordered within this vec // so, assumption is middle vec is bins sorted by pubkey pub fn rest_of_hash_calculation( data_sections_by_pubkey: Vec>>, mut stats: &mut HashStats, - ) -> (Hash, u64) { + is_last_pass: bool, + mut previous_state: PreviousPass, + ) -> (Hash, u64, PreviousPass) { let outer = Self::flatten_hash_intermediate(data_sections_by_pubkey, &mut stats); let sorted_data_by_pubkey = Self::sort_hash_intermediate(outer, &mut stats); - let (hashes, total_lamports) = + let (mut hashes, mut total_lamports) = Self::de_dup_and_eliminate_zeros(sorted_data_by_pubkey, &mut stats); - let hash = Self::flatten_hashes_and_hash(hashes, MERKLE_FANOUT, &mut stats); + total_lamports += previous_state.lamports; - stats.log(); + if !previous_state.remaining_unhashed.is_empty() { + // these items were not hashed last iteration because they didn't divide evenly + hashes.insert(0, vec![previous_state.remaining_unhashed]); + previous_state.remaining_unhashed = Vec::new(); + } - (hash, total_lamports) + let mut next_pass = PreviousPass::default(); + let cumulative = CumulativeOffsets::from_raw_2d(&hashes); + let mut hash_total = cumulative.total_count; + stats.hash_total += hash_total; + next_pass.reduced_hashes = previous_state.reduced_hashes; + + const TARGET_FANOUT_LEVEL: usize = 3; + let target_fanout = MERKLE_FANOUT.pow(TARGET_FANOUT_LEVEL as u32); + + if !is_last_pass { + next_pass.lamports = total_lamports; + total_lamports = 0; + + // Save hashes that don't evenly hash. They will be combined with hashes from the next pass. + let left_over_hashes = hash_total % target_fanout; + + // move tail hashes that don't evenly hash into a 1d vector for next time + let mut i = hash_total - left_over_hashes; + while i < hash_total { + let data = cumulative.get_slice_2d(&hashes, i); + next_pass.remaining_unhashed.extend(data); + i += data.len(); + } + + hash_total -= left_over_hashes; // this is enough to cause the hashes at the end of the data set to be ignored + } + + // if we have raw hashes to process and + // we are not the last pass (we already modded against target_fanout) OR + // we have previously surpassed target_fanout and hashed some already to the target_fanout level. In that case, we know + // we need to hash whatever is left here to the target_fanout level. + if hash_total != 0 && (!is_last_pass || !next_pass.reduced_hashes.is_empty()) { + let mut hash_time = Measure::start("hash"); + let partial_hashes = Self::compute_merkle_root_from_slices( + hash_total, // note this does not include the ones that didn't divide evenly, unless we're in the last iteration + MERKLE_FANOUT, + Some(TARGET_FANOUT_LEVEL), + |start| cumulative.get_slice_2d(&hashes, start), + Some(TARGET_FANOUT_LEVEL), + ) + .1; + hash_time.stop(); + stats.hash_time_total_us += hash_time.as_us(); + next_pass.reduced_hashes.push(partial_hashes); + } + + let no_progress = is_last_pass && next_pass.reduced_hashes.is_empty() && !hashes.is_empty(); + if no_progress { + // we never made partial progress, so hash everything now + hashes.into_iter().for_each(|v| { + v.into_iter().for_each(|v| { + if !v.is_empty() { + next_pass.reduced_hashes.push(v); + } + }); + }); + } + + let hash = if is_last_pass { + let cumulative = CumulativeOffsets::from_raw(&next_pass.reduced_hashes); + + let hash = if cumulative.total_count == 1 && !no_progress { + // all the passes resulted in a single hash, that means we're done, so we had <= MERKLE_ROOT total hashes + cumulative.get_slice(&next_pass.reduced_hashes, 0)[0] + } else { + let mut hash_time = Measure::start("hash"); + // hash all the rest and combine and hash until we have only 1 hash left + let (hash, _) = Self::compute_merkle_root_from_slices( + cumulative.total_count, + MERKLE_FANOUT, + None, + |start| cumulative.get_slice(&next_pass.reduced_hashes, start), + None, + ); + hash_time.stop(); + stats.hash_time_total_us += hash_time.as_us(); + hash + }; + next_pass.reduced_hashes = Vec::new(); + hash + } else { + Hash::default() + }; + + if is_last_pass { + stats.log(); + } + (hash, total_lamports, next_pass) } } @@ -662,6 +847,8 @@ pub mod tests { let result = AccountsHash::rest_of_hash_calculation( vec![vec![account_maps.clone()]], &mut HashStats::default(), + true, + PreviousPass::default(), ); let expected_hash = Hash::from_str("8j9ARGFv4W2GfML7d3sVJK2MePwrikqYnu6yqer28cCa").unwrap(); assert_eq!((result.0, result.1), (expected_hash, 88)); @@ -675,6 +862,8 @@ pub mod tests { let result = AccountsHash::rest_of_hash_calculation( vec![vec![account_maps.clone()]], &mut HashStats::default(), + true, + PreviousPass::default(), ); let expected_hash = Hash::from_str("EHv9C5vX7xQjjMpsJMzudnDTzoTSRwYkqLzY8tVMihGj").unwrap(); assert_eq!((result.0, result.1), (expected_hash, 108)); @@ -688,11 +877,302 @@ pub mod tests { let result = AccountsHash::rest_of_hash_calculation( vec![vec![account_maps]], &mut HashStats::default(), + true, + PreviousPass::default(), ); let expected_hash = Hash::from_str("7NNPg5A8Xsg1uv4UFm6KZNwsipyyUnmgCrznP6MBWoBZ").unwrap(); assert_eq!((result.0, result.1), (expected_hash, 118)); } + #[test] + fn test_accountsdb_multi_pass_rest_of_hash_calculation() { + solana_logger::setup(); + + // passes: + // 0: empty, NON-empty, empty, empty final + // 1: NON-empty, empty final + // 2: NON-empty, empty, empty final + for pass in 0..3 { + let mut account_maps: Vec = Vec::new(); + + let key = Pubkey::new(&[11u8; 32]); + let hash = Hash::new(&[1u8; 32]); + let val = CalculateHashIntermediate::new(0, hash, 88, Slot::default(), key); + account_maps.push(val); + + // 2nd key - zero lamports, so will be removed + let key = Pubkey::new(&[12u8; 32]); + let hash = Hash::new(&[2u8; 32]); + let val = CalculateHashIntermediate::new( + 0, + hash, + ZERO_RAW_LAMPORTS_SENTINEL, + Slot::default(), + key, + ); + account_maps.push(val); + + let mut previous_pass = PreviousPass::default(); + + if pass == 0 { + // first pass that is not last and is empty + let result = AccountsHash::rest_of_hash_calculation( + vec![vec![vec![]]], + &mut HashStats::default(), + false, // not last pass + previous_pass, + ); + assert_eq!(result.0, Hash::default()); + assert_eq!(result.1, 0); + previous_pass = result.2; + assert_eq!(previous_pass.remaining_unhashed.len(), 0); + assert_eq!(previous_pass.reduced_hashes.len(), 0); + assert_eq!(previous_pass.lamports, 0); + } + + let result = AccountsHash::rest_of_hash_calculation( + vec![vec![account_maps.clone()]], + &mut HashStats::default(), + false, // not last pass + previous_pass, + ); + + assert_eq!(result.0, Hash::default()); + assert_eq!(result.1, 0); + let mut previous_pass = result.2; + assert_eq!(previous_pass.remaining_unhashed, vec![account_maps[0].hash]); + assert_eq!(previous_pass.reduced_hashes.len(), 0); + assert_eq!(previous_pass.lamports, account_maps[0].lamports); + + let expected_hash = + Hash::from_str("8j9ARGFv4W2GfML7d3sVJK2MePwrikqYnu6yqer28cCa").unwrap(); + if pass == 2 { + let result = AccountsHash::rest_of_hash_calculation( + vec![vec![vec![]]], + &mut HashStats::default(), + false, + previous_pass, + ); + + previous_pass = result.2; + assert_eq!(previous_pass.remaining_unhashed, vec![account_maps[0].hash]); + assert_eq!(previous_pass.reduced_hashes.len(), 0); + assert_eq!(previous_pass.lamports, account_maps[0].lamports); + } + + let result = AccountsHash::rest_of_hash_calculation( + vec![vec![vec![]]], + &mut HashStats::default(), + true, // finally, last pass + previous_pass, + ); + let previous_pass = result.2; + + assert_eq!(previous_pass.remaining_unhashed.len(), 0); + assert_eq!(previous_pass.reduced_hashes.len(), 0); + assert_eq!(previous_pass.lamports, 0); + + assert_eq!((result.0, result.1), (expected_hash, 88)); + } + } + + #[test] + fn test_accountsdb_multi_pass_rest_of_hash_calculation_partial() { + solana_logger::setup(); + + let mut account_maps: Vec = Vec::new(); + + let key = Pubkey::new(&[11u8; 32]); + let hash = Hash::new(&[1u8; 32]); + let val = CalculateHashIntermediate::new(0, hash, 88, Slot::default(), key); + account_maps.push(val); + + let key = Pubkey::new(&[12u8; 32]); + let hash = Hash::new(&[2u8; 32]); + let val = CalculateHashIntermediate::new(0, hash, 20, Slot::default(), key); + account_maps.push(val); + + let result = AccountsHash::rest_of_hash_calculation( + vec![vec![vec![account_maps[0].clone()]]], + &mut HashStats::default(), + false, // not last pass + PreviousPass::default(), + ); + + assert_eq!(result.0, Hash::default()); + assert_eq!(result.1, 0); + let previous_pass = result.2; + assert_eq!(previous_pass.remaining_unhashed, vec![account_maps[0].hash]); + assert_eq!(previous_pass.reduced_hashes.len(), 0); + assert_eq!(previous_pass.lamports, account_maps[0].lamports); + + let result = AccountsHash::rest_of_hash_calculation( + vec![vec![vec![account_maps[1].clone()]]], + &mut HashStats::default(), + false, // not last pass + previous_pass, + ); + + assert_eq!(result.0, Hash::default()); + assert_eq!(result.1, 0); + let previous_pass = result.2; + assert_eq!( + previous_pass.remaining_unhashed, + vec![account_maps[0].hash, account_maps[1].hash] + ); + assert_eq!(previous_pass.reduced_hashes.len(), 0); + let total_lamports_expected = account_maps[0].lamports + account_maps[1].lamports; + assert_eq!(previous_pass.lamports, total_lamports_expected); + + let result = AccountsHash::rest_of_hash_calculation( + vec![vec![vec![]]], + &mut HashStats::default(), + true, + previous_pass, + ); + + let previous_pass = result.2; + assert_eq!(previous_pass.remaining_unhashed.len(), 0); + assert_eq!(previous_pass.reduced_hashes.len(), 0); + assert_eq!(previous_pass.lamports, 0); + + let expected_hash = AccountsHash::compute_merkle_root( + account_maps + .iter() + .map(|a| (a.pubkey, a.hash)) + .collect::>(), + MERKLE_FANOUT, + ); + + assert_eq!( + (result.0, result.1), + (expected_hash, total_lamports_expected) + ); + } + + #[test] + fn test_accountsdb_multi_pass_rest_of_hash_calculation_partial_hashes() { + solana_logger::setup(); + + let mut account_maps: Vec = Vec::new(); + + const TARGET_FANOUT_LEVEL: usize = 3; + let target_fanout = MERKLE_FANOUT.pow(TARGET_FANOUT_LEVEL as u32); + let mut total_lamports_expected = 0; + let plus1 = target_fanout + 1; + for i in 0..plus1 * 2 { + let lamports = (i + 1) as u64; + total_lamports_expected += lamports; + let key = Pubkey::new_unique(); + let hash = Hash::new_unique(); + let val = CalculateHashIntermediate::new(0, hash, lamports, Slot::default(), key); + account_maps.push(val); + } + + let chunk = account_maps[0..plus1].to_vec(); + let mut sorted = chunk.clone(); + sorted.sort_by(AccountsHash::compare_two_hash_entries); + + // first 4097 hashes (1 left over) + let result = AccountsHash::rest_of_hash_calculation( + vec![vec![chunk]], + &mut HashStats::default(), + false, // not last pass + PreviousPass::default(), + ); + + assert_eq!(result.0, Hash::default()); + assert_eq!(result.1, 0); + let previous_pass = result.2; + let left_over_1 = sorted[plus1 - 1].hash; + assert_eq!(previous_pass.remaining_unhashed, vec![left_over_1]); + assert_eq!(previous_pass.reduced_hashes.len(), 1); + let expected_hash = AccountsHash::compute_merkle_root( + sorted[0..target_fanout] + .iter() + .map(|a| (a.pubkey, a.hash)) + .collect::>(), + MERKLE_FANOUT, + ); + assert_eq!(previous_pass.reduced_hashes[0], vec![expected_hash]); + assert_eq!( + previous_pass.lamports, + account_maps[0..plus1] + .iter() + .map(|i| i.lamports) + .sum::() + ); + + let chunk = account_maps[plus1..plus1 * 2].to_vec(); + let mut sorted2 = chunk.clone(); + sorted2.sort_by(AccountsHash::compare_two_hash_entries); + + let mut with_left_over = vec![left_over_1]; + with_left_over.extend(sorted2[0..plus1 - 2].to_vec().into_iter().map(|i| i.hash)); + let expected_hash2 = AccountsHash::compute_merkle_root( + with_left_over[0..target_fanout] + .iter() + .map(|a| (Pubkey::default(), *a)) + .collect::>(), + MERKLE_FANOUT, + ); + + // second 4097 hashes (2 left over) + let result = AccountsHash::rest_of_hash_calculation( + vec![vec![chunk]], + &mut HashStats::default(), + false, // not last pass + previous_pass, + ); + + assert_eq!(result.0, Hash::default()); + assert_eq!(result.1, 0); + let previous_pass = result.2; + assert_eq!( + previous_pass.remaining_unhashed, + vec![sorted2[plus1 - 2].hash, sorted2[plus1 - 1].hash] + ); + assert_eq!(previous_pass.reduced_hashes.len(), 2); + assert_eq!( + previous_pass.reduced_hashes, + vec![vec![expected_hash], vec![expected_hash2]] + ); + assert_eq!( + previous_pass.lamports, + account_maps[0..plus1 * 2] + .iter() + .map(|i| i.lamports) + .sum::() + ); + + let result = AccountsHash::rest_of_hash_calculation( + vec![vec![vec![]]], + &mut HashStats::default(), + true, + previous_pass, + ); + + let previous_pass = result.2; + assert_eq!(previous_pass.remaining_unhashed.len(), 0); + assert_eq!(previous_pass.reduced_hashes.len(), 0); + assert_eq!(previous_pass.lamports, 0); + + let mut combined = sorted; + combined.extend(sorted2); + let expected_hash = AccountsHash::compute_merkle_root( + combined + .iter() + .map(|a| (a.pubkey, a.hash)) + .collect::>(), + MERKLE_FANOUT, + ); + + assert_eq!( + (result.0, result.1), + (expected_hash, total_lamports_expected) + ); + } + #[test] fn test_accountsdb_de_dup_accounts_zero_chunks() { let (hashes, lamports) = @@ -923,40 +1403,6 @@ pub mod tests { } } - #[test] - fn test_accountsdb_flatten_hashes_and_hash() { - solana_logger::setup(); - const COUNT: usize = 4; - let hashes: Vec<_> = (0..COUNT) - .into_iter() - .map(|i| Hash::new(&[(i) as u8; 32])) - .collect(); - let expected = - AccountsHash::compute_merkle_root_loop(hashes.clone(), MERKLE_FANOUT, |i| *i); - - assert_eq!( - AccountsHash::flatten_hashes_and_hash( - vec![vec![hashes.clone()]], - MERKLE_FANOUT, - &mut HashStats::default() - ), - expected, - ); - for in_first in 1..COUNT - 1 { - assert_eq!( - AccountsHash::flatten_hashes_and_hash( - vec![vec![ - hashes.clone()[0..in_first].to_vec(), - hashes.clone()[in_first..COUNT].to_vec() - ]], - MERKLE_FANOUT, - &mut HashStats::default() - ), - expected - ); - } - } - #[test] fn test_sort_hash_intermediate() { solana_logger::setup(); @@ -1367,25 +1813,23 @@ pub mod tests { let temp: Vec<_> = hashes.iter().map(|h| (Pubkey::default(), *h)).collect(); let result = AccountsHash::compute_merkle_root(temp, fanout); let reduced: Vec<_> = hashes.clone(); - let result2 = - AccountsHash::compute_merkle_root_from_slices(hashes.len(), fanout, None, |start| { - &reduced[start..] - }); - assert_eq!(result, result2, "len: {}", hashes.len()); - - let result2 = - AccountsHash::compute_merkle_root_from_slices(hashes.len(), fanout, Some(1), |start| { - &reduced[start..] - }); - assert_eq!(result, result2, "len: {}", hashes.len()); - - let reduced2: Vec<_> = hashes.iter().map(|x| vec![*x]).collect(); - let result2 = AccountsHash::flatten_hashes_and_hash( - vec![reduced2], + let result2 = AccountsHash::compute_merkle_root_from_slices( + hashes.len(), fanout, - &mut HashStats::default(), + None, + |start| &reduced[start..], + None, ); - assert_eq!(result, result2, "len: {}", hashes.len()); + assert_eq!(result, result2.0, "len: {}", hashes.len()); + + let result2 = AccountsHash::compute_merkle_root_from_slices( + hashes.len(), + fanout, + Some(1), + |start| &reduced[start..], + None, + ); + assert_eq!(result, result2.0, "len: {}", hashes.len()); let max = std::cmp::min(reduced.len(), fanout * 2); for left in 0..max { @@ -1394,9 +1838,17 @@ pub mod tests { vec![reduced[0..left].to_vec(), reduced[left..right].to_vec()], vec![reduced[right..].to_vec()], ]; - let result2 = - AccountsHash::flatten_hashes_and_hash(src, fanout, &mut HashStats::default()); - assert_eq!(result, result2); + let offsets = CumulativeOffsets::from_raw_2d(&src); + + let get_slice = |start: usize| -> &[Hash] { offsets.get_slice_2d(&src, start) }; + let result2 = AccountsHash::compute_merkle_root_from_slices( + offsets.total_count, + fanout, + None, + get_slice, + None, + ); + assert_eq!(result, result2.0); } } result