From 6587dbfa4731687d3a7e6bc9a6718d6160c13973 Mon Sep 17 00:00:00 2001 From: Tao Zhu Date: Wed, 2 Feb 2022 19:56:16 -0600 Subject: [PATCH] use EMA in place of Welford --- core/src/cost_update_service.rs | 9 +++--- runtime/src/cost_model.rs | 4 +-- runtime/src/execute_cost_table.rs | 51 +++++++++++++++++-------------- 3 files changed, 34 insertions(+), 30 deletions(-) diff --git a/core/src/cost_update_service.rs b/core/src/cost_update_service.rs index 900cce52cb..a7f1571aca 100644 --- a/core/src/cost_update_service.rs +++ b/core/src/cost_update_service.rs @@ -270,9 +270,8 @@ mod tests { let accumulated_us: u64 = 2000; let accumulated_units: u64 = 200; let count: u32 = 10; - // to expect new cost = (mean + 2 * std) of [10, 20] = 25, where - // mean = (10+20)/2 = 15; std=5 - expected_cost = 25; + // to expect new cost = (mean + 2 * std) + expected_cost = 24; execute_timings.details.per_program_timings.insert( program_key_1, @@ -367,7 +366,7 @@ mod tests { // 100, // original program_cost // 1000, // cost_per_error // ] - let expected_cost = 1450u64; + let expected_cost = 1342u64; assert_eq!(1, updated_program_costs.len()); assert_eq!( Some(&expected_cost), @@ -401,7 +400,7 @@ mod tests { // 1000, // cost_per_error from above test // 1450, // the smaller_cost_per_error will be coalesced to prev cost // ] - let expected_cost = 1973u64; + let expected_cost = 1915u64; assert_eq!(1, updated_program_costs.len()); assert_eq!( Some(&expected_cost), diff --git a/runtime/src/cost_model.rs b/runtime/src/cost_model.rs index b500480294..534dae0470 100644 --- a/runtime/src/cost_model.rs +++ b/runtime/src/cost_model.rs @@ -498,8 +498,8 @@ mod tests { let key1 = Pubkey::new_unique(); let cost1 = 100; let cost2 = 200; - // updated_cost = (mean + 2*std) = 150 + 2 * 50 = 250 - let updated_cost = 250; + // updated_cost = (mean + 2*std) + let updated_cost = 238; let mut cost_model = CostModel::default(); diff --git a/runtime/src/execute_cost_table.rs b/runtime/src/execute_cost_table.rs index 3b01f0fedb..c779164609 100644 --- a/runtime/src/execute_cost_table.rs +++ b/runtime/src/execute_cost_table.rs @@ -15,11 +15,15 @@ const OCCURRENCES_WEIGHT: i64 = 100; const DEFAULT_CAPACITY: usize = 1024; +// The coefficient represents the degree of weighting decrease in EMA, +// a constant smoothing factor between 0 and 1. A higher alpha +// discounts older observations faster. +const COEFFICIENT: f64 = 0.4; + #[derive(Debug, Default)] struct AggregatedVarianceStats { - count: u64, - mean: f64, - squared_mean_distance: f64, + ema: f64, + ema_var: f64, } #[derive(Debug)] @@ -57,18 +61,11 @@ impl ExecuteCostTable { // returns None if program doesn't exist in table. In this case, // it is advised to call `get_default()` for default program costdefault/ - // using Welford's Algorithm to calculate mean and std: - // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm // Program cost is estimated as 2 standard deviations above mean, eg // cost = (mean + 2 * std) pub fn get_cost(&self, key: &Pubkey) -> Option { let aggregated = self.table.get(key)?; - if aggregated.count < 1 { - None - } else { - let variance = aggregated.squared_mean_distance / aggregated.count as f64; - Some((aggregated.mean + 2.0 * variance.sqrt()).ceil() as u64) - } + Some((aggregated.ema + 2.0 * aggregated.ema_var.sqrt()).ceil() as u64) } pub fn upsert(&mut self, key: &Pubkey, value: u64) { @@ -78,16 +75,24 @@ impl ExecuteCostTable { self.prune_to(&((current_size as f64 * PRUNE_RATIO) as usize)); } - // Welford's algorithm - let aggregated = self - .table - .entry(*key) - .or_insert_with(AggregatedVarianceStats::default); - aggregated.count += 1; - let delta = value as f64 - aggregated.mean; - aggregated.mean += delta / aggregated.count as f64; - let delta_2 = value as f64 - aggregated.mean; - aggregated.squared_mean_distance += delta * delta_2; + // exponential moving average algorithm + // https://en.wikipedia.org/wiki/Moving_average#Exponentially_weighted_moving_variance_and_standard_deviation + if self.table.contains_key(key) { + let aggregated = self.table.get_mut(key).unwrap(); + let theta = value as f64 - aggregated.ema; + aggregated.ema += theta * COEFFICIENT; + aggregated.ema_var = + (1.0 - COEFFICIENT) * (aggregated.ema_var + COEFFICIENT * theta * theta) + } else { + // the starting values + self.table.insert( + *key, + AggregatedVarianceStats { + ema: value as f64, + ema_var: 0.0, + }, + ); + } let (count, timestamp) = self .occurrences @@ -231,7 +236,7 @@ mod tests { testee.upsert(&key1, cost2); assert_eq!(2, testee.get_count()); // expected key1 cost = (mean + 2*std) = (105 + 2*5) = 115 - let expected_cost = 115; + let expected_cost = 114; assert_eq!(expected_cost, testee.get_cost(&key1).unwrap()); assert_eq!(cost2, testee.get_cost(&key2).unwrap()); } @@ -276,7 +281,7 @@ mod tests { assert_eq!(2, testee.get_count()); assert!(testee.get_cost(&key1).is_none()); // expected key2 cost = (mean + 2*std) = (105 + 2*5) = 115 - let expected_cost_2 = 115; + let expected_cost_2 = 116; assert_eq!(expected_cost_2, testee.get_cost(&key2).unwrap()); assert!(testee.get_cost(&key3).is_none()); assert_eq!(cost4, testee.get_cost(&key4).unwrap());