From af7b7f143628d9dfdf75ea464be3ebb4c7a3653b Mon Sep 17 00:00:00 2001 From: HaoranYi Date: Fri, 16 Jun 2023 08:40:23 -0500 Subject: [PATCH] Refactor reward block calculation fn (#32167) refactor reward block calculation fn Co-authored-by: HaoranYi --- runtime/src/bank.rs | 11 +++++------ runtime/src/bank/tests.rs | 26 +++++++++++--------------- 2 files changed, 16 insertions(+), 21 deletions(-) diff --git a/runtime/src/bank.rs b/runtime/src/bank.rs index f43140634a..bfc05fad91 100644 --- a/runtime/src/bank.rs +++ b/runtime/src/bank.rs @@ -1528,7 +1528,8 @@ impl Bank { #[allow(dead_code)] /// Calculate the number of blocks required to distribute rewards to all stake accounts. - fn get_reward_distribution_num_blocks(&self, total_stake_accounts: usize) -> u64 { + fn get_reward_distribution_num_blocks(&self, rewards: &StakeRewards) -> u64 { + let total_stake_accounts = rewards.len(); if self.epoch_schedule.warmup && self.epoch < self.first_normal_epoch() { 1 } else { @@ -1544,9 +1545,8 @@ impl Bank { #[allow(dead_code)] /// Return the total number of blocks in reward interval (including both calculation and crediting). - fn get_reward_total_num_blocks(&self, total_stake_accounts: usize) -> u64 { - self.get_reward_calculation_num_blocks() - + self.get_reward_distribution_num_blocks(total_stake_accounts) + fn get_reward_total_num_blocks(&self, rewards: &StakeRewards) -> u64 { + self.get_reward_calculation_num_blocks() + self.get_reward_distribution_num_blocks(rewards) } #[allow(dead_code)] @@ -2642,8 +2642,7 @@ impl Bank { ) .unwrap_or_default(); - let num_partitions = - self.get_reward_distribution_num_blocks(stake_rewards.stake_rewards.len()); + let num_partitions = self.get_reward_distribution_num_blocks(&stake_rewards.stake_rewards); let stake_rewards_by_partition = hash_rewards_into_partitions( std::mem::take(&mut stake_rewards.stake_rewards), &self.parent_hash(), diff --git a/runtime/src/bank/tests.rs b/runtime/src/bank/tests.rs index 131c458f04..ed6940f7b5 100644 --- a/runtime/src/bank/tests.rs +++ b/runtime/src/bank/tests.rs @@ -13196,14 +13196,11 @@ fn test_get_reward_distribution_num_blocks_normal() { .map(|_| StakeReward::new_random()) .collect::>(); - assert_eq!( - bank.get_reward_distribution_num_blocks(stake_rewards.len()), - 2 - ); + assert_eq!(bank.get_reward_distribution_num_blocks(&stake_rewards), 2); assert_eq!(bank.get_reward_calculation_num_blocks(), 1); assert_eq!( - bank.get_reward_total_num_blocks(stake_rewards.len()), - bank.get_reward_distribution_num_blocks(stake_rewards.len()) + bank.get_reward_total_num_blocks(&stake_rewards), + bank.get_reward_distribution_num_blocks(&stake_rewards) + bank.get_reward_calculation_num_blocks(), ); } @@ -13224,14 +13221,11 @@ fn test_get_reward_distribution_num_blocks_cap() { .map(|_| StakeReward::new_random()) .collect::>(); - assert_eq!( - bank.get_reward_distribution_num_blocks(stake_rewards.len()), - 1 - ); + assert_eq!(bank.get_reward_distribution_num_blocks(&stake_rewards), 1); assert_eq!(bank.get_reward_calculation_num_blocks(), 1); assert_eq!( - bank.get_reward_total_num_blocks(stake_rewards.len()), - bank.get_reward_distribution_num_blocks(stake_rewards.len()) + bank.get_reward_total_num_blocks(&stake_rewards), + bank.get_reward_distribution_num_blocks(&stake_rewards) + bank.get_reward_calculation_num_blocks(), ); } @@ -13243,11 +13237,13 @@ fn test_get_reward_distribution_num_blocks_warmup() { let (genesis_config, _mint_keypair) = create_genesis_config(1_000_000 * LAMPORTS_PER_SOL); let bank = Bank::new_for_tests(&genesis_config); - assert_eq!(bank.get_reward_distribution_num_blocks(0), 1); + let rewards = vec![]; + assert_eq!(bank.get_reward_distribution_num_blocks(&rewards), 1); assert_eq!(bank.get_reward_calculation_num_blocks(), 1); assert_eq!( - bank.get_reward_total_num_blocks(0), - bank.get_reward_distribution_num_blocks(0) + bank.get_reward_calculation_num_blocks(), + bank.get_reward_total_num_blocks(&rewards), + bank.get_reward_distribution_num_blocks(&rewards) + + bank.get_reward_calculation_num_blocks(), ); }