feat(hermes): add sse endpoint (#1425)

* add initial sse code

* fix typo

* add more error handling

* fix formatting

* revert import format

* add error handling for nonexistent price feeds in the middle of sub

* refactor

* format

* add comment

* Update hermes/src/api/sse.rs

Co-authored-by: Reisen <Reisen@users.noreply.github.com>

* refactor

* bump

---------

Co-authored-by: Reisen <Reisen@users.noreply.github.com>
This commit is contained in:
Daniel Chew 2024-04-11 11:04:27 +09:00 committed by GitHub
parent e1f9783062
commit 3c5a913a80
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 189 additions and 6 deletions

8
hermes/Cargo.lock generated
View File

@ -1796,7 +1796,7 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hermes"
version = "0.5.3"
version = "0.5.4"
dependencies = [
"anyhow",
"async-trait",
@ -1839,6 +1839,7 @@ dependencies = [
"solana-sdk",
"strum",
"tokio",
"tokio-stream",
"tonic",
"tonic-build",
"tower-http",
@ -5188,9 +5189,9 @@ dependencies = [
[[package]]
name = "termcolor"
version = "1.4.1"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755"
checksum = "bab24d30b911b2376f3a13cc2cd443142f0c81dda04c118693e35b3835757755"
dependencies = [
"winapi-util",
]
@ -5385,6 +5386,7 @@ dependencies = [
"futures-core",
"pin-project-lite",
"tokio",
"tokio-util",
]
[[package]]

View File

@ -1,6 +1,6 @@
[package]
name = "hermes"
version = "0.5.3"
version = "0.5.4"
description = "Hermes is an agent that provides Verified Prices from the Pythnet Pyth Oracle."
edition = "2021"
@ -42,6 +42,7 @@ serde_wormhole = { git = "https://github.com/wormhole-foundation/wormhol
sha3 = { version = "0.10.4" }
strum = { version = "0.24.1", features = ["derive"] }
tokio = { version = "1.26.0", features = ["full"] }
tokio-stream = { version = "0.1.15", features = ["full"] }
tonic = { version = "0.10.1", features = ["tls"] }
tower-http = { version = "0.4.0", features = ["cors"] }
tracing = { version = "0.1.37", features = ["log"] }

View File

@ -23,6 +23,7 @@ use {
mod doc_examples;
mod metrics_middleware;
mod rest;
mod sse;
pub mod types;
mod ws;
@ -143,6 +144,10 @@ pub async fn run(opts: RunOptions, state: ApiState) -> Result<()> {
.route("/api/latest_price_feeds", get(rest::latest_price_feeds))
.route("/api/latest_vaas", get(rest::latest_vaas))
.route("/api/price_feed_ids", get(rest::price_feed_ids))
.route(
"/v2/updates/price/stream",
get(sse::price_stream_sse_handler),
)
.route("/v2/updates/price/latest", get(rest::latest_price_updates))
.route(
"/v2/updates/price/:publish_time",

View File

@ -21,6 +21,7 @@ mod price_feed_ids;
mod ready;
mod v2;
pub use {
get_price_feed::*,
get_vaa::*,
@ -38,6 +39,7 @@ pub use {
},
};
#[derive(Debug)]
pub enum RestError {
BenchmarkPriceNotUnique,
UpdateDataNotFound,

173
hermes/src/api/sse.rs Normal file
View File

@ -0,0 +1,173 @@
use {
crate::{
aggregate::{
AggregationEvent,
RequestTime,
},
api::{
rest::{
verify_price_ids_exist,
RestError,
},
types::{
BinaryPriceUpdate,
EncodingType,
ParsedPriceUpdate,
PriceIdInput,
PriceUpdate,
},
ApiState,
},
},
anyhow::Result,
axum::{
extract::State,
response::sse::{
Event,
KeepAlive,
Sse,
},
},
futures::Stream,
pyth_sdk::PriceIdentifier,
serde::Deserialize,
serde_qs::axum::QsQuery,
std::convert::Infallible,
tokio::sync::broadcast,
tokio_stream::{
wrappers::BroadcastStream,
StreamExt as _,
},
utoipa::IntoParams,
};
#[derive(Debug, Deserialize, IntoParams)]
#[into_params(parameter_in = Query)]
pub struct StreamPriceUpdatesQueryParams {
/// Get the most recent price update for this set of price feed ids.
///
/// This parameter can be provided multiple times to retrieve multiple price updates,
/// for example see the following query string:
///
/// ```
/// ?ids[]=a12...&ids[]=b4c...
/// ```
#[param(rename = "ids[]")]
#[param(example = "e62df6c8b4a85fe1a67db44dc12de5db330f7ac66b72dc658afedf0f4a415b43")]
ids: Vec<PriceIdInput>,
/// If true, include the parsed price update in the `parsed` field of each returned feed.
#[serde(default)]
encoding: EncodingType,
/// If true, include the parsed price update in the `parsed` field of each returned feed.
#[serde(default = "default_true")]
parsed: bool,
}
fn default_true() -> bool {
true
}
#[utoipa::path(
get,
path = "/v2/updates/price/stream",
responses(
(status = 200, description = "Price updates retrieved successfully", body = PriceUpdate),
(status = 404, description = "Price ids not found", body = String)
),
params(StreamPriceUpdatesQueryParams)
)]
/// SSE route handler for streaming price updates.
pub async fn price_stream_sse_handler(
State(state): State<ApiState>,
QsQuery(params): QsQuery<StreamPriceUpdatesQueryParams>,
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, RestError> {
let price_ids: Vec<PriceIdentifier> = params.ids.into_iter().map(Into::into).collect();
verify_price_ids_exist(&state, &price_ids).await?;
// Clone the update_tx receiver to listen for new price updates
let update_rx: broadcast::Receiver<AggregationEvent> = state.update_tx.subscribe();
// Convert the broadcast receiver into a Stream
let stream = BroadcastStream::new(update_rx);
let sse_stream = stream.then(move |message| {
let state_clone = state.clone(); // Clone again to use inside the async block
let price_ids_clone = price_ids.clone(); // Clone again for use inside the async block
async move {
match message {
Ok(event) => {
match handle_aggregation_event(
event,
state_clone,
price_ids_clone,
params.encoding,
params.parsed,
)
.await
{
Ok(price_update) => Ok(Event::default().json_data(price_update).unwrap()),
Err(e) => Ok(error_event(e)),
}
}
Err(e) => Ok(error_event(e)),
}
}
});
Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
}
async fn handle_aggregation_event(
event: AggregationEvent,
state: ApiState,
mut price_ids: Vec<PriceIdentifier>,
encoding: EncodingType,
parsed: bool,
) -> Result<PriceUpdate> {
// We check for available price feed ids to ensure that the price feed ids provided exists since price feeds can be removed.
let available_price_feed_ids = crate::aggregate::get_price_feed_ids(&*state.state).await;
price_ids.retain(|price_feed_id| available_price_feed_ids.contains(price_feed_id));
let price_feeds_with_update_data = crate::aggregate::get_price_feeds_with_update_data(
&*state.state,
&price_ids,
RequestTime::AtSlot(event.slot()),
)
.await?;
let price_update_data = price_feeds_with_update_data.update_data;
let encoded_data: Vec<String> = price_update_data
.into_iter()
.map(|data| encoding.encode_str(&data))
.collect();
let binary_price_update = BinaryPriceUpdate {
encoding,
data: encoded_data,
};
let parsed_price_updates: Option<Vec<ParsedPriceUpdate>> = if parsed {
Some(
price_feeds_with_update_data
.price_feeds
.into_iter()
.map(|price_feed| price_feed.into())
.collect(),
)
} else {
None
};
Ok(PriceUpdate {
binary: binary_price_update,
parsed: parsed_price_updates,
})
}
fn error_event<E: std::fmt::Debug>(e: E) -> Event {
Event::default()
.event("error")
.data(format!("Error receiving update: {:?}", e))
}

View File

@ -28,14 +28,14 @@ mod state;
lazy_static! {
/// A static exit flag to indicate to running threads that we're shutting down. This is used to
/// gracefully shutdown the application.
/// gracefully shut down the application.
///
/// We make this global based on the fact the:
/// - The `Sender` side does not rely on any async runtime.
/// - Exit logic doesn't really require carefully threading this value through the app.
/// - The `Receiver` side of a watch channel performs the detection based on if the change
/// happened after the subscribe, so it means all listeners should always be notified
/// currectly.
/// correctly.
pub static ref EXIT: watch::Sender<bool> = watch::channel(false).0;
}