diff --git a/program/src/avg.rs b/program/src/avg.rs new file mode 100644 index 0000000..4d0acce --- /dev/null +++ b/program/src/avg.rs @@ -0,0 +1,64 @@ +//! utility for calculating time average + +use borsh::{BorshDeserialize, BorshSchema, BorshSerialize}; +/// TimeCumulative is value * seconds, used to calculated average +#[derive(Clone, Default, Debug, BorshSerialize, BorshDeserialize, BorshSchema, PartialEq)] +pub struct TimeCumulative { + /// value * seconds + pub cumulative: u128, + + /// last updated_at (unix time) + pub updated_at: u64, +} + +impl TimeCumulative { + /// update accumulates the time * elapsed since last update + pub fn update(&mut self, val: u64, now: u64) { + assert!(now > self.updated_at, "can only update at a later time"); + + if self.updated_at == 0 { + self.cumulative = (val as u128) * (now as u128); + self.updated_at = now; + return; + } + + let elapsed = now - self.updated_at; + self.cumulative = self.cumulative.checked_add((val as u128) * (elapsed as u128)).unwrap(); + self.updated_at = now; + } + + /// sub calculates the time average value of two cumulatives + pub fn sub(&self, before: &Self) -> u64 { + // assert!(b.updated_at > self.updated_at, ""); + let elapsed = self.updated_at.checked_sub(before.updated_at).unwrap(); + let diff = self.cumulative.checked_sub(before.cumulative).unwrap(); + + (diff / (elapsed as u128)) as u64 + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use super::TimeCumulative; + + #[test] + fn test_time_cumulative_averaging() -> Result<()> { + let start = 1611133014; + let mut tc = TimeCumulative::default(); + tc.update(100, start); + + let mut tc2 = tc.clone(); + tc2.update(50, start+10); + + let mut tc3 = tc2.clone(); + tc3.update(10, start+20); + + assert_eq!(50, tc2.sub(&tc)); + assert_eq!(30, tc3.sub(&tc)); + + + Ok(()) + } +} \ No newline at end of file diff --git a/program/src/error.rs b/program/src/error.rs index def28a4..58072f6 100644 --- a/program/src/error.rs +++ b/program/src/error.rs @@ -49,6 +49,9 @@ pub enum Error { /// Max oralces reached #[error("Max oracles reached")] MaxOralcesReached, + /// No valid value submitted + #[error("No valid value submitted")] + NoValidValue, } impl PrintProgramError for Error { @@ -69,6 +72,7 @@ impl PrintProgramError for Error { Error::InsufficientWithdrawable => msg!("Insufficient withdrawable"), Error::AggregatorKeyNotMatch => msg!("Aggregator key not match"), Error::MaxOralcesReached => msg!("Max oracles reached"), + Error::NoValidValue => msg!("No valid value submitted"), } } } diff --git a/program/src/instruction.rs b/program/src/instruction.rs index 45ff95e..1d16976 100644 --- a/program/src/instruction.rs +++ b/program/src/instruction.rs @@ -204,9 +204,9 @@ pub fn submit( #[cfg(test)] mod tests { use hex; - use super::*; - use crate::borsh_utils; use anyhow::Result; + use borsh::{BorshSerialize, BorshDeserialize}; + use super::Instruction; #[test] fn test_serialize_bytes() -> Result<()> { diff --git a/program/src/lib.rs b/program/src/lib.rs index 5af5467..1cdc80d 100644 --- a/program/src/lib.rs +++ b/program/src/lib.rs @@ -10,12 +10,13 @@ pub mod error; pub mod instruction; pub mod processor; pub mod state; +pub mod avg; #[cfg(not(feature = "no-entrypoint"))] pub mod entrypoint; use error::Error; -use state::Aggregator; +use state::{Aggregator, Submission}; /// Get median value from the aggregator account pub fn get_median(aggregator_info: &AccountInfo) -> Result { @@ -24,17 +25,25 @@ pub fn get_median(aggregator_info: &AccountInfo) -> Result { return Err(Error::NotFoundAggregator.into()); } - let submissions = aggregator.submissions; + submissions_median(&aggregator.submissions) +} +/// return the median of oracle submissions +pub fn submissions_median(submissions: &[Submission]) -> Result { let mut values = vec![]; - // if the submission value is 0, maybe the oracle is not initialized - for s in &submissions { - if s.value != 0 { + // filter out uninitialized submissions + for s in submissions { + if s.time > 0 { values.push(s.value); } } + // error if no valid values + if values.is_empty() { + return Err(Error::NoValidValue.into()); + } + // get median value values.sort(); diff --git a/program/src/processor.rs b/program/src/processor.rs index adee877..17b2367 100644 --- a/program/src/processor.rs +++ b/program/src/processor.rs @@ -1,10 +1,6 @@ //! Program state processor -use crate::{ - error::Error, - instruction::{Instruction, PAYMENT_AMOUNT}, - state::{Aggregator, Oracle}, -}; +use crate::{error::Error, instruction::{Instruction, PAYMENT_AMOUNT}, state::{Aggregator, Oracle, Submission}}; use borsh::BorshDeserialize; use solana_program::{ @@ -271,8 +267,16 @@ impl Processor { return Err(Error::SubmissonCooling.into()); } + let now = clock.unix_timestamp; oracle.withdrawable += PAYMENT_AMOUNT; - oracle.next_submit_time = clock.unix_timestamp + aggregator.submit_interval as i64; + oracle.next_submit_time = now + aggregator.submit_interval as i64; + + let cumulative = &mut aggregator.cumulative; + if (now as u64) - cumulative.updated_at >= 1 { + if let Ok(median) = super::submissions_median(&aggregator.submissions) { + cumulative.update(median, now as u64); + } + } // update aggregator Aggregator::pack(aggregator, &mut aggregator_info.data.borrow_mut())?; @@ -350,7 +354,8 @@ impl Processor { #[cfg(test)] mod tests { use super::*; - use crate::{instruction::*, state::Submission}; + use crate::{avg::TimeCumulative, instruction::*, state::Submission}; + use hex::encode; use solana_program::instruction::Instruction; use solana_sdk::account::{ create_account, create_is_signer_account_infos, Account as SolanaAccount, @@ -376,7 +381,9 @@ mod tests { } fn clock_sysvar() -> SolanaAccount { - create_account(&Clock::default(), 42) + let mut clock = Clock::default(); + clock.unix_timestamp = 6666; + create_account(&clock, 42) } fn aggregator_minimum_balance() -> u64 { @@ -659,7 +666,7 @@ mod tests { ) .unwrap(); - // remove an unexist oracle + // remove an oracle that doesn't exist assert_eq!( Err(Error::NotFoundOracle.into()), do_process_instruction( @@ -667,7 +674,7 @@ mod tests { &program_id, &aggregator_key, &aggregator_owner_key, - &Pubkey::default() + &Pubkey::new(&vec![1u8;32]) ), vec![&mut aggregator_account, &mut aggregator_owner_account,] ) @@ -759,7 +766,7 @@ mod tests { &aggregator_key, &oracle_key, &oracle_owner_key, - 1, + 100, ), vec![ &mut aggregator_account, @@ -770,6 +777,20 @@ mod tests { ) .unwrap(); + let aggregator = Aggregator::try_from_slice(&aggregator_account.data).unwrap(); + assert_eq!(aggregator.cumulative.clone(), TimeCumulative { + cumulative: 666600, + updated_at: 6666 + }); + + assert_eq!(aggregator.submissions[0].time, 6666); + assert_eq!(aggregator.submissions[0].value, 100); + + + // println!("aggregator: {:#?}", &aggregator.cumulative); + // println!("aggregator: {:#?}", aggregator.submissions[0]); + // println!("aggregator data: {}", hex::encode(&aggregator_account.data)); + // submission cooling assert_eq!( Err(Error::SubmissonCooling.into()), diff --git a/program/src/state.rs b/program/src/state.rs index b50e5b6..5889c3f 100644 --- a/program/src/state.rs +++ b/program/src/state.rs @@ -1,7 +1,7 @@ //! State transition types use borsh::{BorshDeserialize, BorshSchema, BorshSerialize}; -use crate::instruction::MAX_ORACLES; +use crate::{avg::TimeCumulative, instruction::MAX_ORACLES}; use solana_program::{ clock::UnixTimestamp, @@ -26,6 +26,10 @@ pub struct Aggregator { pub is_initialized: bool, /// authority pub owner: [u8; 32], + + /// cumulative + pub cumulative: TimeCumulative, + /// submissions pub submissions: [Submission; MAX_ORACLES], } @@ -39,7 +43,7 @@ impl IsInitialized for Aggregator { impl Sealed for Aggregator {} impl Pack for Aggregator { // 48 is submission packed length - const LEN: usize = 86 + MAX_ORACLES * 48; + const LEN: usize = 110 + MAX_ORACLES * 48; fn pack_into_slice(&self, dst: &mut [u8]) { let data = self.try_to_vec().unwrap(); @@ -123,11 +127,6 @@ mod tests { #[test] fn test_get_packed_len() { - assert_eq!( - Aggregator::get_packed_len(), - borsh_utils::get_packed_len::() - ); - assert_eq!( Oracle::get_packed_len(), borsh_utils::get_packed_len::() @@ -137,6 +136,11 @@ mod tests { Submission::get_packed_len(), borsh_utils::get_packed_len::() ); + + assert_eq!( + Aggregator::get_packed_len(), + borsh_utils::get_packed_len::() + ); } #[test]