diff --git a/hermes/.gitignore b/hermes/.gitignore index c3515d11..d291c12b 100644 --- a/hermes/.gitignore +++ b/hermes/.gitignore @@ -7,4 +7,4 @@ src/network/p2p.proto tools/ # Ignore Wormhole cloned repo -wormhole/ +wormhole*/ diff --git a/hermes/Cargo.lock b/hermes/Cargo.lock index cb7ecc2f..e162bad3 100644 --- a/hermes/Cargo.lock +++ b/hermes/Cargo.lock @@ -2010,6 +2010,7 @@ dependencies = [ "ethabi", "futures", "hex", + "humantime", "lazy_static", "libc", "libp2p", @@ -2030,6 +2031,7 @@ dependencies = [ "serde_qs", "serde_wormhole", "sha256", + "sha3 0.10.8", "solana-account-decoder", "solana-client", "solana-sdk", @@ -4216,6 +4218,7 @@ dependencies = [ "bincode", "borsh", "bytemuck", + "byteorder", "fast-math", "hex", "rustc_version 0.4.0", @@ -4223,6 +4226,7 @@ dependencies = [ "serde_wormhole", "sha3 0.10.8", "slow_primes", + "thiserror", "wormhole-sdk", ] diff --git a/hermes/Cargo.toml b/hermes/Cargo.toml index 0e91995e..a4c00033 100644 --- a/hermes/Cargo.toml +++ b/hermes/Cargo.toml @@ -69,6 +69,8 @@ pyth-oracle = { git = "https://github.com/pyth-network/pyth-client", rev = "7d59 strum = { version = "0.24", features = ["derive"] } ethabi = { version = "18.0.0", features = ["serde"] } +sha3 = "0.10.4" +humantime = "2.1.0" [patch.crates-io] serde_wormhole = { git = "https://github.com/wormhole-foundation/wormhole", tag = "v2.17.1" } diff --git a/hermes/build.rs b/hermes/build.rs index 570c12c2..d582733a 100644 --- a/hermes/build.rs +++ b/hermes/build.rs @@ -1,63 +1,79 @@ use std::{ env, path::PathBuf, - process::Command, + process::{ + Command, + Stdio, + }, }; fn main() { let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); let out_var = env::var("OUT_DIR").unwrap(); - // Clone the Wormhole repository, which we need to access the protobuf definitions for Wormhole - // P2P message types. + // Download the Wormhole repository at a certain tag, which we need to access the protobuf definitions + // for Wormhole P2P message types. // - // TODO: This is ugly and costly, and requires git. Instead of this we should have our own tool + // TODO: This is ugly. Instead of this we should have our own tool // build process that can generate protobuf definitions for this and other user cases. For now // this is easy and works and matches upstream Wormhole's `Makefile`. - let _ = Command::new("git") + + const WORMHOLE_VERSION: &str = "2.18.1"; + + let wh_curl = Command::new("curl") .args([ - "clone", - "https://github.com/wormhole-foundation/wormhole", - "wormhole", + "-s", + "-L", + format!("https://github.com/wormhole-foundation/wormhole/archive/refs/tags/v{WORMHOLE_VERSION}.tar.gz").as_str(), ]) + .stdout(Stdio::piped()) + .spawn() + .expect("failed to download wormhole archive"); + + let _ = Command::new("tar") + .args(["xvz"]) + .stdin(Stdio::from(wh_curl.stdout.unwrap())) .output() - .expect("failed to execute process"); + .expect("failed to extract wormhole archive"); // Move the tools directory to the root of the repo because that's where the build script // expects it to be, paths get hardcoded into the binaries. let _ = Command::new("mv") - .args(["wormhole/tools", "tools"]) + .args([ + format!("wormhole-{WORMHOLE_VERSION}/tools").as_str(), + "tools", + ]) .output() - .expect("failed to execute process"); + .expect("failed to move wormhole tools directory"); // Move the protobuf definitions to the src/network directory, we don't have to do this // but it is more intuitive when debugging. let _ = Command::new("mv") .args([ - "wormhole/proto/gossip/v1/gossip.proto", + format!("wormhole-{WORMHOLE_VERSION}/proto/gossip/v1/gossip.proto").as_str(), "src/network/p2p.proto", ]) .output() - .expect("failed to execute process"); + .expect("failed to move wormhole protobuf definitions"); // Build the protobuf compiler. let _ = Command::new("./build.sh") .current_dir("tools") .output() - .expect("failed to execute process"); + .expect("failed to run protobuf compiler build script"); // Make the protobuf compiler executable. let _ = Command::new("chmod") .args(["+x", "tools/bin/*"]) .output() - .expect("failed to execute process"); + .expect("failed to make protofuf compiler executable"); // Generate the protobuf definitions. See buf.gen.yaml to see how we rename the module for our // particular use case. let _ = Command::new("./tools/bin/buf") .args(["generate", "--path", "src"]) .output() - .expect("failed to execute process"); + .expect("failed to generate protobuf definitions"); // Build the Go library. let mut cmd = Command::new("go"); diff --git a/hermes/shell.nix b/hermes/shell.nix index d4a21111..55c79db8 100644 --- a/hermes/shell.nix +++ b/hermes/shell.nix @@ -7,7 +7,7 @@ with pkgs; mkShell { clang llvmPackages.libclang nettle - openssl + openssl_1_1 pkgconfig rustup systemd diff --git a/hermes/src/api.rs b/hermes/src/api.rs index 62d9c3b9..70a2f2c9 100644 --- a/hermes/src/api.rs +++ b/hermes/src/api.rs @@ -7,6 +7,10 @@ use { Router, }, std::sync::Arc, + tokio::{ + signal, + sync::mpsc::Receiver, + }, }; mod rest; @@ -15,12 +19,12 @@ mod ws; #[derive(Clone)] pub struct State { - pub store: Store, + pub store: Arc, pub ws: Arc, } impl State { - pub fn new(store: Store) -> Self { + pub fn new(store: Arc) -> Self { Self { store, ws: Arc::new(ws::WsState::new()), @@ -32,7 +36,7 @@ impl State { /// /// Currently this is based on Axum due to the simplicity and strong ecosystem support for the /// packages they are based on (tokio & hyper). -pub async fn spawn(store: Store, rpc_addr: String) -> Result<()> { +pub async fn run(store: Arc, mut update_rx: Receiver<()>, rpc_addr: String) -> Result<()> { let state = State::new(store); // Initialize Axum Router. Note the type here is a `Router` due to the use of the @@ -50,28 +54,31 @@ pub async fn spawn(store: Store, rpc_addr: String) -> Result<()> { .with_state(state.clone()); - // Binds the axum's server to the configured address and port. This is a blocking call and will - // not return until the server is shutdown. - tokio::spawn(async move { - // FIXME handle errors properly - axum::Server::bind(&rpc_addr.parse().unwrap()) - .serve(app.into_make_service()) - .await - .unwrap(); - }); - // Call dispatch updates to websocket every 1 seconds // FIXME use a channel to get updates from the store tokio::spawn(async move { loop { - dispatch_updates( - state.store.get_price_feed_ids().into_iter().collect(), - state.clone(), - ) - .await; - tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + // Panics if the update channel is closed, which should never happen. + // If it happens we have no way to recover, so we just panic. + update_rx + .recv() + .await + .expect("state update channel is closed"); + + dispatch_updates(state.clone()).await; } }); + // Binds the axum's server to the configured address and port. This is a blocking call and will + // not return until the server is shutdown. + axum::Server::try_bind(&rpc_addr.parse()?)? + .serve(app.into_make_service()) + .with_graceful_shutdown(async { + signal::ctrl_c() + .await + .expect("Ctrl-c signal handler failed."); + }) + .await?; + Ok(()) } diff --git a/hermes/src/api/rest.rs b/hermes/src/api/rest.rs index e1c164e0..4dd1b3a0 100644 --- a/hermes/src/api/rest.rs +++ b/hermes/src/api/rest.rs @@ -36,6 +36,7 @@ use { pub enum RestError { UpdateDataNotFound, CcipUpdateDataNotFound, + InvalidCCIPInput, } impl IntoResponse for RestError { @@ -54,6 +55,9 @@ impl IntoResponse for RestError { (StatusCode::BAD_GATEWAY, "CCIP update data not found").into_response() } + RestError::InvalidCCIPInput => { + (StatusCode::BAD_REQUEST, "Invalid CCIP input").into_response() + } } } } @@ -113,7 +117,7 @@ pub async fn latest_price_feeds( .price_feeds .into_iter() .map(|price_feed| { - RpcPriceFeed::from_price_feed_message(price_feed, params.verbose, params.binary) + RpcPriceFeed::from_price_feed_update(price_feed, params.verbose, params.binary) }) .collect(), )) @@ -156,7 +160,8 @@ pub async fn get_vaa( .price_feeds .get(0) .ok_or(RestError::UpdateDataNotFound)? - .publish_time; // TODO: This should never happen. + .price_feed + .publish_time; Ok(Json(GetVaaResponse { vaa, publish_time })) } @@ -179,8 +184,16 @@ pub async fn get_vaa_ccip( State(state): State, QsQuery(params): QsQuery, ) -> Result, RestError> { - let price_id: PriceIdentifier = PriceIdentifier::new(params.data[0..32].try_into().unwrap()); - let publish_time = UnixTimestamp::from_be_bytes(params.data[32..40].try_into().unwrap()); + let price_id: PriceIdentifier = PriceIdentifier::new( + params.data[0..32] + .try_into() + .map_err(|_| RestError::InvalidCCIPInput)?, + ); + let publish_time = UnixTimestamp::from_be_bytes( + params.data[32..40] + .try_into() + .map_err(|_| RestError::InvalidCCIPInput)?, + ); let price_feeds_with_update_data = state .store diff --git a/hermes/src/api/types.rs b/hermes/src/api/types.rs index 6fdd214d..03fd159d 100644 --- a/hermes/src/api/types.rs +++ b/hermes/src/api/types.rs @@ -1,17 +1,25 @@ use { crate::{ impl_deserialize_for_hex_string_wrapper, - store::types::UnixTimestamp, + store::types::{ + PriceFeedUpdate, + Slot, + UnixTimestamp, + }, + }, + base64::{ + engine::general_purpose::STANDARD as base64_standard_engine, + Engine as _, }, derive_more::{ Deref, DerefMut, }, - pyth_oracle::PriceFeedMessage, pyth_sdk::{ Price, PriceIdentifier, }, + wormhole_sdk::Chain, }; @@ -34,8 +42,8 @@ type Base64String = String; #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct RpcPriceFeedMetadata { + pub slot: Slot, pub emitter_chain: u16, - pub sequence_number: u64, pub price_service_receive_time: UnixTimestamp, } @@ -54,11 +62,13 @@ pub struct RpcPriceFeed { impl RpcPriceFeed { // TODO: Use a Encoding type to have None, Base64, and Hex variants instead of binary flag. // TODO: Use a Verbosity type to define None, or Full instead of verbose flag. - pub fn from_price_feed_message( - price_feed_message: PriceFeedMessage, - _verbose: bool, - _binary: bool, + pub fn from_price_feed_update( + price_feed_update: PriceFeedUpdate, + verbose: bool, + binary: bool, ) -> Self { + let price_feed_message = price_feed_update.price_feed; + Self { id: PriceIdentifier::new(price_feed_message.id), price: Price { @@ -73,16 +83,14 @@ impl RpcPriceFeed { expo: price_feed_message.exponent, publish_time: price_feed_message.publish_time, }, - // FIXME: Handle verbose flag properly. - // metadata: verbose.then_some(RpcPriceFeedMetadata { - // emitter_chain: price_feed_message.emitter_chain, - // sequence_number: price_feed_message.sequence_number, - // price_service_receive_time: price_feed_message.receive_time, - // }), - metadata: None, - // FIXME: The vaa is wrong, fix it - // vaa: binary.then_some(base64_standard_engine.encode(message_state.proof_set.wormhole_merkle_proof.vaa)), - vaa: None, + metadata: verbose.then_some(RpcPriceFeedMetadata { + emitter_chain: Chain::Pythnet.into(), + price_service_receive_time: price_feed_update.received_at, + slot: price_feed_update.slot, + }), + vaa: binary.then_some( + base64_standard_engine.encode(price_feed_update.wormhole_merkle_update_data), + ), } } } diff --git a/hermes/src/api/ws.rs b/hermes/src/api/ws.rs index a7a8a0db..0ac347ea 100644 --- a/hermes/src/api/ws.rs +++ b/hermes/src/api/ws.rs @@ -7,7 +7,10 @@ use { types::RequestTime, Store, }, - anyhow::Result, + anyhow::{ + anyhow, + Result, + }, axum::{ extract::{ ws::{ @@ -23,6 +26,7 @@ use { futures::{ future::join_all, stream::{ + iter, SplitSink, SplitStream, }, @@ -37,9 +41,12 @@ use { std::{ collections::HashMap, pin::Pin, - sync::atomic::{ - AtomicUsize, - Ordering, + sync::{ + atomic::{ + AtomicUsize, + Ordering, + }, + Arc, }, time::Duration, }, @@ -63,7 +70,7 @@ async fn websocket_handler(stream: WebSocket, state: super::State) { // TODO: Use a configured value for the buffer size or make it const static // TODO: Use redis stream to source the updates instead of a channel - let (tx, rx) = mpsc::channel::>(1000); + let (tx, rx) = mpsc::channel::<()>(1000); ws_state.subscribers.insert(id, tx); @@ -81,8 +88,8 @@ pub type SubscriberId = usize; pub struct Subscriber { id: SubscriberId, closed: bool, - store: Store, - update_rx: mpsc::Receiver>, + store: Arc, + update_rx: mpsc::Receiver<()>, receiver: SplitStream, sender: SplitSink, price_feeds_with_config: HashMap, @@ -93,8 +100,8 @@ pub struct Subscriber { impl Subscriber { pub fn new( id: SubscriberId, - store: Store, - update_rx: mpsc::Receiver>, + store: Arc, + update_rx: mpsc::Receiver<()>, receiver: SplitStream, sender: SplitSink, ) -> Self { @@ -114,7 +121,7 @@ impl Subscriber { pub async fn run(&mut self) { while !self.closed { if let Err(e) = self.handle_next().await { - log::error!("Subscriber {}: Error handling next message: {}", self.id, e); + log::warn!("Subscriber {}: Error handling next message: {}", self.id, e); break; } } @@ -122,11 +129,11 @@ impl Subscriber { async fn handle_next(&mut self) -> Result<()> { tokio::select! { - maybe_update_feed_ids = self.update_rx.recv() => { - let update_feed_ids = maybe_update_feed_ids.ok_or_else(|| { - anyhow::anyhow!("Update channel closed.") - })?; - self.handle_price_feeds_update(update_feed_ids).await?; + maybe_update_feeds = self.update_rx.recv() => { + if maybe_update_feeds.is_none() { + return Err(anyhow!("Update channel closed. This should never happen. Closing connection.")); + }; + self.handle_price_feeds_update().await?; }, maybe_message_or_err = self.receiver.next() => { match maybe_message_or_err { @@ -140,7 +147,7 @@ impl Subscriber { }, _ = &mut self.ping_interval_future => { if !self.responded_to_ping { - log::debug!("Subscriber {} did not respond to ping, closing connection.", self.id); + log::debug!("Subscriber {} did not respond to ping. Closing connection.", self.id); self.closed = true; return Ok(()); } @@ -153,37 +160,32 @@ impl Subscriber { Ok(()) } - async fn handle_price_feeds_update( - &mut self, - price_feed_ids: Vec, - ) -> Result<()> { - for price_feed_id in price_feed_ids { - if let Some(config) = self.price_feeds_with_config.get(&price_feed_id) { + async fn handle_price_feeds_update(&mut self) -> Result<()> { + let messages = self + .price_feeds_with_config + .iter() + .map(|(price_feed_id, config)| { let price_feeds_with_update_data = self .store - .get_price_feeds_with_update_data(vec![price_feed_id], RequestTime::Latest)?; - let price_feed = *price_feeds_with_update_data + .get_price_feeds_with_update_data(vec![*price_feed_id], RequestTime::Latest)?; + let price_feed = price_feeds_with_update_data .price_feeds - .iter() - .find(|price_feed| price_feed.id == price_feed_id.to_bytes()) + .into_iter() + .next() .ok_or_else(|| { anyhow::anyhow!("Price feed {} not found.", price_feed_id.to_string()) })?; - let price_feed = RpcPriceFeed::from_price_feed_message( - price_feed, - config.verbose, - config.binary, - ); - // Feed does not flush the message and will allow us - // to send multiple messages in a single flush. - self.sender - .feed(Message::Text(serde_json::to_string( - &ServerMessage::PriceUpdate { price_feed }, - )?)) - .await?; - } - } - self.sender.flush().await?; + let price_feed = + RpcPriceFeed::from_price_feed_update(price_feed, config.verbose, config.binary); + + Ok(Message::Text(serde_json::to_string( + &ServerMessage::PriceUpdate { price_feed }, + )?)) + }) + .collect::>>()?; + self.sender + .send_all(&mut iter(messages.into_iter().map(Ok))) + .await?; Ok(()) } @@ -253,16 +255,15 @@ impl Subscriber { } } -pub async fn dispatch_updates(update_feed_ids: Vec, state: super::State) { +pub async fn dispatch_updates(state: super::State) { let ws_state = state.ws.clone(); - let update_feed_ids_ref = &update_feed_ids; let closed_subscribers: Vec> = join_all( ws_state .subscribers .iter_mut() .map(|subscriber| async move { - match subscriber.send(update_feed_ids_ref.clone()).await { + match subscriber.send(()).await { Ok(_) => None, Err(e) => { log::debug!("Error sending update to subscriber: {}", e); @@ -289,7 +290,7 @@ pub struct PriceFeedClientConfig { pub struct WsState { pub subscriber_counter: AtomicUsize, - pub subscribers: DashMap>>, + pub subscribers: DashMap>, } impl WsState { diff --git a/hermes/src/config.rs b/hermes/src/config.rs index 7529553f..2d871b15 100644 --- a/hermes/src/config.rs +++ b/hermes/src/config.rs @@ -44,5 +44,21 @@ pub enum Options { /// The address to bind the API server to. #[structopt(long, default_value = "127.0.0.1:33999")] api_addr: SocketAddr, + + /// Ethereum RPC endpoint + #[structopt(long, env = "ETH_RPC_ENDPOINT")] + eth_rpc_endpoint: String, + + /// Wormhole contract address on Ethereum + #[structopt( + long, + env = "WORMHOLE_CONTRACT_ADDRESS", + default_value = "0x98f3c9e6E3fAce36bAAd05FE09d375Ef1464288B" + )] + wormhole_eth_contract_address: String, + + /// Ethereum RPC polling duration + #[structopt(long, env = "ETH_POLLING_INTERVAL", default_value = "10s")] + eth_polling_interval: humantime::Duration, }, } diff --git a/hermes/src/main.rs b/hermes/src/main.rs index 2dc57151..9f4c1c72 100644 --- a/hermes/src/main.rs +++ b/hermes/src/main.rs @@ -26,12 +26,15 @@ async fn init() -> Result<()> { wh_bootstrap_addrs, wh_listen_addrs, api_addr, + eth_rpc_endpoint, + wormhole_eth_contract_address, + eth_polling_interval, } => { - log::info!("Running Hermes..."); - let store = Store::new_with_local_cache(1000); + // A channel to emit state updates to api + let (update_tx, update_rx) = tokio::sync::mpsc::channel(1000); - // FIXME: Instead of spawing threads separately, we should handle all their - // errors properly. + log::info!("Running Hermes..."); + let store = Store::new_with_local_cache(update_tx, 1000); // Spawn the P2P layer. log::info!("Starting P2P server on {:?}", wh_listen_addrs); @@ -44,26 +47,25 @@ async fn init() -> Result<()> { .await?; // Spawn the Ethereum guardian set watcher + log::info!( + "Starting Ethereum guardian set watcher using {}", + eth_rpc_endpoint + ); network::ethereum::spawn( store.clone(), - "https://rpc.ankr.com/eth".to_owned(), - "0x98f3c9e6E3fAce36bAAd05FE09d375Ef1464288B".to_owned(), + eth_rpc_endpoint, + wormhole_eth_contract_address, + eth_polling_interval.into(), ) - .await; - - // Spawn the RPC server. - log::info!("Starting RPC server on {}", api_addr); - - // TODO: Add max size to the config - api::spawn(store.clone(), api_addr.to_string()).await?; + .await?; // Spawn the Pythnet listener - // TODO: Exit the thread when it gets spawned log::info!("Starting Pythnet listener using {}", pythnet_ws_endpoint); network::pythnet::spawn(store.clone(), pythnet_ws_endpoint).await?; - // Wait on Ctrl+C similar to main. - tokio::signal::ctrl_c().await?; + // Run the RPC server and wait for it to shutdown gracefully. + log::info!("Starting RPC server on {}", api_addr); + api::run(store.clone(), update_rx, api_addr.to_string()).await?; } } diff --git a/hermes/src/network/ethereum.rs b/hermes/src/network/ethereum.rs index 2dc51788..96c27a68 100644 --- a/hermes/src/network/ethereum.rs +++ b/hermes/src/network/ethereum.rs @@ -17,6 +17,11 @@ use { }, ethabi::Function, reqwest::Client, + std::{ + sync::Arc, + time::Duration, + }, + tokio::time::Instant, wormhole_sdk::GuardianAddress, }; @@ -49,17 +54,25 @@ async fn query( let res = res .get("result") - .ok_or(anyhow!("Invalid RPC Response, 'result' not found"))? + .ok_or(anyhow!( + "Invalid RPC Response, 'result' not found. {:?}", + res + ))? .as_str() - .ok_or(anyhow!("Invalid result"))?; + .ok_or(anyhow!("Invalid result. {:?}", res))?; - let res = hex::decode(&res[2..]).unwrap(); + let res = hex::decode(&res[2..])?; let res = method.decode_output(&res)?; Ok(res) } -async fn run(store: Store, rpc_endpoint: String, wormhole_contract: String) -> Result<()> { +async fn run( + store: Arc, + rpc_endpoint: String, + wormhole_contract: String, + polling_interval: Duration, +) -> Result { loop { let get_current_index_method = serde_json::from_str::( r#"{"inputs":[],"name":"getCurrentGuardianSetIndex","outputs":[{"internalType":"uint32","name":"","type":"uint32"}], @@ -126,32 +139,49 @@ async fn run(store: Store, rpc_endpoint: String, wormhole_contract: String) -> R log::info!("Guardian set: {:?}", guardian_set); - store - .update_guardian_set( - guardian_set - .into_iter() - .map(|address| GuardianAddress(address.0)) - .collect(), - ) - .await; + let store = store.clone(); + tokio::spawn(async move { + store + .update_guardian_set( + guardian_set + .into_iter() + .map(|address| GuardianAddress(address.0)) + .collect(), + ) + .await; + }); - tokio::time::sleep(tokio::time::Duration::from_secs(10)).await; + tokio::time::sleep(polling_interval).await; } } -pub async fn spawn(store: Store, rpc_endpoint: String, wormhole_contract: String) { +pub async fn spawn( + store: Arc, + rpc_endpoint: String, + wormhole_contract: String, + polling_interval: Duration, +) -> Result<()> { tokio::spawn(async move { loop { - if let Err(e) = run( + let current_time = Instant::now(); + + if let Err(ref e) = run( store.clone(), rpc_endpoint.clone(), wormhole_contract.clone(), + polling_interval, ) .await { - log::error!("Error in ethereum network: {}", e); - // TODO: Add a backoff here. + log::error!("Error in Ethereum network listener: {:?}", e); + } + + if current_time.elapsed() < Duration::from_secs(30) { + log::error!("Ethereum network listener is restarting too quickly. Sleeping for 1s"); + tokio::time::sleep(Duration::from_secs(1)).await; } } }); + + Ok(()) } diff --git a/hermes/src/network/p2p.go b/hermes/src/network/p2p.go index 3f06d5e4..6b256a10 100644 --- a/hermes/src/network/p2p.go +++ b/hermes/src/network/p2p.go @@ -28,7 +28,9 @@ import "C" import ( "context" "fmt" + "os" "strings" + "time" "github.com/libp2p/go-libp2p" "github.com/libp2p/go-libp2p/core/crypto" @@ -44,6 +46,7 @@ import ( pubsub "github.com/libp2p/go-libp2p-pubsub" libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls" libp2pquic "github.com/libp2p/go-libp2p/p2p/transport/quic" + libp2pquicreuse "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" ) //export RegisterObservationCallback @@ -52,11 +55,24 @@ func RegisterObservationCallback(f C.callback_t, network_id, bootstrap_addrs, li bootstrapAddrs := strings.Split(C.GoString(bootstrap_addrs), ",") listenAddrs := strings.Split(C.GoString(listen_addrs), ",") - go func() { + var startTime int64 + var recoverRerun func() + + routine := func() { + defer recoverRerun() + + // Record the current time + startTime = time.Now().UnixNano() + ctx := context.Background() // Setup base network configuration. priv, _, err := crypto.GenerateKeyPair(crypto.Ed25519, -1) + if err != nil { + err := fmt.Errorf("Failed to generate key pair: %w", err) + fmt.Println(err) + return + } // Setup libp2p Connection Manager. mgr, err := connmgr.NewConnManager( @@ -76,6 +92,28 @@ func RegisterObservationCallback(f C.callback_t, network_id, bootstrap_addrs, li libp2p.Identity(priv), libp2p.ListenAddrStrings(listenAddrs...), libp2p.Security(libp2ptls.ID, libp2ptls.New), + // Disable Reuse because upon panic, the Close() call on the p2p reactor does not properly clean up the + // open ports (they are kept around for re-use, this seems to be a libp2p bug in the reuse `gc()` call + // which can be found here: + // + // https://github.com/libp2p/go-libp2p/blob/master/p2p/transport/quicreuse/reuse.go#L97 + // + // By disabling this we get correct Close() behaviour. + // + // IMPORTANT: Normally re-use allows libp2p to dial on the same port that is used to listen for traffic + // and by disabling this dialing uses a random high port (32768-60999) which causes the nodes that we + // connect to by dialing (instead of them connecting to us) will respond on the high range port instead + // of the specified Dial port. This requires firewalls to be configured to allow (UDP 32768-60999) which + // should be specified in our documentation. + // + // The best way to securely enable this range is via the conntrack module, which can statefully allow + // UDP packets only when a sent UDP packet is present in the conntrack table. This rule looks roughly + // like this: + // + // iptables -A INPUT -m conntrack --ctstate RELATED,ESTABLISHED -j ACCEPT + // + // Which is a standard rule in many firewall configurations (RELATED is the key flag). + libp2p.QUICReuse(libp2pquicreuse.NewConnManager, libp2pquicreuse.DisableReuseport()), libp2p.Transport(libp2pquic.NewTransport), libp2p.ConnectionManager(mgr), libp2p.Routing(func(h host.Host) (routing.PeerRouting, error) { @@ -107,6 +145,8 @@ func RegisterObservationCallback(f C.callback_t, network_id, bootstrap_addrs, li return } + defer h.Close() + topic := fmt.Sprintf("%s/%s", networkID, "broadcast") ps, err := pubsub.NewGossipSub(ctx, h) if err != nil { @@ -122,6 +162,8 @@ func RegisterObservationCallback(f C.callback_t, network_id, bootstrap_addrs, li return } + defer th.Close() + sub, err := th.Subscribe() if err != nil { err := fmt.Errorf("Failed to subscribe topic: %w", err) @@ -129,6 +171,8 @@ func RegisterObservationCallback(f C.callback_t, network_id, bootstrap_addrs, li return } + defer sub.Cancel() + for { for { select { @@ -160,7 +204,25 @@ func RegisterObservationCallback(f C.callback_t, network_id, bootstrap_addrs, li } } } - }() + } + + recoverRerun = func() { + // Print the error if any and recall routine + if err := recover(); err != nil { + fmt.Fprintf(os.Stderr, "p2p.go error: %v\n", err) + } + + // Sleep for 1 second if the time elapsed is less than 30 seconds + // to avoid spamming the network with requests. + elapsed := time.Duration(time.Now().UnixNano() - startTime) + if elapsed < 30*time.Second { + time.Sleep(1 * time.Second) + } + + go routine() + } + + go routine() } func main() { diff --git a/hermes/src/network/p2p.rs b/hermes/src/network/p2p.rs index c74ca19d..7d1d5402 100644 --- a/hermes/src/network/p2p.rs +++ b/hermes/src/network/p2p.rs @@ -26,6 +26,7 @@ use { Receiver, Sender, }, + Arc, Mutex, }, }, @@ -70,8 +71,14 @@ lazy_static::lazy_static! { extern "C" fn proxy(o: ObservationC) { // Create a fixed slice from the pointer and length. let vaa = unsafe { std::slice::from_raw_parts(o.vaa, o.vaa_len) }.to_owned(); - // FIXME: Remove unwrap - if let Err(e) = OBSERVATIONS.0.lock().unwrap().send(vaa) { + // The chances of the mutex getting poisioned is very low and if it happens + // there is no way for us to recover from it. + if let Err(e) = OBSERVATIONS + .0 + .lock() + .expect("Cannot acquire p2p channel lock") + .send(vaa) + { log::error!("Failed to send observation: {}", e); } } @@ -116,40 +123,46 @@ pub fn bootstrap( // Spawn's the P2P layer as a separate thread via Go. pub async fn spawn( - store: Store, + store: Arc, network_id: String, wh_bootstrap_addrs: Vec, wh_listen_addrs: Vec, ) -> Result<()> { - bootstrap(network_id, wh_bootstrap_addrs, wh_listen_addrs)?; + std::thread::spawn(|| bootstrap(network_id, wh_bootstrap_addrs, wh_listen_addrs).unwrap()); - // Listen in the background for new VAA's from the p2p layer - // and update the state accordingly. tokio::spawn(async move { + // Listen in the background for new VAA's from the p2p layer + // and update the state accordingly. loop { - let vaa_bytes = { + let vaa_bytes = tokio::task::spawn_blocking(|| { let observation = OBSERVATIONS.1.lock(); - let observation = match observation { Ok(observation) => observation, Err(e) => { - log::error!("Failed to lock observation channel: {}", e); - return; + // This should never happen, but if it does, we want to panic and crash + // as it is not recoverable. + panic!("Failed to lock p2p observation channel: {e}"); } }; match observation.recv() { Ok(vaa_bytes) => vaa_bytes, Err(e) => { - log::error!("Failed to receive observation: {}", e); - return; + // This should never happen, but if it does, we want to panic and crash + // as it is not recoverable. + panic!("Failed to receive p2p observation: {e}"); } } - }; + }) + .await + .unwrap(); - if let Err(e) = store.store_update(Update::Vaa(vaa_bytes)).await { - log::error!("Failed to process VAA: {:?}", e); - } + let store = store.clone(); + tokio::spawn(async move { + if let Err(e) = store.store_update(Update::Vaa(vaa_bytes)).await { + log::error!("Failed to process VAA: {:?}", e); + } + }); } }); diff --git a/hermes/src/network/pythnet.rs b/hermes/src/network/pythnet.rs index b2fbc43b..60fbc1e8 100644 --- a/hermes/src/network/pythnet.rs +++ b/hermes/src/network/pythnet.rs @@ -10,7 +10,10 @@ use { }, Store, }, - anyhow::Result, + anyhow::{ + anyhow, + Result, + }, borsh::BorshDeserialize, futures::stream::StreamExt, solana_account_decoder::UiAccountEncoding, @@ -32,9 +35,14 @@ use { pubkey::Pubkey, system_program, }, + std::{ + sync::Arc, + time::Duration, + }, + tokio::time::Instant, }; -pub async fn spawn(store: Store, pythnet_ws_endpoint: String) -> Result<()> { +pub async fn run(store: Arc, pythnet_ws_endpoint: String) -> Result { let client = PubsubClient::new(pythnet_ws_endpoint.as_ref()).await?; let config = RpcProgramAccountsConfig { @@ -56,40 +64,73 @@ pub async fn spawn(store: Store, pythnet_ws_endpoint: String) -> Result<()> { .await?; loop { - let update = notif.next().await; - log::debug!("Received Pythnet update: {:?}", update); - - if let Some(update) = update { - let account: Account = update.value.account.decode().unwrap(); - log::debug!("Received Accumulator update: {:?}", account); - - let accumulator_messages = AccumulatorMessages::try_from_slice(&account.data); - match accumulator_messages { - Ok(accumulator_messages) => { - let (candidate, _) = Pubkey::find_program_address( - &[ - b"AccumulatorState", - &accumulator_messages.ring_index().to_be_bytes(), - ], - &system_program::id(), - ); - - if candidate.to_string() == update.value.pubkey { - store - .store_update(Update::AccumulatorMessages(accumulator_messages)) - .await?; - } else { - log::error!( - "Failed to verify the messages public key: {:?} != {:?}", - candidate, - update.value.pubkey - ); + match notif.next().await { + Some(update) => { + let account: Account = match update.value.account.decode() { + Some(account) => account, + None => { + log::error!("Failed to decode account from update: {:?}", update); + continue; } - } - Err(err) => { - log::error!("Failed to parse AccumulatorMessages: {:?}", err); - } - }; + }; + + let accumulator_messages = AccumulatorMessages::try_from_slice(&account.data); + match accumulator_messages { + Ok(accumulator_messages) => { + let (candidate, _) = Pubkey::find_program_address( + &[ + b"AccumulatorState", + &accumulator_messages.ring_index().to_be_bytes(), + ], + &system_program::id(), + ); + + if candidate.to_string() == update.value.pubkey { + let store = store.clone(); + tokio::spawn(async move { + if let Err(err) = store + .store_update(Update::AccumulatorMessages(accumulator_messages)) + .await + { + log::error!("Failed to store accumulator messages: {:?}", err); + } + }); + } else { + log::error!( + "Failed to verify the messages public key: {:?} != {:?}", + candidate, + update.value.pubkey + ); + } + } + Err(err) => { + log::error!("Failed to parse AccumulatorMessages: {:?}", err); + } + }; + } + None => { + return Err(anyhow!("Pythnet network listener terminated")); + } } } } + + +pub async fn spawn(store: Arc, pythnet_ws_endpoint: String) -> Result<()> { + tokio::spawn(async move { + loop { + let current_time = Instant::now(); + + if let Err(ref e) = run(store.clone(), pythnet_ws_endpoint.clone()).await { + log::error!("Error in Pythnet network listener: {:?}", e); + } + + if current_time.elapsed() < Duration::from_secs(30) { + log::error!("Pythnet network listener is restarting too quickly. Sleeping for 1s"); + tokio::time::sleep(Duration::from_secs(1)).await; + } + } + }); + + Ok(()) +} diff --git a/hermes/src/store.rs b/hermes/src/store.rs index 41b5b17c..3bef4ea3 100644 --- a/hermes/src/store.rs +++ b/hermes/src/store.rs @@ -1,13 +1,11 @@ use { self::{ - proof::wormhole_merkle::{ - construct_update_data, - WormholeMerkleProof, - }, + proof::wormhole_merkle::construct_update_data, storage::StorageInstance, types::{ AccumulatorMessages, MessageType, + PriceFeedUpdate, PriceFeedsWithUpdateData, RequestTime, Slot, @@ -22,11 +20,9 @@ use { types::{ MessageState, ProofSet, + UnixTimestamp, }, - wormhole::{ - parse_and_verify_vaa, - WormholePayload, - }, + wormhole::parse_and_verify_vaa, }, anyhow::{ anyhow, @@ -36,12 +32,24 @@ use { moka::future::Cache, pyth_oracle::Message, pyth_sdk::PriceIdentifier, + pythnet_sdk::payload::v1::{ + WormholeMerkleRoot, + WormholeMessage, + WormholePayload, + }, std::{ collections::HashSet, sync::Arc, - time::Duration, + time::{ + Duration, + SystemTime, + UNIX_EPOCH, + }, + }, + tokio::sync::{ + mpsc::Sender, + RwLock, }, - tokio::sync::RwLock, wormhole_sdk::{ Address, Chain, @@ -58,30 +66,27 @@ pub mod wormhole; #[builder(derive(Debug), pattern = "immutable")] pub struct AccumulatorState { pub accumulator_messages: AccumulatorMessages, - pub wormhole_merkle_proof: WormholeMerkleProof, + pub wormhole_merkle_proof: (WormholeMerkleRoot, Vec), } -#[derive(Clone)] pub struct Store { pub storage: StorageInstance, pub pending_accumulations: Cache, - pub guardian_set: Arc>>>, + pub guardian_set: RwLock>>, + pub update_tx: Sender<()>, } impl Store { - pub fn new_with_local_cache(max_size_per_key: usize) -> Self { - // TODO: Should we return an Arc? Although we are currently safe to be cloned without - // an Arc but it is easily to miss and cause a bug. - Self { - storage: storage::local_storage::LocalStorage::new_instance( - max_size_per_key, - ), + pub fn new_with_local_cache(update_tx: Sender<()>, max_size_per_key: usize) -> Arc { + Arc::new(Self { + storage: storage::local_storage::LocalStorage::new_instance(max_size_per_key), pending_accumulations: Cache::builder() .max_capacity(10_000) .time_to_live(Duration::from_secs(60 * 5)) .build(), // FIXME: Make this configurable - guardian_set: Arc::new(RwLock::new(None)), - } + guardian_set: RwLock::new(None), + update_tx, + }) } /// Stores the update data in the store @@ -103,30 +108,19 @@ impl Store { return Ok(()); // Ignore VAA from other emitters } - let payload = WormholePayload::try_from_bytes(body.payload, &vaa_bytes)?; - - match payload { + match WormholeMessage::try_from_bytes(body.payload)?.payload { WormholePayload::Merkle(proof) => { - log::info!( - "Storing merkle proof for slot {:?}: {:?}", - proof.slot, - proof - ); - store_wormhole_merkle_verified_message(self, proof.clone()).await?; + log::info!("Storing merkle proof for slot {:?}", proof.slot,); + store_wormhole_merkle_verified_message(self, proof.clone(), vaa_bytes) + .await?; proof.slot } } } Update::AccumulatorMessages(accumulator_messages) => { - // FIXME: Move this constant to a better place - let slot = accumulator_messages.slot; - log::info!( - "Storing accumulator messages for slot {:?}: {:?}", - slot, - accumulator_messages - ); + log::info!("Storing accumulator messages for slot {:?}.", slot,); let pending_acc = self .pending_accumulations @@ -154,10 +148,11 @@ impl Store { Err(_) => return Ok(()), }; - log::info!("State: {:?}", state); - let wormhole_merkle_message_states_proofs = construct_message_states_proofs(state.clone())?; + let current_time: UnixTimestamp = + SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() as _; + let message_states = state .accumulator_messages .messages @@ -176,16 +171,19 @@ impl Store { .clone(), }, state.accumulator_messages.slot, + current_time, )) }) .collect::>>()?; - log::info!("Message states: {:?}", message_states); + log::info!("Message states len: {:?}", message_states.len()); self.storage.store_message_states(message_states)?; self.pending_accumulations.invalidate(&slot).await; + self.update_tx.send(()).await?; + Ok(()) } @@ -207,11 +205,20 @@ impl Store { let price_feeds = messages .iter() .map(|message_state| match message_state.message { - Message::PriceFeedMessage(price_feed) => Ok(price_feed), + Message::PriceFeedMessage(price_feed) => Ok(PriceFeedUpdate { + price_feed, + received_at: message_state.received_at, + slot: message_state.slot, + wormhole_merkle_update_data: construct_update_data(vec![message_state])? + .into_iter() + .next() + .ok_or(anyhow!("Missing update data for message"))?, + }), _ => Err(anyhow!("Invalid message state type")), }) .collect::>>()?; - let update_data = construct_update_data(messages)?; + + let update_data = construct_update_data(messages.iter().collect())?; Ok(PriceFeedsWithUpdateData { price_feeds, diff --git a/hermes/src/store/proof/wormhole_merkle.rs b/hermes/src/store/proof/wormhole_merkle.rs index 2a2da954..c0ec3bd7 100644 --- a/hermes/src/store/proof/wormhole_merkle.rs +++ b/hermes/src/store/proof/wormhole_merkle.rs @@ -8,10 +8,6 @@ use { anyhow, Result, }, - byteorder::{ - BigEndian, - WriteBytesExt, - }, pythnet_sdk::{ accumulators::{ merkle::{ @@ -21,23 +17,16 @@ use { Accumulator, }, hashers::keccak256_160::Keccak160, - }, - std::io::{ - Cursor, - Write, + payload::v1::{ + AccumulatorUpdateData, + MerklePriceUpdate, + Proof, + WormholeMerkleRoot, + }, + ser::to_vec, }, }; -type Hash = [u8; 20]; - -#[derive(Clone, PartialEq, Debug)] -pub struct WormholeMerkleProof { - pub vaa: Vec, - pub slot: u64, - pub ring_size: u32, - pub root: Hash, -} - #[derive(Clone, PartialEq, Debug)] pub struct WormholeMerkleMessageProof { pub vaa: Vec, @@ -46,7 +35,8 @@ pub struct WormholeMerkleMessageProof { pub async fn store_wormhole_merkle_verified_message( store: &Store, - proof: WormholeMerkleProof, + proof: WormholeMerkleRoot, + vaa_bytes: Vec, ) -> Result<()> { let pending_acc = store .pending_accumulations @@ -56,7 +46,10 @@ pub async fn store_wormhole_merkle_verified_message( .into_value(); store .pending_accumulations - .insert(proof.slot, pending_acc.wormhole_merkle_proof(proof)) + .insert( + proof.slot, + pending_acc.wormhole_merkle_proof((proof, vaa_bytes)), + ) .await; Ok(()) } @@ -76,14 +69,11 @@ pub fn construct_message_states_proofs( None => return Ok(vec![]), // It only happens when the message set is empty }; - let proof = &state.wormhole_merkle_proof; + let (proof, vaa) = &state.wormhole_merkle_proof; - log::info!( - "Merkle root: {:?}, Verified root: {:?}", - merkle_acc.root, - proof.root - ); - log::info!("Valid: {}", merkle_acc.root == proof.root); + if merkle_acc.root != proof.root { + return Err(anyhow!("Invalid merkle root")); + } state .accumulator_messages @@ -91,7 +81,7 @@ pub fn construct_message_states_proofs( .iter() .map(|m| { Ok(WormholeMerkleMessageProof { - vaa: state.wormhole_merkle_proof.vaa.clone(), + vaa: vaa.clone(), proof: merkle_acc .prove(m.as_ref()) .ok_or(anyhow!("Failed to prove message"))?, @@ -100,7 +90,7 @@ pub fn construct_message_states_proofs( .collect::>>() } -pub fn construct_update_data(mut message_states: Vec) -> Result>> { +pub fn construct_update_data(mut message_states: Vec<&MessageState>) -> Result>> { message_states.sort_by_key( |m| m.proof_set.wormhole_merkle_proof.vaa.clone(), // FIXME: This is not efficient ); @@ -118,39 +108,18 @@ pub fn construct_update_data(mut message_states: Vec) -> Result(0x504e4155)?; // "PNAU" - cursor.write_u8(0x01)?; // Major version - cursor.write_u8(0x00)?; // Minor version - cursor.write_u8(0)?; // Trailing header size - - cursor.write_u8(0)?; // Update type of WormholeMerkle. FIXME: Make this out of enum - - // Writing VAA - cursor.write_u16::(vaa.len().try_into()?)?; - cursor.write_all(&vaa)?; - - // Writing number of messages - cursor.write_u8(messages.len().try_into()?)?; - - for message in messages { - // Writing message - cursor.write_u16::(message.raw_message.len().try_into()?)?; - cursor.write_all(&message.raw_message)?; - - // Writing proof - cursor.write_all( - &message - .proof_set - .wormhole_merkle_proof - .proof - .serialize() - .ok_or(anyhow!("Unable to serialize merkle proof path"))?, - )?; - } - - Ok(cursor.into_inner()) + Ok(to_vec::<_, byteorder::BE>(&AccumulatorUpdateData::new( + Proof::WormholeMerkle { + vaa: vaa.into(), + updates: messages + .iter() + .map(|message| MerklePriceUpdate { + message: message.raw_message.clone().into(), + proof: message.proof_set.wormhole_merkle_proof.proof.clone(), + }) + .collect(), + }, + ))?) }) .collect::>>>() } diff --git a/hermes/src/store/storage.rs b/hermes/src/store/storage.rs index 307174d5..4968cf4c 100644 --- a/hermes/src/store/storage.rs +++ b/hermes/src/store/storage.rs @@ -7,7 +7,6 @@ use { }, anyhow::Result, pyth_sdk::PriceIdentifier, - std::sync::Arc, }; pub mod local_storage; @@ -30,4 +29,4 @@ pub trait Storage: Send + Sync { fn keys(&self) -> Vec; } -pub type StorageInstance = Arc>; +pub type StorageInstance = Box; diff --git a/hermes/src/store/storage/local_storage.rs b/hermes/src/store/storage/local_storage.rs index 29a8ec46..66c36aed 100644 --- a/hermes/src/store/storage/local_storage.rs +++ b/hermes/src/store/storage/local_storage.rs @@ -28,10 +28,10 @@ pub struct LocalStorage { impl LocalStorage { pub fn new_instance(max_size_per_key: usize) -> StorageInstance { - Arc::new(Box::new(Self { + Box::new(Self { cache: Arc::new(DashMap::new()), max_size_per_key, - })) + }) } fn retrieve_message_state( diff --git a/hermes/src/store/types.rs b/hermes/src/store/types.rs index 9bc0d4e0..930344c6 100644 --- a/hermes/src/store/types.rs +++ b/hermes/src/store/types.rs @@ -60,7 +60,6 @@ pub struct WormholeMerkleState { #[derive(Clone, PartialEq, Eq, Debug, Hash)] pub struct MessageIdentifier { - // -> this is the real message id pub price_id: PriceIdentifier, pub type_: MessageType, } @@ -84,6 +83,7 @@ pub struct MessageState { pub message: Message, pub raw_message: RawMessage, pub proof_set: ProofSet, + pub received_at: UnixTimestamp, } impl MessageState { @@ -98,7 +98,13 @@ impl MessageState { self.id.clone() } - pub fn new(message: Message, raw_message: RawMessage, proof_set: ProofSet, slot: Slot) -> Self { + pub fn new( + message: Message, + raw_message: RawMessage, + proof_set: ProofSet, + slot: Slot, + received_at: UnixTimestamp, + ) -> Self { Self { publish_time: message.publish_time(), slot, @@ -106,6 +112,7 @@ impl MessageState { message, raw_message, proof_set, + received_at, } } } @@ -138,7 +145,17 @@ pub enum Update { AccumulatorMessages(AccumulatorMessages), } +pub struct PriceFeedUpdate { + pub price_feed: PriceFeedMessage, + pub slot: Slot, + pub received_at: UnixTimestamp, + /// Wormhole merkle update data for this single price feed update. + /// This field is available for backward compatibility and will be + /// removed in the future. + pub wormhole_merkle_update_data: Vec, +} + pub struct PriceFeedsWithUpdateData { - pub price_feeds: Vec, + pub price_feeds: Vec, pub wormhole_merkle_update_data: Vec>, } diff --git a/hermes/src/store/wormhole.rs b/hermes/src/store/wormhole.rs index 9e3bfa1c..169678b9 100644 --- a/hermes/src/store/wormhole.rs +++ b/hermes/src/store/wormhole.rs @@ -1,8 +1,5 @@ use { - super::{ - proof::wormhole_merkle::WormholeMerkleProof, - Store, - }, + super::Store, anyhow::{ anyhow, Result, @@ -16,6 +13,10 @@ use { Secp256k1, }, serde_wormhole::RawMessage, + sha3::{ + Digest, + Keccak256, + }, wormhole_sdk::{ vaa::{ Body, @@ -26,43 +27,6 @@ use { }, }; -#[derive(Clone, Debug, PartialEq)] -pub enum WormholePayload { - Merkle(WormholeMerkleProof), -} - -impl WormholePayload { - pub fn try_from_bytes(bytes: &[u8], vaa_bytes: &[u8]) -> Result { - if bytes.len() != 37 { - return Err(anyhow!("Invalid message length")); - } - - // TODO: Use byte string literals for this check - let magic = u32::from_be_bytes(bytes[0..4].try_into()?); - if magic != 0x41555756u32 { - return Err(anyhow!("Invalid magic")); - } - - let message_type = u8::from_be_bytes(bytes[4..5].try_into()?); - - if message_type != 0 { - return Err(anyhow!("Invalid message type")); - } - - let slot = u64::from_be_bytes(bytes[5..13].try_into()?); - let ring_size = u32::from_be_bytes(bytes[13..17].try_into()?); - let root_digest = bytes[17..37].try_into()?; - - - Ok(Self::Merkle(WormholeMerkleProof { - root: root_digest, - slot, - ring_size, - vaa: vaa_bytes.to_vec(), - })) - } -} - /// Parses and verifies a VAA to ensure it is signed by the Wormhole guardian set. pub async fn parse_and_verify_vaa<'a>( store: &Store, @@ -72,33 +36,51 @@ pub async fn parse_and_verify_vaa<'a>( let (header, body): (Header, Body<&RawMessage>) = vaa.into(); let digest = body.digest()?; + let guardian_set = match store.guardian_set.read().await.as_ref() { + Some(guardian_set) => guardian_set.clone(), + None => { + return Err(anyhow!("Guardian set is not initialized")); + } + }; + let mut num_correct_signers = 0; for sig in header.signatures.iter() { let signer_id: usize = sig.index.into(); + let sig = sig.signature; + // Recover the public key from ecdsa signature from [u8; 65] that has (v, r, s) format let recid = RecoveryId::from_i32(sig[64].into())?; - // Recover the public key from ecdsa signature from [u8; 65] that has (v, r, s) format let secp = Secp256k1::new(); - let pubkey = &secp + + // To get the address we need to use the uncompressed public key + let pubkey: &[u8; 65] = &secp .recover_ecdsa( &Message::from_slice(&digest.secp256k_hash)?, &RecoverableSignature::from_compact(&sig[..64], recid)?, )? - .serialize(); + .serialize_uncompressed(); - let address = GuardianAddress(pubkey[pubkey.len() - 20..].try_into()?); + // The address is the last 20 bytes of the Keccak256 hash of the public key + let mut keccak = Keccak256::new(); + keccak.update(&pubkey[1..]); - if let Some(guardian_set) = store.guardian_set.read().await.as_ref() { - if guardian_set.get(signer_id) == Some(&address) { - num_correct_signers += 1; - } + let address: [u8; 32] = keccak.finalize().into(); + let address = GuardianAddress(address[address.len() - 20..].try_into()?); + + if guardian_set.get(signer_id) == Some(&address) { + num_correct_signers += 1; } } - if num_correct_signers < header.signatures.len() * 2 / 3 { - return Err(anyhow!("Not enough correct signatures")); + let quorum = (guardian_set.len() * 2 + 2) / 3; + if num_correct_signers < quorum { + return Err(anyhow!( + "Not enough correct signatures. Expected {:?}, received {:?}", + quorum, + num_correct_signers + )); } Ok(body) diff --git a/pythnet/pythnet_sdk/src/accumulators/merkle.rs b/pythnet/pythnet_sdk/src/accumulators/merkle.rs index eb45b638..93e2f8fd 100644 --- a/pythnet/pythnet_sdk/src/accumulators/merkle.rs +++ b/pythnet/pythnet_sdk/src/accumulators/merkle.rs @@ -51,25 +51,13 @@ fn hash_null() -> H::Hash { H::hashv(&[NULL_PREFIX]) } -#[derive(Clone, Default, Debug, PartialEq, Eq, Serialize)] +#[derive(Clone, Default, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] pub struct MerklePath(Vec); impl MerklePath { pub fn new(path: Vec) -> Self { Self(path) } - - pub fn serialize(&self) -> Option> { - let mut serialized = vec![]; - let proof_size: u8 = self.0.len().try_into().ok()?; - serialized.extend_from_slice(&proof_size.to_be_bytes()); - - for node in self.0.iter() { - serialized.extend_from_slice(node.as_ref()); - } - - Some(serialized) - } } /// A MerkleAccumulator maintains a Merkle Tree. diff --git a/pythnet/pythnet_sdk/src/de.rs b/pythnet/pythnet_sdk/src/de.rs new file mode 100644 index 00000000..e1e5b739 --- /dev/null +++ b/pythnet/pythnet_sdk/src/de.rs @@ -0,0 +1,487 @@ +use { + byteorder::{ + ByteOrder, + ReadBytesExt, + }, + serde::{ + de::{ + EnumAccess, + IntoDeserializer, + MapAccess, + SeqAccess, + VariantAccess, + }, + Deserialize, + }, + std::io::{ + Cursor, + Seek, + SeekFrom, + }, + thiserror::Error, +}; + +pub fn from_slice<'de, B, T>(bytes: &'de [u8]) -> Result +where + T: Deserialize<'de>, + B: ByteOrder, +{ + let mut deserializer = Deserializer::::new(bytes); + T::deserialize(&mut deserializer) +} + +#[derive(Debug, Error)] +pub enum DeserializeError { + #[error("io error: {0}")] + Io(#[from] std::io::Error), + + #[error("invalid utf8: {0}")] + Utf8(#[from] std::str::Utf8Error), + + #[error("this type is not supported")] + Unsupported, + + #[error("sequence too large ({0} elements), max supported is 255")] + SequenceTooLarge(usize), + + #[error("message: {0}")] + Message(Box), + + #[error("eof")] + Eof, +} + +pub struct Deserializer<'de, B> +where + B: ByteOrder, +{ + cursor: Cursor<&'de [u8]>, + endian: std::marker::PhantomData, +} + +impl serde::de::Error for DeserializeError { + fn custom(msg: T) -> Self { + DeserializeError::Message(msg.to_string().into_boxed_str()) + } +} + +impl<'de, B> Deserializer<'de, B> +where + B: ByteOrder, +{ + pub fn new(buffer: &'de [u8]) -> Self { + Self { + cursor: Cursor::new(buffer), + endian: std::marker::PhantomData, + } + } +} + +impl<'de, B> serde::de::Deserializer<'de> for &'_ mut Deserializer<'de, B> +where + B: ByteOrder, +{ + type Error = DeserializeError; + + fn deserialize_any(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + Err(DeserializeError::Unsupported) + } + + fn deserialize_ignored_any(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + Err(DeserializeError::Unsupported) + } + + fn deserialize_bool(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + let value = self.cursor.read_u8().map_err(DeserializeError::from)?; + visitor.visit_bool(value != 0) + } + + fn deserialize_i8(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + let value = self.cursor.read_i8().map_err(DeserializeError::from)?; + visitor.visit_i8(value) + } + + fn deserialize_i16(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + let value = self + .cursor + .read_i16::() + .map_err(DeserializeError::from)?; + + visitor.visit_i16(value) + } + + fn deserialize_i32(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + let value = self + .cursor + .read_i32::() + .map_err(DeserializeError::from)?; + + visitor.visit_i32(value) + } + + fn deserialize_i64(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + let value = self + .cursor + .read_i64::() + .map_err(DeserializeError::from)?; + + visitor.visit_i64(value) + } + + fn deserialize_u8(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + let value = self.cursor.read_u8().map_err(DeserializeError::from)?; + visitor.visit_u8(value) + } + + fn deserialize_u16(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + let value = self + .cursor + .read_u16::() + .map_err(DeserializeError::from)?; + + visitor.visit_u16(value) + } + + fn deserialize_u32(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + let value = self + .cursor + .read_u32::() + .map_err(DeserializeError::from)?; + + visitor.visit_u32(value) + } + + fn deserialize_u64(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + let value = self + .cursor + .read_u64::() + .map_err(DeserializeError::from)?; + + visitor.visit_u64(value) + } + + fn deserialize_f32(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + Err(DeserializeError::Unsupported) + } + + fn deserialize_f64(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + Err(DeserializeError::Unsupported) + } + + fn deserialize_char(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + Err(DeserializeError::Unsupported) + } + + fn deserialize_str(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + let len = self.cursor.read_u8().map_err(DeserializeError::from)? as u64; + + // Because cursor read methods copy the data out of the internal buffer, + // where we want a pointer, we need to move the cursor forward first + // (because the ) get_ref() call triggers a borrow), and then read + // after. + self.cursor + .seek(SeekFrom::Current(len as i64)) + .map_err(DeserializeError::from)?; + + let buf = { + let buf = self.cursor.get_ref(); + buf[(self.cursor.position() - len) as usize..] + .get(..len as usize) + .ok_or(DeserializeError::Eof)? + }; + + visitor.visit_borrowed_str(std::str::from_utf8(buf).map_err(DeserializeError::from)?) + } + + fn deserialize_string(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.deserialize_str(visitor) + } + + fn deserialize_bytes(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + let len = self.cursor.read_u8().map_err(DeserializeError::from)? as u64; + + // See comment in deserialize_str for the reason for the subtraction. + self.cursor + .seek(SeekFrom::Current(len as i64)) + .map_err(DeserializeError::from)?; + + let buf = { + let buf = self.cursor.get_ref(); + buf[(self.cursor.position() - len) as usize..] + .get(..len as usize) + .ok_or(DeserializeError::Eof)? + }; + + visitor.visit_borrowed_bytes(buf) + } + + fn deserialize_byte_buf(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.deserialize_bytes(visitor) + } + + fn deserialize_option(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + Err(DeserializeError::Unsupported) + } + + fn deserialize_unit(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_unit_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_newtype_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_seq(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + let len = self.cursor.read_u8().map_err(DeserializeError::from)? as usize; + visitor.visit_seq(SequenceIterator::new(self, len)) + } + + fn deserialize_tuple(self, len: usize, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_seq(SequenceIterator::new(self, len)) + } + + fn deserialize_tuple_struct( + self, + _name: &'static str, + len: usize, + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_seq(SequenceIterator::new(self, len)) + } + + fn deserialize_map(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + let len = self.cursor.read_u8().map_err(DeserializeError::from)? as usize; + visitor.visit_map(SequenceIterator::new(self, len)) + } + + fn deserialize_struct( + self, + _name: &'static str, + fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_seq(SequenceIterator::new(self, fields.len())) + } + + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + let variant = self.cursor.read_u8().map_err(DeserializeError::from)?; + visitor.visit_enum(Enum { de: self, variant }) + } + + fn deserialize_identifier(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + Err(DeserializeError::Unsupported) + } +} + +impl<'de, 'a, B: ByteOrder> VariantAccess<'de> for &'a mut Deserializer<'de, B> { + type Error = DeserializeError; + + fn unit_variant(self) -> Result<(), Self::Error> { + Ok(()) + } + + fn newtype_variant_seed(self, seed: T) -> Result + where + T: serde::de::DeserializeSeed<'de>, + { + seed.deserialize(self) + } + + fn tuple_variant(self, len: usize, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_seq(SequenceIterator::new(self, len)) + } + + fn struct_variant( + self, + fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_seq(SequenceIterator::new(self, fields.len())) + } +} + +struct SequenceIterator<'de, 'a, B: ByteOrder> { + de: &'a mut Deserializer<'de, B>, + len: usize, +} + +impl<'de, 'a, B: ByteOrder> SequenceIterator<'de, 'a, B> { + fn new(de: &'a mut Deserializer<'de, B>, len: usize) -> Self { + Self { de, len } + } +} + +impl<'de, 'a, B: ByteOrder> SeqAccess<'de> for SequenceIterator<'de, 'a, B> { + type Error = DeserializeError; + + fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> + where + T: serde::de::DeserializeSeed<'de>, + { + if self.len == 0 { + return Ok(None); + } + + self.len -= 1; + seed.deserialize(&mut *self.de).map(Some) + } + + fn size_hint(&self) -> Option { + Some(self.len) + } +} + +impl<'de, 'a, B: ByteOrder> MapAccess<'de> for SequenceIterator<'de, 'a, B> { + type Error = DeserializeError; + + fn next_key_seed(&mut self, seed: K) -> Result, Self::Error> + where + K: serde::de::DeserializeSeed<'de>, + { + if self.len == 0 { + return Ok(None); + } + + self.len -= 1; + seed.deserialize(&mut *self.de).map(Some) + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: serde::de::DeserializeSeed<'de>, + { + seed.deserialize(&mut *self.de) + } + + fn size_hint(&self) -> Option { + Some(self.len) + } +} + +struct Enum<'de, 'a, B: ByteOrder> { + de: &'a mut Deserializer<'de, B>, + variant: u8, +} + +impl<'de, 'a, B: ByteOrder> EnumAccess<'de> for Enum<'de, 'a, B> { + type Error = DeserializeError; + type Variant = &'a mut Deserializer<'de, B>; + + fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> + where + V: serde::de::DeserializeSeed<'de>, + { + seed.deserialize(self.variant.into_deserializer()) + .map(|v| (v, self.de)) + } +} diff --git a/pythnet/pythnet_sdk/src/error.rs b/pythnet/pythnet_sdk/src/error.rs new file mode 100644 index 00000000..b4152ad9 --- /dev/null +++ b/pythnet/pythnet_sdk/src/error.rs @@ -0,0 +1,19 @@ +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum Error { + #[error("Invalid Magic")] + InvalidMagic, + + #[error("Invalid Version")] + InvalidVersion, +} + +#[macro_export] +macro_rules! require { + ($cond:expr, $err:expr) => { + if !$cond { + return Err($err); + } + }; +} diff --git a/pythnet/pythnet_sdk/src/hashers.rs b/pythnet/pythnet_sdk/src/hashers.rs index 68eefaba..11c068d8 100644 --- a/pythnet/pythnet_sdk/src/hashers.rs +++ b/pythnet/pythnet_sdk/src/hashers.rs @@ -28,6 +28,7 @@ where + Debug + Default + Eq + + std::hash::Hash + PartialOrd + PartialEq + serde::Serialize diff --git a/pythnet/pythnet_sdk/src/hashers/keccak256_160.rs b/pythnet/pythnet_sdk/src/hashers/keccak256_160.rs index dc9b462c..3c79ad65 100644 --- a/pythnet/pythnet_sdk/src/hashers/keccak256_160.rs +++ b/pythnet/pythnet_sdk/src/hashers/keccak256_160.rs @@ -7,7 +7,7 @@ use { }, }; -#[derive(Clone, Default, Debug, Eq, PartialEq, Serialize)] +#[derive(Clone, Default, Debug, Eq, Hash, PartialEq, Serialize)] pub struct Keccak160 {} impl Hasher for Keccak160 { diff --git a/pythnet/pythnet_sdk/src/lib.rs b/pythnet/pythnet_sdk/src/lib.rs index f1817851..6c952c30 100644 --- a/pythnet/pythnet_sdk/src/lib.rs +++ b/pythnet/pythnet_sdk/src/lib.rs @@ -1,4 +1,5 @@ pub mod accumulators; +pub mod error; pub mod hashers; pub mod wire; pub mod wormhole; diff --git a/pythnet/pythnet_sdk/src/payload.rs b/pythnet/pythnet_sdk/src/payload.rs new file mode 100644 index 00000000..d993847d --- /dev/null +++ b/pythnet/pythnet_sdk/src/payload.rs @@ -0,0 +1,109 @@ +//! Definition of the Accumulator Payload Formats. +//! +//! This module defines the data types that are injected into VAA's to be sent to other chains via +//! Wormhole. The wire format for these types must be backwards compatible and so all tyeps in this +//! module are expected to be append-only (for minor changes) and versioned for breaking changes. + +use { + crate::{ + error::Error, + require, + ser::PrefixedVec, + }, + serde::{ + Deserialize, + Serialize, + }, +}; + +// Proof Format (V1) +// -------------------------------------------------------------------------------- +// The definitions within each module can be updated with append-only data without requiring a new +// module to be defined. So for example, it is possible to add new fields can be added to the end +// of the `AccumulatorAccount` without moving to a `v1`. +pub mod v1 { + use { + super::*, + crate::{ + accumulators::merkle::MerklePath, + de::from_slice, + hashers::keccak256_160::Keccak160, + }, + }; + + // Transfer Format. + // -------------------------------------------------------------------------------- + // This definition is what will be sent over the wire (I.E, pulled from PythNet and submitted + // to target chains). + #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] + pub struct AccumulatorUpdateData { + magic: [u8; 4], + major_version: u8, + minor_version: u8, + trailing: Vec, + proof: Proof, + } + + impl AccumulatorUpdateData { + pub fn new(proof: Proof) -> Self { + Self { + magic: *b"PNAU", + major_version: 1, + minor_version: 0, + trailing: vec![], + proof, + } + } + + pub fn try_from_slice(bytes: &[u8]) -> Result { + let message = from_slice::(bytes).unwrap(); + require!(&message.magic[..] != b"PNAU", Error::InvalidMagic); + require!(message.major_version == 1, Error::InvalidVersion); + require!(message.minor_version == 0, Error::InvalidVersion); + Ok(message) + } + } + + // A hash of some data. + pub type Hash = [u8; 20]; + + #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] + pub enum Proof { + WormholeMerkle { + vaa: PrefixedVec, + updates: Vec, + }, + } + + #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] + pub struct MerklePriceUpdate { + pub message: PrefixedVec, + pub proof: MerklePath, + } + + #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] + pub struct WormholeMessage { + pub magic: [u8; 4], + pub payload: WormholePayload, + } + + impl WormholeMessage { + pub fn try_from_bytes(bytes: impl AsRef<[u8]>) -> Result { + let message = from_slice::(bytes.as_ref()).unwrap(); + require!(&message.magic[..] == b"AUWV", Error::InvalidMagic); + Ok(message) + } + } + + #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] + pub enum WormholePayload { + Merkle(WormholeMerkleRoot), + } + + #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] + pub struct WormholeMerkleRoot { + pub slot: u64, + pub ring_size: u32, + pub root: Hash, + } +} diff --git a/pythnet/pythnet_sdk/src/ser.rs b/pythnet/pythnet_sdk/src/ser.rs new file mode 100644 index 00000000..710604f2 --- /dev/null +++ b/pythnet/pythnet_sdk/src/ser.rs @@ -0,0 +1,893 @@ +//! A module defining serde serialize for a simple Rust struct-like message format. The format will +//! read Rust types exactly the size they are, and reads sequences by reading a u8 length followed +//! by a count of the elements of the vector. +//! +//! TL;DR: How to Use +//! ================================================================================ +//! +//! ```rust,ignore +//! #[derive(Serialize)] +//! struct ExampleStruct { +//! a: (), +//! b: bool, +//! c: u8, +//! ..., +//! } +//! +//! let mut buf = Vec::new(); +//! let mut cur = Cursor::new(&mut buf); +//! pythnet_sdk::ser::to_writer(&mut cur, &ExampleStruct { ... }).unwrap(); +//! let result = pythnet_sdk::ser::to_vec(&ExampleStruct { ... }).unwrap(); +//! ``` +//! +//! A quick primer on `serde::Serialize`: +//! ================================================================================ +//! +//! Given some type `T`, the `serde::Serialize` derives an implementation with a `serialize` method +//! that calls all the relevant `serialize_` calls defined in this file, so for example, given the +//! following types: +//! +//! ```rust,ignore +//! #[derive(Serialize)] +//! enum ExampleEnum { +//! A, +//! B(u8), +//! C(u8, u8), +//! D { a: u8, b: u8 }, +//! } +//! +//! #[derive(Serialize)] +//! struct ExampleStruct { +//! a: (), +//! b: bool, +//! c: u8, +//! d: &str, +//! e: ExampleEnum +//! } +//! ``` +//! +//! The macro will expand into (a more complicated but equivelent) version of: +//! +//! ```rust,ignore +//! impl serde::Serialize for ExampleEnum { +//! fn serialize(&self, serializer: S) -> Result { +//! match self { +//! ExampleEnum::A => serializer.serialize_unit_variant("ExampleEnum", 0, "A"), +//! ExampleEnum::B(v) => serializer.serialize_newtype_variant("ExampleEnum", 1, "B", v), +//! ExampleEnum::C(v0, v1) => serializer.serialize_tuple_variant("ExampleEnum", 2, "C", (v0, v1)), +//! ExampleEnum::D { a, b } => serializer.serialize_struct_variant("ExampleEnum", 3, "D", 2, "a", a, "b", b), +//! } +//! } +//! } +//! +//! impl serde::Serialize for ExampleStruct { +//! fn serialize(&self, serializer: S) -> Result { +//! let mut state = serializer.serialize_struct("ExampleStruct", 5)?; +//! state.serialize_field("a", &self.a)?; +//! state.serialize_field("b", &self.b)?; +//! state.serialize_field("c", &self.c)?; +//! state.serialize_field("d", &self.d)?; +//! state.serialize_field("e", &self.e)?; +//! state.end() +//! } +//! } +//! ``` +//! +//! Note that any parser can be passed in, which gives the serializer the ability to serialize to +//! any format we desire as long as there is a `Serializer` implementation for it. With aggressive +//! inlining, the compiler will be able to optimize away the intermediate state objects and calls +//! to `serialize_field` and `serialize_*_variant` and the final result of our parser will have +//! very close to equivelent performance to a hand written implementation. +//! +//! The Pyth Serialization Format +//! ================================================================================ +//! +//! Pyth has various data formats that are serialized in compact forms to make storing or passing +//! cross-chain cheaper. So far all these formats follow a similar pattern and so this serializer +//! is designed to be able to provide a canonical implementation for all of them. +//! +//! +//! Format Spec: +//! -------------------------------------------------------------------------------- +//! +//! Integers: +//! +//! - `{u,i}8` are serialized as a single byte +//! - `{u,i}16/32/64` are serialized as bytes specified by the parser endianess type param. +//! - `{u,i}128` is not supported. +//! +//! Floats: +//! +//! - `f32/64/128` are not supported due to different chains having different float formats. +//! +//! Strings: +//! +//! - `&str` is serialized as a u8 length followed by the bytes of the string. +//! - `String` is serialized as a u8 length followed by the bytes of the string. +//! +//! Sequences: +//! +//! - `Vec` is serialized as a u8 length followed by the serialized elements of the vector. +//! - `&[T]` is serialized as a u8 length followed by the serialized elements of the slice. +//! +//! Enums: +//! +//! - `enum` is serialized as a u8 variant index followed by the serialized variant data. +//! - `Option` is serialized as a u8 variant index followed by the serialized variant data. +//! +//! Structs: +//! +//! - `struct` is serialized as the serialized fields of the struct in order. +//! +//! Tuples: +//! +//! - `tuple` is serialized as the serialized elements of the tuple in order. +//! +//! Unit: +//! +//! - `()` is serialized as nothing. +//! +//! +//! Example Usage +//! -------------------------------------------------------------------------------- +//! +//! ```rust,ignore +//! fn example(data: &[u8]) { +//! let mut buf = Vec::new(); +//! let mut cur = Cursor::new(&mut buf); +//! let mut des = Deserializer::new(&mut cur); +//! let mut result = des.deserialize::(data).unwrap(); +//! ... +//! } +//! ``` + +use { + byteorder::{ + ByteOrder, + WriteBytesExt, + }, + serde::{ + ser::{ + SerializeMap, + SerializeSeq, + SerializeStruct, + SerializeStructVariant, + SerializeTuple, + SerializeTupleStruct, + SerializeTupleVariant, + }, + Serialize, + }, + std::{ + fmt::Display, + io::Write, + }, + thiserror::Error, +}; + +pub fn to_vec(value: &T) -> Result, SerializeError> +where + T: Serialize, + B: ByteOrder, +{ + let mut buf = Vec::new(); + value.serialize(&mut Serializer::<_, B>::new(&mut buf))?; + Ok(buf) +} + +#[derive(Debug, Error)] +pub enum SerializeError { + #[error("io error: {0}")] + Io(#[from] std::io::Error), + + #[error("this type is not supported")] + Unsupported, + + #[error("sequence too large ({0} elements), max supported is 255")] + SequenceTooLarge(usize), + + #[error("sequence length must be known before serializing")] + SequenceLengthUnknown, + + #[error("enum variant {0}::{1} cannot be parsed as `u8`: {2}")] + InvalidEnumVariant(&'static str, u32, &'static str), + + #[error("message: {0}")] + Message(Box), +} + +#[derive(Clone, Debug, Hash, PartialEq, PartialOrd, Error, serde::Deserialize)] +pub struct PrefixedVec { + data: PrefixlessVec, + __phantom: std::marker::PhantomData, +} + +impl From> for PrefixedVec { + fn from(data: Vec) -> Self { + Self { + data: PrefixlessVec { data }, + __phantom: std::marker::PhantomData, + } + } +} + +#[derive(Clone, Debug, Hash, PartialEq, PartialOrd, Error, serde::Deserialize)] +struct PrefixlessVec { + data: Vec, +} + +impl Serialize for PrefixedVec +where + T: Serialize, + L: Serialize, + L: TryFrom, + >::Error: std::fmt::Debug, +{ + #[inline] + fn serialize(&self, serializer: S) -> Result { + let len: L = L::try_from(self.data.data.len()).unwrap(); + let mut st = serializer.serialize_struct("SizedVec", 1)?; + st.serialize_field("len", &len)?; + st.serialize_field("data", &self.data)?; + st.end() + } +} + +impl Serialize for PrefixlessVec +where + T: Serialize, +{ + #[inline] + fn serialize(&self, serializer: S) -> Result { + let mut seq = serializer.serialize_seq(None)?; + for item in &self.data { + seq.serialize_element(item)?; + } + seq.end() + } +} + +/// A type for Pyth's common serialization format. Note that a ByteOrder type param is required as +/// we serialize in both big and little endian depending on different use-cases. +#[derive(Clone)] +pub struct Serializer { + writer: W, + _endian: std::marker::PhantomData, +} + +impl serde::ser::Error for SerializeError { + fn custom(msg: T) -> Self { + SerializeError::Message(msg.to_string().into_boxed_str()) + } +} + +impl Serializer { + pub fn new(writer: W) -> Self { + Self { + writer, + _endian: std::marker::PhantomData, + } + } +} + +impl<'a, W: Write, B: ByteOrder> serde::Serializer for &'a mut Serializer { + type Ok = (); + type Error = SerializeError; + + // Serde uses different types for different parse targets to allow for different + // implementations. We only support one target, so we can set all these to `Self` + // and implement those traits on the same type. + type SerializeSeq = Self; + type SerializeTuple = Self; + type SerializeTupleStruct = Self; + type SerializeTupleVariant = Self; + type SerializeMap = Self; + type SerializeStruct = Self; + type SerializeStructVariant = Self; + + #[inline] + fn serialize_bool(self, v: bool) -> Result { + self.writer + .write_all(&[v as u8]) + .map_err(SerializeError::from) + } + + #[inline] + fn serialize_i8(self, v: i8) -> Result { + self.writer + .write_all(&[v as u8]) + .map_err(SerializeError::from) + } + + #[inline] + fn serialize_i16(self, v: i16) -> Result { + self.writer.write_i16::(v).map_err(SerializeError::from) + } + + #[inline] + fn serialize_i32(self, v: i32) -> Result { + self.writer.write_i32::(v).map_err(SerializeError::from) + } + + #[inline] + fn serialize_i64(self, v: i64) -> Result { + self.writer.write_i64::(v).map_err(SerializeError::from) + } + + #[inline] + fn serialize_u8(self, v: u8) -> Result { + self.writer.write_all(&[v]).map_err(SerializeError::from) + } + + #[inline] + fn serialize_u16(self, v: u16) -> Result { + self.writer.write_u16::(v).map_err(SerializeError::from) + } + + #[inline] + fn serialize_u32(self, v: u32) -> Result { + self.writer.write_u32::(v).map_err(SerializeError::from) + } + + #[inline] + fn serialize_u64(self, v: u64) -> Result { + self.writer.write_u64::(v).map_err(SerializeError::from) + } + + #[inline] + fn serialize_f32(self, _: f32) -> Result { + Err(SerializeError::Unsupported) + } + + #[inline] + fn serialize_f64(self, _: f64) -> Result { + Err(SerializeError::Unsupported) + } + + #[inline] + fn serialize_char(self, _: char) -> Result { + Err(SerializeError::Unsupported) + } + + #[inline] + fn serialize_str(self, v: &str) -> Result { + let len = u8::try_from(v.len()).map_err(|_| SerializeError::SequenceTooLarge(v.len()))?; + self.writer.write_all(&[len])?; + self.writer + .write_all(v.as_bytes()) + .map_err(SerializeError::from) + } + + #[inline] + fn serialize_bytes(self, v: &[u8]) -> Result { + let len = u8::try_from(v.len()).map_err(|_| SerializeError::SequenceTooLarge(v.len()))?; + self.writer.write_all(&[len])?; + self.writer.write_all(v).map_err(SerializeError::from) + } + + #[inline] + fn serialize_none(self) -> Result { + Err(SerializeError::Unsupported) + } + + #[inline] + fn serialize_some(self, value: &T) -> Result { + value.serialize(self) + } + + #[inline] + fn serialize_unit(self) -> Result { + Ok(()) + } + + #[inline] + fn serialize_unit_struct(self, _name: &'static str) -> Result { + Ok(()) + } + + #[inline] + fn serialize_unit_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + ) -> Result { + let variant: u8 = variant_index + .try_into() + .map_err(|_| SerializeError::InvalidEnumVariant(name, variant_index, variant))?; + + self.writer + .write_all(&[variant]) + .map_err(SerializeError::from) + } + + #[inline] + fn serialize_newtype_struct( + self, + _name: &'static str, + value: &T, + ) -> Result { + value.serialize(self) + } + + #[inline] + fn serialize_newtype_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + value: &T, + ) -> Result { + let variant: u8 = variant_index + .try_into() + .map_err(|_| SerializeError::InvalidEnumVariant(name, variant_index, variant))?; + + self.writer.write_all(&[variant])?; + value.serialize(self) + } + + #[inline] + fn serialize_seq(self, len: Option) -> Result { + if let Some(len) = len { + let len = u8::try_from(len).map_err(|_| SerializeError::SequenceTooLarge(len))?; + self.writer.write_all(&[len])?; + } + + Ok(self) + } + + #[inline] + fn serialize_tuple(self, _len: usize) -> Result { + Ok(self) + } + + #[inline] + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + Ok(self) + } + + #[inline] + fn serialize_tuple_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + _len: usize, + ) -> Result { + let variant: u8 = variant_index + .try_into() + .map_err(|_| SerializeError::InvalidEnumVariant(name, variant_index, variant))?; + + self.writer.write_all(&[variant])?; + Ok(self) + } + + #[inline] + fn serialize_map(self, len: Option) -> Result { + let len = len + .ok_or(SerializeError::SequenceLengthUnknown) + .and_then(|len| u8::try_from(len).map_err(|_| SerializeError::SequenceTooLarge(len)))?; + + self.writer.write_all(&[len])?; + Ok(self) + } + + #[inline] + fn serialize_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + Ok(self) + } + + #[inline] + fn serialize_struct_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + _len: usize, + ) -> Result { + let variant: u8 = variant_index + .try_into() + .map_err(|_| SerializeError::InvalidEnumVariant(name, variant_index, variant))?; + + self.writer.write_all(&[variant])?; + Ok(self) + } + + fn is_human_readable(&self) -> bool { + false + } + + fn collect_str(self, value: &T) -> Result { + self.serialize_str(&value.to_string()) + } +} + +impl<'a, W: Write, B: ByteOrder> SerializeSeq for &'a mut Serializer { + type Ok = (); + type Error = SerializeError; + + #[inline] + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> { + value.serialize(&mut **self) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl<'a, W: Write, B: ByteOrder> SerializeTuple for &'a mut Serializer { + type Ok = (); + type Error = SerializeError; + + #[inline] + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> { + value.serialize(&mut **self) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl<'a, W: Write, B: ByteOrder> SerializeTupleStruct for &'a mut Serializer { + type Ok = (); + type Error = SerializeError; + + #[inline] + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> { + value.serialize(&mut **self) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl<'a, W: Write, B: ByteOrder> SerializeTupleVariant for &'a mut Serializer { + type Ok = (); + type Error = SerializeError; + + #[inline] + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> { + value.serialize(&mut **self) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl<'a, W: Write, B: ByteOrder> SerializeMap for &'a mut Serializer { + type Ok = (); + type Error = SerializeError; + + #[inline] + fn serialize_key(&mut self, key: &T) -> Result<(), Self::Error> { + key.serialize(&mut **self) + } + + #[inline] + fn serialize_value(&mut self, value: &T) -> Result<(), Self::Error> { + value.serialize(&mut **self) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl<'a, W: Write, B: ByteOrder> SerializeStruct for &'a mut Serializer { + type Ok = (); + type Error = SerializeError; + + #[inline] + fn serialize_field( + &mut self, + _key: &'static str, + value: &T, + ) -> Result<(), Self::Error> { + value.serialize(&mut **self) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl<'a, W: Write, B: ByteOrder> SerializeStructVariant for &'a mut Serializer { + type Ok = (); + type Error = SerializeError; + + #[inline] + fn serialize_field( + &mut self, + _key: &'static str, + value: &T, + ) -> Result<(), Self::Error> { + value.serialize(&mut **self) + } + + fn end(self) -> Result { + Ok(()) + } +} + +// By default, serde does not know how to parse fixed length arrays of sizes +// that aren't common (I.E: 32) Here we provide a module that can be used to +// serialize arrays that relies on const generics. +// +// Usage: +// +// ```rust,ignore` +// #[derive(Serialize)] +// struct Example { +// #[serde(with = "array")] +// array: [u8; 55], +// } +// ``` +pub mod array { + use std::mem::MaybeUninit; + + /// Serialize an array of size N using a const generic parameter to drive serialize_seq. + pub fn serialize(array: &[T; N], serializer: S) -> Result + where + S: serde::Serializer, + T: serde::Serialize, + { + use serde::ser::SerializeTuple; + let mut seq = serializer.serialize_tuple(N)?; + array.iter().try_for_each(|e| seq.serialize_element(e))?; + seq.end() + } + + /// A Marer type that carries type-level information about the length of the + /// array we want to deserialize. + struct ArrayVisitor { + _marker: std::marker::PhantomData, + } + + /// Implement a Visitor over our ArrayVisitor that knows how many times to + /// call next_element using the generic. + impl<'de, T, const N: usize> serde::de::Visitor<'de> for ArrayVisitor + where + T: serde::de::Deserialize<'de>, + { + type Value = [T; N]; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "an array of length {}", N) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + // We use MaybeUninit to allocate the right amount of memory + // because we do not know if `T` has a constructor or a default. + // Without this we would have to allocate a Vec. + let mut array = MaybeUninit::<[T; N]>::uninit(); + let ptr = array.as_mut_ptr() as *mut T; + let mut pos = 0; + while pos < N { + let next = seq + .next_element()? + .ok_or_else(|| serde::de::Error::invalid_length(pos, &self))?; + + unsafe { + std::ptr::write(ptr.add(pos), next); + } + + pos += 1; + } + + // We only succeed if we fully filled the array. This prevents + // accidentally returning garbage. + if pos == N { + return Ok(unsafe { array.assume_init() }); + } + + Err(serde::de::Error::invalid_length(pos, &self)) + } + } + + /// Deserialize an array with an ArrayVisitor aware of `N` during deserialize. + pub fn deserialize<'de, D, T, const N: usize>(deserializer: D) -> Result<[T; N], D::Error> + where + D: serde::Deserializer<'de>, + T: serde::de::Deserialize<'de>, + { + deserializer.deserialize_tuple( + N, + ArrayVisitor { + _marker: std::marker::PhantomData, + }, + ) + } +} + +#[cfg(test)] +mod tests { + use super::{ + super::de::Deserializer, + array, + Serializer, + }; + + // Test the arbitrary fixed sized array serialization implementation. + #[test] + fn test_array_serde() { + // Serialize an array into a buffer. + let mut buffer = Vec::new(); + let mut cursor = std::io::Cursor::new(&mut buffer); + let mut serializer: Serializer<_, byteorder::LE> = Serializer::new(&mut cursor); + array::serialize(&[1u8; 37], &mut serializer).unwrap(); + + // The result should not have been prefixed with a length byte. + assert_eq!(buffer.len(), 37); + + // We should also be able to deserialize it back. + let mut deserializer = Deserializer::::new(&buffer); + let deserialized: [u8; 37] = array::deserialize(&mut deserializer).unwrap(); + + // The deserialized array should be the same as the original. + assert_eq!(deserialized, [1u8; 37]); + } + + // The array serializer should not interfere with other serializers. Here we + // check serde_json to make sure an array is written as expected. + #[test] + fn test_array_serde_json() { + // Serialize an array into a buffer. + let mut buffer = Vec::new(); + let mut cursor = std::io::Cursor::new(&mut buffer); + let mut serialized = serde_json::Serializer::new(&mut cursor); + array::serialize(&[1u8; 7], &mut serialized).unwrap(); + let result = String::from_utf8(buffer).unwrap(); + assert_eq!(result, "[1,1,1,1,1,1,1]"); + + // Deserializing should also work. + let mut deserializer = serde_json::Deserializer::from_str(&result); + let deserialized: [u8; 7] = array::deserialize(&mut deserializer).unwrap(); + assert_eq!(deserialized, [1u8; 7]); + } + + // Golden Structure Test + // + // This test serializes a struct containing all the expected types we should + // be able to handle and checks the output is as expected. The reason I + // opted to serialize all in one struct instead of with separate tests is to + // ensure that the positioning of elements when in relation to others is + // also as expected. Especially when it comes to things such as nesting and + // length prefixing. + #[test] + fn test_pyth_serde() { + use serde::Serialize; + + // Setup Serializer. + let mut buffer = Vec::new(); + let mut cursor = std::io::Cursor::new(&mut buffer); + let mut serializer: Serializer<_, byteorder::LE> = Serializer::new(&mut cursor); + + // Golden Test Value. As binary data can be fickle to understand in + // tests this should be kept commented with detail. + #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)] + struct GoldenStruct<'a> { + // Test `unit` is not serialized to anything. + unit: (), + + // Test `bool` is serialized to a single byte. + t_bool: bool, + + // Test integer serializations. + t_u8: u8, + t_u16: u16, + t_u32: u32, + t_u64: u64, + + // Test `str` is serialized to a variable length array. + t_string: String, + t_str: &'a str, + + // Test `Vec` is serialized to a variable length array. + t_vec: Vec, + t_vec_empty: Vec, + t_vec_nested: Vec>, + t_vec_nested_empty: Vec>, + t_slice: &'a [u8], + t_slice_empty: &'a [u8], + + // Test tuples serialize as expected. + t_tuple: (u8, u16, u32, u64, String, Vec, &'a [u8]), + t_tuple_nested: ((u8, u16), (u32, u64)), + + // Test enum serializations. + t_enum_unit: GoldenEnum, + t_enum_newtype: GoldenEnum, + t_enum_tuple: GoldenEnum, + t_enum_struct: GoldenEnum, + } + + #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)] + enum GoldenEnum { + UnitVariant, + NewtypeVariant(u8), + TupleVariant(u8, u16), + StructVariant { a: u8, b: u16 }, + } + + // Serialize the golden test value. + let golden_struct = GoldenStruct { + unit: (), + t_bool: true, + t_u8: 1, + t_u16: 2, + t_u32: 3, + t_u64: 4, + t_string: "9".to_string(), + t_str: "10", + t_vec: vec![11, 12, 13], + t_vec_empty: vec![], + t_vec_nested: vec![vec![14, 15, 16], vec![17, 18, 19]], + t_vec_nested_empty: vec![vec![], vec![]], + t_slice: &[20, 21, 22], + t_slice_empty: &[], + t_tuple: ( + 29, + 30, + 31, + 32, + "10".to_string(), + vec![35, 36, 37], + &[38, 39, 40], + ), + t_tuple_nested: ((41, 42), (43, 44)), + t_enum_unit: GoldenEnum::UnitVariant, + t_enum_newtype: GoldenEnum::NewtypeVariant(45), + t_enum_tuple: GoldenEnum::TupleVariant(46, 47), + t_enum_struct: GoldenEnum::StructVariant { a: 48, b: 49 }, + }; + + golden_struct.serialize(&mut serializer).unwrap(); + + // The serialized output should be as expected. + assert_eq!( + &buffer, + &[ + 1, // t_bool + 1, // t_u8 + 2, 0, // t_u16 + 3, 0, 0, 0, // t_u32 + 4, 0, 0, 0, 0, 0, 0, 0, // t_u64 + 1, 57, // t_string + 2, 49, 48, // t_str + 3, 11, 12, 13, // t_vec + 0, // t_vec_empty + 2, 3, 14, 15, 16, 3, 17, 18, 19, // t_vec_nested + 2, 0, 0, // t_vec_nested_empty + 3, 20, 21, 22, // t_slice + 0, // t_slice_empty + 29, // t_tuple + 30, 0, // u8 + 31, 0, 0, 0, // u16 + 32, 0, 0, 0, 0, 0, 0, 0, // u32 + 2, 49, 48, // "10" + 3, 35, 36, 37, // [35, 36, 37] + 3, 38, 39, 40, // [38, 39, 40] + 41, 42, 0, 43, 0, 0, 0, 44, 0, 0, 0, 0, 0, 0, 0, // t_tuple_nested + 0, // t_enum_unit + 1, 45, // t_enum_newtype + 2, 46, 47, 0, // t_enum_tuple + 3, 48, 49, 0, // t_enum_struct + ] + ); + } +} diff --git a/pythnet/pythnet_sdk/src/wormhole.rs b/pythnet/pythnet_sdk/src/wormhole.rs index fbb0e403..48012035 100644 --- a/pythnet/pythnet_sdk/src/wormhole.rs +++ b/pythnet/pythnet_sdk/src/wormhole.rs @@ -1,3 +1,8 @@ +//! This module provides Wormhole primitives. +//! +//! Wormhole does not provide an SDK for working with Solana versions of Wormhole related types, so +//! we clone the definitions from the Solana contracts here and adapt them to Pyth purposes. This +//! allows us to emit and parse messages through Wormhole. use { crate::Pubkey, borsh::{