use EMA in place of Welford
This commit is contained in:
parent
a25ac1c988
commit
6587dbfa47
|
@ -270,9 +270,8 @@ mod tests {
|
|||
let accumulated_us: u64 = 2000;
|
||||
let accumulated_units: u64 = 200;
|
||||
let count: u32 = 10;
|
||||
// to expect new cost = (mean + 2 * std) of [10, 20] = 25, where
|
||||
// mean = (10+20)/2 = 15; std=5
|
||||
expected_cost = 25;
|
||||
// to expect new cost = (mean + 2 * std)
|
||||
expected_cost = 24;
|
||||
|
||||
execute_timings.details.per_program_timings.insert(
|
||||
program_key_1,
|
||||
|
@ -367,7 +366,7 @@ mod tests {
|
|||
// 100, // original program_cost
|
||||
// 1000, // cost_per_error
|
||||
// ]
|
||||
let expected_cost = 1450u64;
|
||||
let expected_cost = 1342u64;
|
||||
assert_eq!(1, updated_program_costs.len());
|
||||
assert_eq!(
|
||||
Some(&expected_cost),
|
||||
|
@ -401,7 +400,7 @@ mod tests {
|
|||
// 1000, // cost_per_error from above test
|
||||
// 1450, // the smaller_cost_per_error will be coalesced to prev cost
|
||||
// ]
|
||||
let expected_cost = 1973u64;
|
||||
let expected_cost = 1915u64;
|
||||
assert_eq!(1, updated_program_costs.len());
|
||||
assert_eq!(
|
||||
Some(&expected_cost),
|
||||
|
|
|
@ -498,8 +498,8 @@ mod tests {
|
|||
let key1 = Pubkey::new_unique();
|
||||
let cost1 = 100;
|
||||
let cost2 = 200;
|
||||
// updated_cost = (mean + 2*std) = 150 + 2 * 50 = 250
|
||||
let updated_cost = 250;
|
||||
// updated_cost = (mean + 2*std)
|
||||
let updated_cost = 238;
|
||||
|
||||
let mut cost_model = CostModel::default();
|
||||
|
||||
|
|
|
@ -15,11 +15,15 @@ const OCCURRENCES_WEIGHT: i64 = 100;
|
|||
|
||||
const DEFAULT_CAPACITY: usize = 1024;
|
||||
|
||||
// The coefficient represents the degree of weighting decrease in EMA,
|
||||
// a constant smoothing factor between 0 and 1. A higher alpha
|
||||
// discounts older observations faster.
|
||||
const COEFFICIENT: f64 = 0.4;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct AggregatedVarianceStats {
|
||||
count: u64,
|
||||
mean: f64,
|
||||
squared_mean_distance: f64,
|
||||
ema: f64,
|
||||
ema_var: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -57,18 +61,11 @@ impl ExecuteCostTable {
|
|||
|
||||
// returns None if program doesn't exist in table. In this case,
|
||||
// it is advised to call `get_default()` for default program costdefault/
|
||||
// using Welford's Algorithm to calculate mean and std:
|
||||
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
|
||||
// Program cost is estimated as 2 standard deviations above mean, eg
|
||||
// cost = (mean + 2 * std)
|
||||
pub fn get_cost(&self, key: &Pubkey) -> Option<u64> {
|
||||
let aggregated = self.table.get(key)?;
|
||||
if aggregated.count < 1 {
|
||||
None
|
||||
} else {
|
||||
let variance = aggregated.squared_mean_distance / aggregated.count as f64;
|
||||
Some((aggregated.mean + 2.0 * variance.sqrt()).ceil() as u64)
|
||||
}
|
||||
Some((aggregated.ema + 2.0 * aggregated.ema_var.sqrt()).ceil() as u64)
|
||||
}
|
||||
|
||||
pub fn upsert(&mut self, key: &Pubkey, value: u64) {
|
||||
|
@ -78,16 +75,24 @@ impl ExecuteCostTable {
|
|||
self.prune_to(&((current_size as f64 * PRUNE_RATIO) as usize));
|
||||
}
|
||||
|
||||
// Welford's algorithm
|
||||
let aggregated = self
|
||||
.table
|
||||
.entry(*key)
|
||||
.or_insert_with(AggregatedVarianceStats::default);
|
||||
aggregated.count += 1;
|
||||
let delta = value as f64 - aggregated.mean;
|
||||
aggregated.mean += delta / aggregated.count as f64;
|
||||
let delta_2 = value as f64 - aggregated.mean;
|
||||
aggregated.squared_mean_distance += delta * delta_2;
|
||||
// exponential moving average algorithm
|
||||
// https://en.wikipedia.org/wiki/Moving_average#Exponentially_weighted_moving_variance_and_standard_deviation
|
||||
if self.table.contains_key(key) {
|
||||
let aggregated = self.table.get_mut(key).unwrap();
|
||||
let theta = value as f64 - aggregated.ema;
|
||||
aggregated.ema += theta * COEFFICIENT;
|
||||
aggregated.ema_var =
|
||||
(1.0 - COEFFICIENT) * (aggregated.ema_var + COEFFICIENT * theta * theta)
|
||||
} else {
|
||||
// the starting values
|
||||
self.table.insert(
|
||||
*key,
|
||||
AggregatedVarianceStats {
|
||||
ema: value as f64,
|
||||
ema_var: 0.0,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let (count, timestamp) = self
|
||||
.occurrences
|
||||
|
@ -231,7 +236,7 @@ mod tests {
|
|||
testee.upsert(&key1, cost2);
|
||||
assert_eq!(2, testee.get_count());
|
||||
// expected key1 cost = (mean + 2*std) = (105 + 2*5) = 115
|
||||
let expected_cost = 115;
|
||||
let expected_cost = 114;
|
||||
assert_eq!(expected_cost, testee.get_cost(&key1).unwrap());
|
||||
assert_eq!(cost2, testee.get_cost(&key2).unwrap());
|
||||
}
|
||||
|
@ -276,7 +281,7 @@ mod tests {
|
|||
assert_eq!(2, testee.get_count());
|
||||
assert!(testee.get_cost(&key1).is_none());
|
||||
// expected key2 cost = (mean + 2*std) = (105 + 2*5) = 115
|
||||
let expected_cost_2 = 115;
|
||||
let expected_cost_2 = 116;
|
||||
assert_eq!(expected_cost_2, testee.get_cost(&key2).unwrap());
|
||||
assert!(testee.get_cost(&key3).is_none());
|
||||
assert_eq!(cost4, testee.get_cost(&key4).unwrap());
|
||||
|
|
Loading…
Reference in New Issue