Use pyth-oracle types

This commit is contained in:
Ali Behjati 2023-05-26 20:21:18 +02:00 committed by Reisen
parent 15e35aa300
commit d70119c067
10 changed files with 435 additions and 854 deletions

969
hermes/Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -52,9 +52,12 @@ libp2p = { version = "0.42.2", features = [
]} ]}
async-trait = "0.1.68" async-trait = "0.1.68"
solana-client = "=1.15.2"
solana-sdk = "=1.15.2" # We around bound to this version because of pyth-oracle
solana-account-decoder = "=1.15.2" solana-client = "=1.13.3"
solana-sdk = "=1.13.3"
solana-account-decoder = "=1.13.3"
moka = { version = "0.11.0", features = ["future"] } moka = { version = "0.11.0", features = ["future"] }
derive_builder = "0.12.0" derive_builder = "0.12.0"
byteorder = "1.4.3" byteorder = "1.4.3"
@ -62,6 +65,9 @@ serde_qs = { version = "0.12.0", features = ["axum"] }
serde_wormhole = { git = "https://github.com/wormhole-foundation/wormhole", tag = "v2.17.1"} serde_wormhole = { git = "https://github.com/wormhole-foundation/wormhole", tag = "v2.17.1"}
wormhole-sdk = { git = "https://github.com/wormhole-foundation/wormhole", tag = "v2.17.1" } wormhole-sdk = { git = "https://github.com/wormhole-foundation/wormhole", tag = "v2.17.1" }
pyth-oracle = { git = "https://github.com/pyth-network/pyth-client", rev = "7d593d87e07a1e2486e7ca21597d664ee72be1ec", features = ["library"] }
strum = { version = "0.24", features = ["derive"] }
[patch.crates-io] [patch.crates-io]
serde_wormhole = { git = "https://github.com/wormhole-foundation/wormhole", tag = "v2.17.1" } serde_wormhole = { git = "https://github.com/wormhole-foundation/wormhole", tag = "v2.17.1" }

View File

@ -64,7 +64,11 @@ pub async fn spawn(rpc_addr: String, store: Store) -> Result<()> {
// FIXME use a channel to get updates from the store // FIXME use a channel to get updates from the store
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
dispatch_updates(state.store.get_price_feed_ids(), state.clone()).await; dispatch_updates(
state.store.get_price_feed_ids().into_iter().collect(),
state.clone(),
)
.await;
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
} }
}); });

View File

@ -30,6 +30,7 @@ use {
}, },
pyth_sdk::PriceIdentifier, pyth_sdk::PriceIdentifier,
serde_qs::axum::QsQuery, serde_qs::axum::QsQuery,
std::collections::HashSet,
}; };
pub enum RestError { pub enum RestError {
@ -59,7 +60,7 @@ impl IntoResponse for RestError {
pub async fn price_feed_ids( pub async fn price_feed_ids(
State(state): State<super::State>, State(state): State<super::State>,
) -> Result<Json<Vec<PriceIdentifier>>, RestError> { ) -> Result<Json<HashSet<PriceIdentifier>>, RestError> {
let price_feeds = state.store.get_price_feed_ids(); let price_feeds = state.store.get_price_feed_ids();
Ok(Json(price_feeds)) Ok(Json(price_feeds))
} }

View File

@ -1,15 +1,13 @@
use { use {
crate::{ crate::{
impl_deserialize_for_hex_string_wrapper, impl_deserialize_for_hex_string_wrapper,
store::types::{ store::types::UnixTimestamp,
PriceFeedMessage,
UnixTimestamp,
},
}, },
derive_more::{ derive_more::{
Deref, Deref,
DerefMut, DerefMut,
}, },
pyth_oracle::PriceFeedMessage,
pyth_sdk::{ pyth_sdk::{
Price, Price,
PriceIdentifier, PriceIdentifier,

View File

@ -22,6 +22,7 @@ use {
}, },
rpc_filter::{ rpc_filter::{
Memcmp, Memcmp,
MemcmpEncodedBytes,
RpcFilterType, RpcFilterType,
}, },
}, },
@ -42,10 +43,11 @@ pub async fn spawn(pythnet_ws_endpoint: String, store: Store) -> Result<()> {
encoding: Some(UiAccountEncoding::Base64Zstd), encoding: Some(UiAccountEncoding::Base64Zstd),
..Default::default() ..Default::default()
}, },
filters: Some(vec![RpcFilterType::Memcmp(Memcmp::new_raw_bytes( filters: Some(vec![RpcFilterType::Memcmp(Memcmp {
0, offset: 0,
b"PAS1".to_vec(), bytes: MemcmpEncodedBytes::Bytes(b"PAS1".to_vec()),
))]), encoding: None,
})]),
with_context: Some(true), with_context: Some(true),
..Default::default() ..Default::default()
}; };

View File

@ -20,7 +20,6 @@ use {
store_wormhole_merkle_verified_message, store_wormhole_merkle_verified_message,
}, },
types::{ types::{
Message,
MessageState, MessageState,
ProofSet, ProofSet,
WormholePayload, WormholePayload,
@ -32,8 +31,12 @@ use {
}, },
derive_builder::Builder, derive_builder::Builder,
moka::future::Cache, moka::future::Cache,
pyth_oracle::Message,
pyth_sdk::PriceIdentifier, pyth_sdk::PriceIdentifier,
std::time::Duration, std::{
collections::HashSet,
time::Duration,
},
wormhole_sdk::Vaa, wormhole_sdk::Vaa,
}; };
@ -137,7 +140,7 @@ impl Store {
.iter() .iter()
.enumerate() .enumerate()
.map(|(idx, raw_message)| { .map(|(idx, raw_message)| {
let message = Message::from_bytes(raw_message)?; let message = Message::try_from_bytes(raw_message)?;
Ok(MessageState::new( Ok(MessageState::new(
message, message,
@ -168,18 +171,15 @@ impl Store {
request_time: RequestTime, request_time: RequestTime,
) -> Result<PriceFeedsWithUpdateData> { ) -> Result<PriceFeedsWithUpdateData> {
let messages = self.storage.retrieve_message_states( let messages = self.storage.retrieve_message_states(
price_ids price_ids,
.iter()
.map(|price_id| price_id.to_bytes())
.collect(),
types::RequestType::Some(vec![MessageType::PriceFeed]),
request_time, request_time,
Some(&|message_type| *message_type == MessageType::PriceFeedMessage),
)?; )?;
let price_feeds = messages let price_feeds = messages
.iter() .iter()
.map(|message_state| match message_state.message { .map(|message_state| match message_state.message {
Message::PriceFeed(price_feed) => Ok(price_feed), Message::PriceFeedMessage(price_feed) => Ok(price_feed),
_ => Err(anyhow!("Invalid message state type")), _ => Err(anyhow!("Invalid message state type")),
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
@ -191,11 +191,7 @@ impl Store {
}) })
} }
pub fn get_price_feed_ids(&self) -> Vec<PriceIdentifier> { pub fn get_price_feed_ids(&self) -> HashSet<PriceIdentifier> {
self.storage self.storage.keys().iter().map(|key| key.price_id).collect()
.keys()
.iter()
.map(|key| PriceIdentifier::new(key.id))
.collect()
} }
} }

View File

@ -1,12 +1,12 @@
use { use {
super::types::{ super::types::{
MessageIdentifier, MessageIdentifier,
MessageKey,
MessageState, MessageState,
MessageType,
RequestTime, RequestTime,
RequestType,
}, },
anyhow::Result, anyhow::Result,
pyth_sdk::PriceIdentifier,
std::sync::Arc, std::sync::Arc,
}; };
@ -23,11 +23,11 @@ pub trait Storage: Send + Sync {
fn store_message_states(&self, message_states: Vec<MessageState>) -> Result<()>; fn store_message_states(&self, message_states: Vec<MessageState>) -> Result<()>;
fn retrieve_message_states( fn retrieve_message_states(
&self, &self,
ids: Vec<MessageIdentifier>, ids: Vec<PriceIdentifier>,
request_type: RequestType,
request_time: RequestTime, request_time: RequestTime,
filter: Option<&dyn Fn(&MessageType) -> bool>,
) -> Result<Vec<MessageState>>; ) -> Result<Vec<MessageState>>;
fn keys(&self) -> Vec<MessageKey>; fn keys(&self) -> Vec<MessageIdentifier>;
} }
pub type StorageInstance = Arc<Box<dyn Storage>>; pub type StorageInstance = Arc<Box<dyn Storage>>;

View File

@ -1,10 +1,8 @@
use { use {
super::{ super::{
MessageIdentifier, MessageIdentifier,
MessageKey,
MessageState, MessageState,
RequestTime, RequestTime,
RequestType,
Storage, Storage,
StorageInstance, StorageInstance,
}, },
@ -14,15 +12,17 @@ use {
Result, Result,
}, },
dashmap::DashMap, dashmap::DashMap,
pyth_sdk::PriceIdentifier,
std::{ std::{
collections::VecDeque, collections::VecDeque,
sync::Arc, sync::Arc,
}, },
strum::IntoEnumIterator,
}; };
#[derive(Clone)] #[derive(Clone)]
pub struct LocalStorage { pub struct LocalStorage {
cache: Arc<DashMap<MessageKey, VecDeque<MessageState>>>, cache: Arc<DashMap<MessageIdentifier, VecDeque<MessageState>>>,
max_size_per_key: usize, max_size_per_key: usize,
} }
@ -36,7 +36,7 @@ impl LocalStorage {
fn retrieve_message_state( fn retrieve_message_state(
&self, &self,
key: MessageKey, key: MessageIdentifier,
request_time: RequestTime, request_time: RequestTime,
) -> Option<MessageState> { ) -> Option<MessageState> {
match self.cache.get(&key) { match self.cache.get(&key) {
@ -109,19 +109,22 @@ impl Storage for LocalStorage {
fn retrieve_message_states( fn retrieve_message_states(
&self, &self,
ids: Vec<MessageIdentifier>, ids: Vec<PriceIdentifier>,
request_type: RequestType,
request_time: RequestTime, request_time: RequestTime,
filter: Option<&dyn Fn(&MessageType) -> bool>,
) -> Result<Vec<MessageState>> { ) -> Result<Vec<MessageState>> {
// TODO: Should we return an error if any of the ids are not found? // TODO: Should we return an error if any of the ids are not found?
let types: Vec<MessageType> = request_type.into();
ids.into_iter() ids.into_iter()
.flat_map(|id| { .flat_map(|id| {
let request_time = request_time.clone(); let request_time = request_time.clone();
types.iter().map(move |message_type| { let message_types: Vec<MessageType> = match filter {
let key = MessageKey { Some(filter) => MessageType::iter().filter(filter).collect(),
id, None => MessageType::iter().collect(),
type_: message_type.clone(), };
message_types.into_iter().map(move |message_type| {
let key = MessageIdentifier {
price_id: id,
type_: message_type,
}; };
self.retrieve_message_state(key, request_time.clone()) self.retrieve_message_state(key, request_time.clone())
.ok_or(anyhow!("Message not found")) .ok_or(anyhow!("Message not found"))
@ -130,7 +133,7 @@ impl Storage for LocalStorage {
.collect() .collect()
} }
fn keys(&self) -> Vec<MessageKey> { fn keys(&self) -> Vec<MessageIdentifier> {
self.cache.iter().map(|entry| entry.key().clone()).collect() self.cache.iter().map(|entry| entry.key().clone()).collect()
} }
} }

View File

@ -8,6 +8,12 @@ use {
Result, Result,
}, },
borsh::BorshDeserialize, borsh::BorshDeserialize,
pyth_oracle::{
Message,
PriceFeedMessage,
},
pyth_sdk::PriceIdentifier,
strum::EnumIter,
}; };
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
@ -47,22 +53,49 @@ impl WormholePayload {
} }
} }
pub type RawMessage = Vec<u8>;
pub type MessageIdentifier = [u8; 32];
#[derive(Clone, PartialEq, Eq, Debug, Hash)] // TODO: We can use strum on Message enum to derive this.
#[derive(Clone, Debug, Eq, PartialEq, Hash, EnumIter)]
pub enum MessageType { pub enum MessageType {
PriceFeed, PriceFeedMessage,
TwapPrice, TwapMessage,
} }
impl MessageType { // TODO: Move this methods to Message enum
pub fn all() -> Vec<Self> { pub trait MessageExt {
// FIXME: This is a bit brittle, guard it in the future fn type_(&self) -> MessageType;
vec![Self::PriceFeed, Self::TwapPrice] fn id(&self) -> MessageIdentifier;
fn publish_time(&self) -> UnixTimestamp;
}
impl MessageExt for Message {
fn type_(&self) -> MessageType {
match self {
Message::PriceFeedMessage(_) => MessageType::PriceFeedMessage,
Message::TwapMessage(_) => MessageType::TwapMessage,
} }
} }
fn id(&self) -> MessageIdentifier {
MessageIdentifier {
price_id: match self {
Message::PriceFeedMessage(message) => PriceIdentifier::new(message.id),
Message::TwapMessage(message) => PriceIdentifier::new(message.id),
},
type_: self.type_(),
}
}
fn publish_time(&self) -> UnixTimestamp {
match self {
Message::PriceFeedMessage(message) => message.publish_time,
Message::TwapMessage(message) => message.publish_time,
}
}
}
pub type RawMessage = Vec<u8>;
#[derive(Clone, PartialEq, Debug)] #[derive(Clone, PartialEq, Debug)]
pub struct WormholeMerkleState { pub struct WormholeMerkleState {
pub digest_proof: Vec<u8>, pub digest_proof: Vec<u8>,
@ -70,9 +103,9 @@ pub struct WormholeMerkleState {
} }
#[derive(Clone, PartialEq, Eq, Debug, Hash)] #[derive(Clone, PartialEq, Eq, Debug, Hash)]
pub struct MessageKey { pub struct MessageIdentifier {
// -> this is the real message id // -> this is the real message id
pub id: MessageIdentifier, // -> this is price feed id pub price_id: PriceIdentifier,
pub type_: MessageType, pub type_: MessageType,
} }
@ -87,13 +120,11 @@ pub struct ProofSet {
pub wormhole_merkle_proof: WormholeMerkleMessageProof, pub wormhole_merkle_proof: WormholeMerkleMessageProof,
} }
#[derive(Clone, PartialEq, Debug)] #[derive(Clone, PartialEq, Debug)]
pub struct MessageState { pub struct MessageState {
pub publish_time: UnixTimestamp, pub publish_time: UnixTimestamp,
pub slot: Slot, pub slot: Slot,
pub id: MessageIdentifier, pub id: MessageIdentifier,
pub type_: MessageType,
pub message: Message, pub message: Message,
pub raw_message: RawMessage, pub raw_message: RawMessage,
pub proof_set: ProofSet, pub proof_set: ProofSet,
@ -107,19 +138,15 @@ impl MessageState {
} }
} }
pub fn key(&self) -> MessageKey { pub fn key(&self) -> MessageIdentifier {
MessageKey { self.id.clone()
id: self.id,
type_: self.type_.clone(),
}
} }
pub fn new(message: Message, raw_message: RawMessage, proof_set: ProofSet, slot: Slot) -> Self { pub fn new(message: Message, raw_message: RawMessage, proof_set: ProofSet, slot: Slot) -> Self {
Self { Self {
publish_time: message.publish_time(), publish_time: message.publish_time(),
slot, slot,
id: *message.id(), id: message.id(),
type_: message.message_type(),
message, message,
raw_message, raw_message,
proof_set, proof_set,
@ -127,20 +154,6 @@ impl MessageState {
} }
} }
pub enum RequestType {
All,
Some(Vec<MessageType>),
}
impl From<RequestType> for Vec<MessageType> {
fn from(request_type: RequestType) -> Self {
match request_type {
RequestType::All => MessageType::all(),
RequestType::Some(types) => types,
}
}
}
pub type Slot = u64; pub type Slot = u64;
pub type UnixTimestamp = i64; pub type UnixTimestamp = i64;
@ -169,143 +182,6 @@ pub enum Update {
AccumulatorMessages(AccumulatorMessages), AccumulatorMessages(AccumulatorMessages),
} }
#[repr(C)]
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct PriceFeedMessage {
pub id: [u8; 32],
pub price: i64,
pub conf: u64,
pub exponent: i32,
pub publish_time: i64,
pub prev_publish_time: i64,
pub ema_price: i64,
pub ema_conf: u64,
}
impl PriceFeedMessage {
// The size of the serialized message. Note that this is not the same as the size of the struct
// (because of the discriminator & struct padding/alignment).
pub const MESSAGE_SIZE: usize = 1 + 32 + 8 + 8 + 4 + 8 + 8 + 8 + 8;
pub const DISCRIMINATOR: u8 = 0;
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
if bytes.len() != Self::MESSAGE_SIZE {
return Err(anyhow!("Invalid message length"));
}
let mut id = [0u8; 32];
id.copy_from_slice(&bytes[1..33]);
let price = i64::from_be_bytes(bytes[33..41].try_into()?);
let conf = u64::from_be_bytes(bytes[41..49].try_into()?);
let exponent = i32::from_be_bytes(bytes[49..53].try_into()?);
let publish_time = i64::from_be_bytes(bytes[53..61].try_into()?);
let prev_publish_time = i64::from_be_bytes(bytes[61..69].try_into()?);
let ema_price = i64::from_be_bytes(bytes[69..77].try_into()?);
let ema_conf = u64::from_be_bytes(bytes[77..85].try_into()?);
Ok(Self {
id,
price,
conf,
exponent,
publish_time,
prev_publish_time,
ema_price,
ema_conf,
})
}
}
#[repr(C)]
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct TwapMessage {
pub id: [u8; 32],
pub cumulative_price: i128,
pub cumulative_conf: u128,
pub num_down_slots: u64,
pub exponent: i32,
pub publish_time: i64,
pub prev_publish_time: i64,
pub publish_slot: u64,
}
#[allow(dead_code)]
impl TwapMessage {
// The size of the serialized message. Note that this is not the same as the size of the struct
// (because of the discriminator & struct padding/alignment).
pub const MESSAGE_SIZE: usize = 1 + 32 + 16 + 16 + 8 + 4 + 8 + 8 + 8;
pub const DISCRIMINATOR: u8 = 1;
// FIXME: Use nom or a TLV ser/de library
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
if bytes.len() != Self::MESSAGE_SIZE {
return Err(anyhow!("Invalid message length"));
}
let mut id = [0u8; 32];
id.copy_from_slice(&bytes[1..33]);
let cumulative_price = i128::from_be_bytes(bytes[33..49].try_into()?);
let cumulative_conf = u128::from_be_bytes(bytes[49..65].try_into()?);
let num_down_slots = u64::from_be_bytes(bytes[65..73].try_into()?);
let exponent = i32::from_be_bytes(bytes[73..77].try_into()?);
let publish_time = i64::from_be_bytes(bytes[77..85].try_into()?);
let prev_publish_time = i64::from_be_bytes(bytes[85..93].try_into()?);
let publish_slot = u64::from_be_bytes(bytes[93..101].try_into()?);
Ok(Self {
id,
cumulative_price,
cumulative_conf,
num_down_slots,
exponent,
publish_time,
prev_publish_time,
publish_slot,
})
}
}
#[derive(Clone, PartialEq, Debug)]
pub enum Message {
PriceFeed(PriceFeedMessage),
TwapPrice(TwapMessage),
}
impl Message {
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
match bytes[0] {
PriceFeedMessage::DISCRIMINATOR => {
Ok(Self::PriceFeed(PriceFeedMessage::from_bytes(bytes)?))
}
TwapMessage::DISCRIMINATOR => Ok(Self::TwapPrice(TwapMessage::from_bytes(bytes)?)),
_ => Err(anyhow!("Invalid message discriminator")),
}
}
pub fn message_type(&self) -> MessageType {
match self {
Self::PriceFeed(_) => MessageType::PriceFeed,
Self::TwapPrice(_) => MessageType::TwapPrice,
}
}
pub fn id(&self) -> &[u8; 32] {
match self {
Self::PriceFeed(msg) => &msg.id,
Self::TwapPrice(msg) => &msg.id,
}
}
pub fn publish_time(&self) -> i64 {
match self {
Self::PriceFeed(msg) => msg.publish_time,
Self::TwapPrice(msg) => msg.publish_time,
}
}
}
pub struct PriceFeedsWithUpdateData { pub struct PriceFeedsWithUpdateData {
pub price_feeds: Vec<PriceFeedMessage>, pub price_feeds: Vec<PriceFeedMessage>,
pub wormhole_merkle_update_data: Vec<Vec<u8>>, pub wormhole_merkle_update_data: Vec<Vec<u8>>,