use EMA in place of Welford

This commit is contained in:
Tao Zhu 2022-02-02 19:56:16 -06:00 committed by Tao Zhu
parent a25ac1c988
commit 6587dbfa47
3 changed files with 34 additions and 30 deletions

View File

@ -270,9 +270,8 @@ mod tests {
let accumulated_us: u64 = 2000;
let accumulated_units: u64 = 200;
let count: u32 = 10;
// to expect new cost = (mean + 2 * std) of [10, 20] = 25, where
// mean = (10+20)/2 = 15; std=5
expected_cost = 25;
// to expect new cost = (mean + 2 * std)
expected_cost = 24;
execute_timings.details.per_program_timings.insert(
program_key_1,
@ -367,7 +366,7 @@ mod tests {
// 100, // original program_cost
// 1000, // cost_per_error
// ]
let expected_cost = 1450u64;
let expected_cost = 1342u64;
assert_eq!(1, updated_program_costs.len());
assert_eq!(
Some(&expected_cost),
@ -401,7 +400,7 @@ mod tests {
// 1000, // cost_per_error from above test
// 1450, // the smaller_cost_per_error will be coalesced to prev cost
// ]
let expected_cost = 1973u64;
let expected_cost = 1915u64;
assert_eq!(1, updated_program_costs.len());
assert_eq!(
Some(&expected_cost),

View File

@ -498,8 +498,8 @@ mod tests {
let key1 = Pubkey::new_unique();
let cost1 = 100;
let cost2 = 200;
// updated_cost = (mean + 2*std) = 150 + 2 * 50 = 250
let updated_cost = 250;
// updated_cost = (mean + 2*std)
let updated_cost = 238;
let mut cost_model = CostModel::default();

View File

@ -15,11 +15,15 @@ const OCCURRENCES_WEIGHT: i64 = 100;
const DEFAULT_CAPACITY: usize = 1024;
// The coefficient represents the degree of weighting decrease in EMA,
// a constant smoothing factor between 0 and 1. A higher alpha
// discounts older observations faster.
const COEFFICIENT: f64 = 0.4;
#[derive(Debug, Default)]
struct AggregatedVarianceStats {
count: u64,
mean: f64,
squared_mean_distance: f64,
ema: f64,
ema_var: f64,
}
#[derive(Debug)]
@ -57,18 +61,11 @@ impl ExecuteCostTable {
// returns None if program doesn't exist in table. In this case,
// it is advised to call `get_default()` for default program costdefault/
// using Welford's Algorithm to calculate mean and std:
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
// Program cost is estimated as 2 standard deviations above mean, eg
// cost = (mean + 2 * std)
pub fn get_cost(&self, key: &Pubkey) -> Option<u64> {
let aggregated = self.table.get(key)?;
if aggregated.count < 1 {
None
} else {
let variance = aggregated.squared_mean_distance / aggregated.count as f64;
Some((aggregated.mean + 2.0 * variance.sqrt()).ceil() as u64)
}
Some((aggregated.ema + 2.0 * aggregated.ema_var.sqrt()).ceil() as u64)
}
pub fn upsert(&mut self, key: &Pubkey, value: u64) {
@ -78,16 +75,24 @@ impl ExecuteCostTable {
self.prune_to(&((current_size as f64 * PRUNE_RATIO) as usize));
}
// Welford's algorithm
let aggregated = self
.table
.entry(*key)
.or_insert_with(AggregatedVarianceStats::default);
aggregated.count += 1;
let delta = value as f64 - aggregated.mean;
aggregated.mean += delta / aggregated.count as f64;
let delta_2 = value as f64 - aggregated.mean;
aggregated.squared_mean_distance += delta * delta_2;
// exponential moving average algorithm
// https://en.wikipedia.org/wiki/Moving_average#Exponentially_weighted_moving_variance_and_standard_deviation
if self.table.contains_key(key) {
let aggregated = self.table.get_mut(key).unwrap();
let theta = value as f64 - aggregated.ema;
aggregated.ema += theta * COEFFICIENT;
aggregated.ema_var =
(1.0 - COEFFICIENT) * (aggregated.ema_var + COEFFICIENT * theta * theta)
} else {
// the starting values
self.table.insert(
*key,
AggregatedVarianceStats {
ema: value as f64,
ema_var: 0.0,
},
);
}
let (count, timestamp) = self
.occurrences
@ -231,7 +236,7 @@ mod tests {
testee.upsert(&key1, cost2);
assert_eq!(2, testee.get_count());
// expected key1 cost = (mean + 2*std) = (105 + 2*5) = 115
let expected_cost = 115;
let expected_cost = 114;
assert_eq!(expected_cost, testee.get_cost(&key1).unwrap());
assert_eq!(cost2, testee.get_cost(&key2).unwrap());
}
@ -276,7 +281,7 @@ mod tests {
assert_eq!(2, testee.get_count());
assert!(testee.get_cost(&key1).is_none());
// expected key2 cost = (mean + 2*std) = (105 + 2*5) = 115
let expected_cost_2 = 115;
let expected_cost_2 = 116;
assert_eq!(expected_cost_2, testee.get_cost(&key2).unwrap());
assert!(testee.get_cost(&key3).is_none());
assert_eq!(cost4, testee.get_cost(&key4).unwrap());