refactor(hermes): state->price_feed_metadata downcasting

This commit is contained in:
Reisen 2024-04-10 14:03:59 +00:00 committed by Reisen
parent d1c5d93c8e
commit ce4019b63f
6 changed files with 133 additions and 74 deletions

View File

@ -426,7 +426,7 @@ where
pub async fn is_ready(state: &State) -> bool {
let metadata = state.aggregate_state.read().await;
let price_feeds_metadata = state.price_feeds_metadata.read().await;
let price_feeds_metadata = state.price_feed_meta.data.read().await;
let has_completed_recently = match metadata.latest_completed_update_at.as_ref() {
Some(latest_completed_update_time) => {
@ -456,7 +456,7 @@ mod test {
super::*,
crate::{
api::types::PriceFeedMetadata,
price_feeds_metadata::store_price_feeds_metadata,
price_feeds_metadata::PriceFeedMeta,
state::test::setup_state,
},
futures::future::join_all,
@ -809,15 +809,13 @@ mod test {
// Add a dummy price feeds metadata
store_price_feeds_metadata(
&state,
&[PriceFeedMetadata {
state
.store_price_feeds_metadata(&[PriceFeedMetadata {
id: PriceIdentifier::new([100; 32]),
attributes: Default::default(),
}],
)
.await
.unwrap();
}])
.await
.unwrap();
// Check the state is ready
assert!(is_ready(&state).await);

View File

@ -26,15 +26,27 @@ mod rest;
pub mod types;
mod ws;
#[derive(Clone)]
pub struct ApiState {
pub state: Arc<State>,
pub struct ApiState<S = State> {
pub state: Arc<S>,
pub ws: Arc<ws::WsState>,
pub metrics: Arc<metrics_middleware::Metrics>,
pub update_tx: Sender<AggregationEvent>,
}
impl ApiState {
/// Manually implement `Clone` as the derive macro will try and slap `Clone` on
/// `State` which should not be Clone.
impl<S> Clone for ApiState<S> {
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
ws: self.ws.clone(),
metrics: self.metrics.clone(),
update_tx: self.update_tx.clone(),
}
}
}
impl ApiState<State> {
pub fn new(
state: Arc<State>,
ws_whitelist: Vec<IpNet>,

View File

@ -6,8 +6,9 @@ use {
AssetType,
PriceFeedMetadata,
},
ApiState,
},
price_feeds_metadata::get_price_feeds_metadata,
price_feeds_metadata::PriceFeedMeta,
},
anyhow::Result,
axum::{
@ -46,19 +47,23 @@ pub struct PriceFeedsMetadataQueryParams {
PriceFeedsMetadataQueryParams
)
)]
pub async fn price_feeds_metadata(
State(state): State<crate::api::ApiState>,
pub async fn price_feeds_metadata<S>(
State(state): State<ApiState<S>>,
QsQuery(params): QsQuery<PriceFeedsMetadataQueryParams>,
) -> Result<Json<Vec<PriceFeedMetadata>>, RestError> {
let price_feeds_metadata =
get_price_feeds_metadata(&state.state, params.query, params.asset_type)
.await
.map_err(|e| {
tracing::warn!("RPC connection error: {}", e);
RestError::RpcConnectionError {
message: format!("RPC connection error: {}", e),
}
})?;
) -> Result<Json<Vec<PriceFeedMetadata>>, RestError>
where
S: PriceFeedMeta,
{
let state = &state.state;
let price_feeds_metadata = state
.get_price_feeds_metadata(params.query, params.asset_type)
.await
.map_err(|e| {
tracing::warn!("RPC connection error: {}", e);
RestError::RpcConnectionError {
message: format!("RPC connection error: {}", e),
}
})?;
Ok(Json(price_feeds_metadata))
}

View File

@ -17,7 +17,7 @@ use {
GuardianSetData,
},
price_feeds_metadata::{
store_price_feeds_metadata,
PriceFeedMeta,
DEFAULT_PRICE_FEEDS_CACHE_UPDATE_INTERVAL,
},
state::State,
@ -353,13 +353,18 @@ pub async fn spawn(opts: RunOptions, state: Arc<State>) -> Result<()> {
}
pub async fn fetch_and_store_price_feeds_metadata(
state: &State,
pub async fn fetch_and_store_price_feeds_metadata<S>(
state: &S,
mapping_address: &Pubkey,
rpc_client: &RpcClient,
) -> Result<Vec<PriceFeedMetadata>> {
) -> Result<Vec<PriceFeedMetadata>>
where
S: PriceFeedMeta,
{
let price_feeds_metadata = fetch_price_feeds_metadata(mapping_address, rpc_client).await?;
store_price_feeds_metadata(state, &price_feeds_metadata).await?;
state
.store_price_feeds_metadata(&price_feeds_metadata)
.await?;
Ok(price_feeds_metadata)
}

View File

@ -7,49 +7,88 @@ use {
state::State,
},
anyhow::Result,
tokio::sync::RwLock,
};
pub const DEFAULT_PRICE_FEEDS_CACHE_UPDATE_INTERVAL: u64 = 600;
pub async fn retrieve_price_feeds_metadata(state: &State) -> Result<Vec<PriceFeedMetadata>> {
let price_feeds_metadata = state.price_feeds_metadata.read().await;
Ok(price_feeds_metadata.clone())
pub struct PriceFeedMetaState {
pub data: RwLock<Vec<PriceFeedMetadata>>,
}
pub async fn store_price_feeds_metadata(
state: &State,
price_feeds_metadata: &[PriceFeedMetadata],
) -> Result<()> {
let mut price_feeds_metadata_write_guard = state.price_feeds_metadata.write().await;
*price_feeds_metadata_write_guard = price_feeds_metadata.to_vec();
Ok(())
impl PriceFeedMetaState {
pub fn new() -> Self {
Self {
data: RwLock::new(Vec::new()),
}
}
}
/// Allow downcasting State into CacheState for functions that depend on the `Cache` service.
impl<'a> From<&'a State> for &'a PriceFeedMetaState {
fn from(state: &'a State) -> &'a PriceFeedMetaState {
&state.price_feed_meta
}
}
pub async fn get_price_feeds_metadata(
state: &State,
query: Option<String>,
asset_type: Option<AssetType>,
) -> Result<Vec<PriceFeedMetadata>> {
let mut price_feeds_metadata = retrieve_price_feeds_metadata(state).await?;
pub trait PriceFeedMeta {
async fn retrieve_price_feeds_metadata(&self) -> Result<Vec<PriceFeedMetadata>>;
async fn store_price_feeds_metadata(
&self,
price_feeds_metadata: &[PriceFeedMetadata],
) -> Result<()>;
async fn get_price_feeds_metadata(
&self,
query: Option<String>,
asset_type: Option<AssetType>,
) -> Result<Vec<PriceFeedMetadata>>;
}
// Filter by query if provided
if let Some(query_str) = &query {
price_feeds_metadata.retain(|feed| {
feed.attributes.get("symbol").map_or(false, |symbol| {
symbol.to_lowercase().contains(&query_str.to_lowercase())
})
});
impl<T> PriceFeedMeta for T
where
for<'a> &'a T: Into<&'a PriceFeedMetaState>,
T: Sync,
{
async fn retrieve_price_feeds_metadata(&self) -> Result<Vec<PriceFeedMetadata>> {
let price_feeds_metadata = self.into().data.read().await;
Ok(price_feeds_metadata.clone())
}
// Filter by asset_type if provided
if let Some(asset_type) = &asset_type {
price_feeds_metadata.retain(|feed| {
feed.attributes.get("asset_type").map_or(false, |type_str| {
type_str.to_lowercase() == asset_type.to_string().to_lowercase()
})
});
async fn store_price_feeds_metadata(
&self,
price_feeds_metadata: &[PriceFeedMetadata],
) -> Result<()> {
let mut price_feeds_metadata_write_guard = self.into().data.write().await;
*price_feeds_metadata_write_guard = price_feeds_metadata.to_vec();
Ok(())
}
Ok(price_feeds_metadata)
async fn get_price_feeds_metadata(
&self,
query: Option<String>,
asset_type: Option<AssetType>,
) -> Result<Vec<PriceFeedMetadata>> {
let mut price_feeds_metadata = self.retrieve_price_feeds_metadata().await?;
// Filter by query if provided
if let Some(query_str) = &query {
price_feeds_metadata.retain(|feed| {
feed.attributes.get("symbol").map_or(false, |symbol| {
symbol.to_lowercase().contains(&query_str.to_lowercase())
})
});
}
// Filter by asset_type if provided
if let Some(asset_type) = &asset_type {
price_feeds_metadata.retain(|feed| {
feed.attributes.get("asset_type").map_or(false, |type_str| {
type_str.to_lowercase() == asset_type.to_string().to_lowercase()
})
});
}
Ok(price_feeds_metadata)
}
}

View File

@ -10,8 +10,8 @@ use {
AggregateState,
AggregationEvent,
},
api::types::PriceFeedMetadata,
network::wormhole::GuardianSet,
price_feeds_metadata::PriceFeedMetaState,
},
prometheus_client::registry::Registry,
reqwest::Url,
@ -38,6 +38,9 @@ pub struct State {
/// State for the `Benchmarks` service for looking up historical updates.
pub benchmarks: BenchmarksState,
/// State for the `PriceFeedMeta` service for looking up metadata related to Pyth price feeds.
pub price_feed_meta: PriceFeedMetaState,
/// Sequence numbers of lately observed Vaas. Store uses this set
/// to ignore the previously observed Vaas as a performance boost.
pub observed_vaa_seqs: RwLock<BTreeSet<u64>>,
@ -53,9 +56,6 @@ pub struct State {
/// Metrics registry
pub metrics_registry: RwLock<Registry>,
/// Price feeds metadata
pub price_feeds_metadata: RwLock<Vec<PriceFeedMetadata>>,
}
impl State {
@ -66,14 +66,14 @@ impl State {
) -> Arc<Self> {
let mut metrics_registry = Registry::default();
Arc::new(Self {
cache: CacheState::new(cache_size),
benchmarks: BenchmarksState::new(benchmarks_endpoint),
observed_vaa_seqs: RwLock::new(Default::default()),
guardian_set: RwLock::new(Default::default()),
api_update_tx: update_tx,
aggregate_state: RwLock::new(AggregateState::new(&mut metrics_registry)),
metrics_registry: RwLock::new(metrics_registry),
price_feeds_metadata: RwLock::new(Default::default()),
cache: CacheState::new(cache_size),
benchmarks: BenchmarksState::new(benchmarks_endpoint),
price_feed_meta: PriceFeedMetaState::new(),
observed_vaa_seqs: RwLock::new(Default::default()),
guardian_set: RwLock::new(Default::default()),
api_update_tx: update_tx,
aggregate_state: RwLock::new(AggregateState::new(&mut metrics_registry)),
metrics_registry: RwLock::new(metrics_registry),
})
}
}