feat: add serializers for pyth formats

feat: use pythnet serialization in hermes

Fix vaa validation

Clippy

Update config names

Wrap Store with Arc

Store works perfectly without Arc as all it's elements are behind an Arc
or something similar to that, however a developer might make
mistake to add a field and missing it.

Improve error handling

Update metadata struct

Add metadata

Update Eth listener

Pin wormhole to a version

Fix ws dispatcher

fix: blocking in go recv corrupts tokio runtime

Make network <> store message passing non-blocking

Update logs and revert debug changes
This commit is contained in:
Reisen 2023-05-25 17:07:48 +02:00 committed by Reisen
parent 8aeef6e6bd
commit a1dff0f5ac
30 changed files with 2058 additions and 366 deletions

2
hermes/.gitignore vendored
View File

@ -7,4 +7,4 @@ src/network/p2p.proto
tools/
# Ignore Wormhole cloned repo
wormhole/
wormhole*/

4
hermes/Cargo.lock generated
View File

@ -2010,6 +2010,7 @@ dependencies = [
"ethabi",
"futures",
"hex",
"humantime",
"lazy_static",
"libc",
"libp2p",
@ -2030,6 +2031,7 @@ dependencies = [
"serde_qs",
"serde_wormhole",
"sha256",
"sha3 0.10.8",
"solana-account-decoder",
"solana-client",
"solana-sdk",
@ -4216,6 +4218,7 @@ dependencies = [
"bincode",
"borsh",
"bytemuck",
"byteorder",
"fast-math",
"hex",
"rustc_version 0.4.0",
@ -4223,6 +4226,7 @@ dependencies = [
"serde_wormhole",
"sha3 0.10.8",
"slow_primes",
"thiserror",
"wormhole-sdk",
]

View File

@ -69,6 +69,8 @@ pyth-oracle = { git = "https://github.com/pyth-network/pyth-client", rev = "7d59
strum = { version = "0.24", features = ["derive"] }
ethabi = { version = "18.0.0", features = ["serde"] }
sha3 = "0.10.4"
humantime = "2.1.0"
[patch.crates-io]
serde_wormhole = { git = "https://github.com/wormhole-foundation/wormhole", tag = "v2.17.1" }

View File

@ -1,63 +1,79 @@
use std::{
env,
path::PathBuf,
process::Command,
process::{
Command,
Stdio,
},
};
fn main() {
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
let out_var = env::var("OUT_DIR").unwrap();
// Clone the Wormhole repository, which we need to access the protobuf definitions for Wormhole
// P2P message types.
// Download the Wormhole repository at a certain tag, which we need to access the protobuf definitions
// for Wormhole P2P message types.
//
// TODO: This is ugly and costly, and requires git. Instead of this we should have our own tool
// TODO: This is ugly. Instead of this we should have our own tool
// build process that can generate protobuf definitions for this and other user cases. For now
// this is easy and works and matches upstream Wormhole's `Makefile`.
let _ = Command::new("git")
const WORMHOLE_VERSION: &str = "2.18.1";
let wh_curl = Command::new("curl")
.args([
"clone",
"https://github.com/wormhole-foundation/wormhole",
"wormhole",
"-s",
"-L",
format!("https://github.com/wormhole-foundation/wormhole/archive/refs/tags/v{WORMHOLE_VERSION}.tar.gz").as_str(),
])
.stdout(Stdio::piped())
.spawn()
.expect("failed to download wormhole archive");
let _ = Command::new("tar")
.args(["xvz"])
.stdin(Stdio::from(wh_curl.stdout.unwrap()))
.output()
.expect("failed to execute process");
.expect("failed to extract wormhole archive");
// Move the tools directory to the root of the repo because that's where the build script
// expects it to be, paths get hardcoded into the binaries.
let _ = Command::new("mv")
.args(["wormhole/tools", "tools"])
.args([
format!("wormhole-{WORMHOLE_VERSION}/tools").as_str(),
"tools",
])
.output()
.expect("failed to execute process");
.expect("failed to move wormhole tools directory");
// Move the protobuf definitions to the src/network directory, we don't have to do this
// but it is more intuitive when debugging.
let _ = Command::new("mv")
.args([
"wormhole/proto/gossip/v1/gossip.proto",
format!("wormhole-{WORMHOLE_VERSION}/proto/gossip/v1/gossip.proto").as_str(),
"src/network/p2p.proto",
])
.output()
.expect("failed to execute process");
.expect("failed to move wormhole protobuf definitions");
// Build the protobuf compiler.
let _ = Command::new("./build.sh")
.current_dir("tools")
.output()
.expect("failed to execute process");
.expect("failed to run protobuf compiler build script");
// Make the protobuf compiler executable.
let _ = Command::new("chmod")
.args(["+x", "tools/bin/*"])
.output()
.expect("failed to execute process");
.expect("failed to make protofuf compiler executable");
// Generate the protobuf definitions. See buf.gen.yaml to see how we rename the module for our
// particular use case.
let _ = Command::new("./tools/bin/buf")
.args(["generate", "--path", "src"])
.output()
.expect("failed to execute process");
.expect("failed to generate protobuf definitions");
// Build the Go library.
let mut cmd = Command::new("go");

View File

@ -7,7 +7,7 @@ with pkgs; mkShell {
clang
llvmPackages.libclang
nettle
openssl
openssl_1_1
pkgconfig
rustup
systemd

View File

@ -7,6 +7,10 @@ use {
Router,
},
std::sync::Arc,
tokio::{
signal,
sync::mpsc::Receiver,
},
};
mod rest;
@ -15,12 +19,12 @@ mod ws;
#[derive(Clone)]
pub struct State {
pub store: Store,
pub store: Arc<Store>,
pub ws: Arc<ws::WsState>,
}
impl State {
pub fn new(store: Store) -> Self {
pub fn new(store: Arc<Store>) -> Self {
Self {
store,
ws: Arc::new(ws::WsState::new()),
@ -32,7 +36,7 @@ impl State {
///
/// Currently this is based on Axum due to the simplicity and strong ecosystem support for the
/// packages they are based on (tokio & hyper).
pub async fn spawn(store: Store, rpc_addr: String) -> Result<()> {
pub async fn run(store: Arc<Store>, mut update_rx: Receiver<()>, rpc_addr: String) -> Result<()> {
let state = State::new(store);
// Initialize Axum Router. Note the type here is a `Router<State>` due to the use of the
@ -50,28 +54,31 @@ pub async fn spawn(store: Store, rpc_addr: String) -> Result<()> {
.with_state(state.clone());
// Binds the axum's server to the configured address and port. This is a blocking call and will
// not return until the server is shutdown.
tokio::spawn(async move {
// FIXME handle errors properly
axum::Server::bind(&rpc_addr.parse().unwrap())
.serve(app.into_make_service())
.await
.unwrap();
});
// Call dispatch updates to websocket every 1 seconds
// FIXME use a channel to get updates from the store
tokio::spawn(async move {
loop {
dispatch_updates(
state.store.get_price_feed_ids().into_iter().collect(),
state.clone(),
)
.await;
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
// Panics if the update channel is closed, which should never happen.
// If it happens we have no way to recover, so we just panic.
update_rx
.recv()
.await
.expect("state update channel is closed");
dispatch_updates(state.clone()).await;
}
});
// Binds the axum's server to the configured address and port. This is a blocking call and will
// not return until the server is shutdown.
axum::Server::try_bind(&rpc_addr.parse()?)?
.serve(app.into_make_service())
.with_graceful_shutdown(async {
signal::ctrl_c()
.await
.expect("Ctrl-c signal handler failed.");
})
.await?;
Ok(())
}

View File

@ -36,6 +36,7 @@ use {
pub enum RestError {
UpdateDataNotFound,
CcipUpdateDataNotFound,
InvalidCCIPInput,
}
impl IntoResponse for RestError {
@ -54,6 +55,9 @@ impl IntoResponse for RestError {
(StatusCode::BAD_GATEWAY, "CCIP update data not found").into_response()
}
RestError::InvalidCCIPInput => {
(StatusCode::BAD_REQUEST, "Invalid CCIP input").into_response()
}
}
}
}
@ -113,7 +117,7 @@ pub async fn latest_price_feeds(
.price_feeds
.into_iter()
.map(|price_feed| {
RpcPriceFeed::from_price_feed_message(price_feed, params.verbose, params.binary)
RpcPriceFeed::from_price_feed_update(price_feed, params.verbose, params.binary)
})
.collect(),
))
@ -156,7 +160,8 @@ pub async fn get_vaa(
.price_feeds
.get(0)
.ok_or(RestError::UpdateDataNotFound)?
.publish_time; // TODO: This should never happen.
.price_feed
.publish_time;
Ok(Json(GetVaaResponse { vaa, publish_time }))
}
@ -179,8 +184,16 @@ pub async fn get_vaa_ccip(
State(state): State<super::State>,
QsQuery(params): QsQuery<GetVaaCcipQueryParams>,
) -> Result<Json<GetVaaCcipResponse>, RestError> {
let price_id: PriceIdentifier = PriceIdentifier::new(params.data[0..32].try_into().unwrap());
let publish_time = UnixTimestamp::from_be_bytes(params.data[32..40].try_into().unwrap());
let price_id: PriceIdentifier = PriceIdentifier::new(
params.data[0..32]
.try_into()
.map_err(|_| RestError::InvalidCCIPInput)?,
);
let publish_time = UnixTimestamp::from_be_bytes(
params.data[32..40]
.try_into()
.map_err(|_| RestError::InvalidCCIPInput)?,
);
let price_feeds_with_update_data = state
.store

View File

@ -1,17 +1,25 @@
use {
crate::{
impl_deserialize_for_hex_string_wrapper,
store::types::UnixTimestamp,
store::types::{
PriceFeedUpdate,
Slot,
UnixTimestamp,
},
},
base64::{
engine::general_purpose::STANDARD as base64_standard_engine,
Engine as _,
},
derive_more::{
Deref,
DerefMut,
},
pyth_oracle::PriceFeedMessage,
pyth_sdk::{
Price,
PriceIdentifier,
},
wormhole_sdk::Chain,
};
@ -34,8 +42,8 @@ type Base64String = String;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct RpcPriceFeedMetadata {
pub slot: Slot,
pub emitter_chain: u16,
pub sequence_number: u64,
pub price_service_receive_time: UnixTimestamp,
}
@ -54,11 +62,13 @@ pub struct RpcPriceFeed {
impl RpcPriceFeed {
// TODO: Use a Encoding type to have None, Base64, and Hex variants instead of binary flag.
// TODO: Use a Verbosity type to define None, or Full instead of verbose flag.
pub fn from_price_feed_message(
price_feed_message: PriceFeedMessage,
_verbose: bool,
_binary: bool,
pub fn from_price_feed_update(
price_feed_update: PriceFeedUpdate,
verbose: bool,
binary: bool,
) -> Self {
let price_feed_message = price_feed_update.price_feed;
Self {
id: PriceIdentifier::new(price_feed_message.id),
price: Price {
@ -73,16 +83,14 @@ impl RpcPriceFeed {
expo: price_feed_message.exponent,
publish_time: price_feed_message.publish_time,
},
// FIXME: Handle verbose flag properly.
// metadata: verbose.then_some(RpcPriceFeedMetadata {
// emitter_chain: price_feed_message.emitter_chain,
// sequence_number: price_feed_message.sequence_number,
// price_service_receive_time: price_feed_message.receive_time,
// }),
metadata: None,
// FIXME: The vaa is wrong, fix it
// vaa: binary.then_some(base64_standard_engine.encode(message_state.proof_set.wormhole_merkle_proof.vaa)),
vaa: None,
metadata: verbose.then_some(RpcPriceFeedMetadata {
emitter_chain: Chain::Pythnet.into(),
price_service_receive_time: price_feed_update.received_at,
slot: price_feed_update.slot,
}),
vaa: binary.then_some(
base64_standard_engine.encode(price_feed_update.wormhole_merkle_update_data),
),
}
}
}

View File

@ -7,7 +7,10 @@ use {
types::RequestTime,
Store,
},
anyhow::Result,
anyhow::{
anyhow,
Result,
},
axum::{
extract::{
ws::{
@ -23,6 +26,7 @@ use {
futures::{
future::join_all,
stream::{
iter,
SplitSink,
SplitStream,
},
@ -37,9 +41,12 @@ use {
std::{
collections::HashMap,
pin::Pin,
sync::atomic::{
AtomicUsize,
Ordering,
sync::{
atomic::{
AtomicUsize,
Ordering,
},
Arc,
},
time::Duration,
},
@ -63,7 +70,7 @@ async fn websocket_handler(stream: WebSocket, state: super::State) {
// TODO: Use a configured value for the buffer size or make it const static
// TODO: Use redis stream to source the updates instead of a channel
let (tx, rx) = mpsc::channel::<Vec<PriceIdentifier>>(1000);
let (tx, rx) = mpsc::channel::<()>(1000);
ws_state.subscribers.insert(id, tx);
@ -81,8 +88,8 @@ pub type SubscriberId = usize;
pub struct Subscriber {
id: SubscriberId,
closed: bool,
store: Store,
update_rx: mpsc::Receiver<Vec<PriceIdentifier>>,
store: Arc<Store>,
update_rx: mpsc::Receiver<()>,
receiver: SplitStream<WebSocket>,
sender: SplitSink<WebSocket, Message>,
price_feeds_with_config: HashMap<PriceIdentifier, PriceFeedClientConfig>,
@ -93,8 +100,8 @@ pub struct Subscriber {
impl Subscriber {
pub fn new(
id: SubscriberId,
store: Store,
update_rx: mpsc::Receiver<Vec<PriceIdentifier>>,
store: Arc<Store>,
update_rx: mpsc::Receiver<()>,
receiver: SplitStream<WebSocket>,
sender: SplitSink<WebSocket, Message>,
) -> Self {
@ -114,7 +121,7 @@ impl Subscriber {
pub async fn run(&mut self) {
while !self.closed {
if let Err(e) = self.handle_next().await {
log::error!("Subscriber {}: Error handling next message: {}", self.id, e);
log::warn!("Subscriber {}: Error handling next message: {}", self.id, e);
break;
}
}
@ -122,11 +129,11 @@ impl Subscriber {
async fn handle_next(&mut self) -> Result<()> {
tokio::select! {
maybe_update_feed_ids = self.update_rx.recv() => {
let update_feed_ids = maybe_update_feed_ids.ok_or_else(|| {
anyhow::anyhow!("Update channel closed.")
})?;
self.handle_price_feeds_update(update_feed_ids).await?;
maybe_update_feeds = self.update_rx.recv() => {
if maybe_update_feeds.is_none() {
return Err(anyhow!("Update channel closed. This should never happen. Closing connection."));
};
self.handle_price_feeds_update().await?;
},
maybe_message_or_err = self.receiver.next() => {
match maybe_message_or_err {
@ -140,7 +147,7 @@ impl Subscriber {
},
_ = &mut self.ping_interval_future => {
if !self.responded_to_ping {
log::debug!("Subscriber {} did not respond to ping, closing connection.", self.id);
log::debug!("Subscriber {} did not respond to ping. Closing connection.", self.id);
self.closed = true;
return Ok(());
}
@ -153,37 +160,32 @@ impl Subscriber {
Ok(())
}
async fn handle_price_feeds_update(
&mut self,
price_feed_ids: Vec<PriceIdentifier>,
) -> Result<()> {
for price_feed_id in price_feed_ids {
if let Some(config) = self.price_feeds_with_config.get(&price_feed_id) {
async fn handle_price_feeds_update(&mut self) -> Result<()> {
let messages = self
.price_feeds_with_config
.iter()
.map(|(price_feed_id, config)| {
let price_feeds_with_update_data = self
.store
.get_price_feeds_with_update_data(vec![price_feed_id], RequestTime::Latest)?;
let price_feed = *price_feeds_with_update_data
.get_price_feeds_with_update_data(vec![*price_feed_id], RequestTime::Latest)?;
let price_feed = price_feeds_with_update_data
.price_feeds
.iter()
.find(|price_feed| price_feed.id == price_feed_id.to_bytes())
.into_iter()
.next()
.ok_or_else(|| {
anyhow::anyhow!("Price feed {} not found.", price_feed_id.to_string())
})?;
let price_feed = RpcPriceFeed::from_price_feed_message(
price_feed,
config.verbose,
config.binary,
);
// Feed does not flush the message and will allow us
// to send multiple messages in a single flush.
self.sender
.feed(Message::Text(serde_json::to_string(
&ServerMessage::PriceUpdate { price_feed },
)?))
.await?;
}
}
self.sender.flush().await?;
let price_feed =
RpcPriceFeed::from_price_feed_update(price_feed, config.verbose, config.binary);
Ok(Message::Text(serde_json::to_string(
&ServerMessage::PriceUpdate { price_feed },
)?))
})
.collect::<Result<Vec<Message>>>()?;
self.sender
.send_all(&mut iter(messages.into_iter().map(Ok)))
.await?;
Ok(())
}
@ -253,16 +255,15 @@ impl Subscriber {
}
}
pub async fn dispatch_updates(update_feed_ids: Vec<PriceIdentifier>, state: super::State) {
pub async fn dispatch_updates(state: super::State) {
let ws_state = state.ws.clone();
let update_feed_ids_ref = &update_feed_ids;
let closed_subscribers: Vec<Option<SubscriberId>> = join_all(
ws_state
.subscribers
.iter_mut()
.map(|subscriber| async move {
match subscriber.send(update_feed_ids_ref.clone()).await {
match subscriber.send(()).await {
Ok(_) => None,
Err(e) => {
log::debug!("Error sending update to subscriber: {}", e);
@ -289,7 +290,7 @@ pub struct PriceFeedClientConfig {
pub struct WsState {
pub subscriber_counter: AtomicUsize,
pub subscribers: DashMap<SubscriberId, mpsc::Sender<Vec<PriceIdentifier>>>,
pub subscribers: DashMap<SubscriberId, mpsc::Sender<()>>,
}
impl WsState {

View File

@ -44,5 +44,21 @@ pub enum Options {
/// The address to bind the API server to.
#[structopt(long, default_value = "127.0.0.1:33999")]
api_addr: SocketAddr,
/// Ethereum RPC endpoint
#[structopt(long, env = "ETH_RPC_ENDPOINT")]
eth_rpc_endpoint: String,
/// Wormhole contract address on Ethereum
#[structopt(
long,
env = "WORMHOLE_CONTRACT_ADDRESS",
default_value = "0x98f3c9e6E3fAce36bAAd05FE09d375Ef1464288B"
)]
wormhole_eth_contract_address: String,
/// Ethereum RPC polling duration
#[structopt(long, env = "ETH_POLLING_INTERVAL", default_value = "10s")]
eth_polling_interval: humantime::Duration,
},
}

View File

@ -26,12 +26,15 @@ async fn init() -> Result<()> {
wh_bootstrap_addrs,
wh_listen_addrs,
api_addr,
eth_rpc_endpoint,
wormhole_eth_contract_address,
eth_polling_interval,
} => {
log::info!("Running Hermes...");
let store = Store::new_with_local_cache(1000);
// A channel to emit state updates to api
let (update_tx, update_rx) = tokio::sync::mpsc::channel(1000);
// FIXME: Instead of spawing threads separately, we should handle all their
// errors properly.
log::info!("Running Hermes...");
let store = Store::new_with_local_cache(update_tx, 1000);
// Spawn the P2P layer.
log::info!("Starting P2P server on {:?}", wh_listen_addrs);
@ -44,26 +47,25 @@ async fn init() -> Result<()> {
.await?;
// Spawn the Ethereum guardian set watcher
log::info!(
"Starting Ethereum guardian set watcher using {}",
eth_rpc_endpoint
);
network::ethereum::spawn(
store.clone(),
"https://rpc.ankr.com/eth".to_owned(),
"0x98f3c9e6E3fAce36bAAd05FE09d375Ef1464288B".to_owned(),
eth_rpc_endpoint,
wormhole_eth_contract_address,
eth_polling_interval.into(),
)
.await;
// Spawn the RPC server.
log::info!("Starting RPC server on {}", api_addr);
// TODO: Add max size to the config
api::spawn(store.clone(), api_addr.to_string()).await?;
.await?;
// Spawn the Pythnet listener
// TODO: Exit the thread when it gets spawned
log::info!("Starting Pythnet listener using {}", pythnet_ws_endpoint);
network::pythnet::spawn(store.clone(), pythnet_ws_endpoint).await?;
// Wait on Ctrl+C similar to main.
tokio::signal::ctrl_c().await?;
// Run the RPC server and wait for it to shutdown gracefully.
log::info!("Starting RPC server on {}", api_addr);
api::run(store.clone(), update_rx, api_addr.to_string()).await?;
}
}

View File

@ -17,6 +17,11 @@ use {
},
ethabi::Function,
reqwest::Client,
std::{
sync::Arc,
time::Duration,
},
tokio::time::Instant,
wormhole_sdk::GuardianAddress,
};
@ -49,17 +54,25 @@ async fn query(
let res = res
.get("result")
.ok_or(anyhow!("Invalid RPC Response, 'result' not found"))?
.ok_or(anyhow!(
"Invalid RPC Response, 'result' not found. {:?}",
res
))?
.as_str()
.ok_or(anyhow!("Invalid result"))?;
.ok_or(anyhow!("Invalid result. {:?}", res))?;
let res = hex::decode(&res[2..]).unwrap();
let res = hex::decode(&res[2..])?;
let res = method.decode_output(&res)?;
Ok(res)
}
async fn run(store: Store, rpc_endpoint: String, wormhole_contract: String) -> Result<()> {
async fn run(
store: Arc<Store>,
rpc_endpoint: String,
wormhole_contract: String,
polling_interval: Duration,
) -> Result<!> {
loop {
let get_current_index_method = serde_json::from_str::<Function>(
r#"{"inputs":[],"name":"getCurrentGuardianSetIndex","outputs":[{"internalType":"uint32","name":"","type":"uint32"}],
@ -126,32 +139,49 @@ async fn run(store: Store, rpc_endpoint: String, wormhole_contract: String) -> R
log::info!("Guardian set: {:?}", guardian_set);
store
.update_guardian_set(
guardian_set
.into_iter()
.map(|address| GuardianAddress(address.0))
.collect(),
)
.await;
let store = store.clone();
tokio::spawn(async move {
store
.update_guardian_set(
guardian_set
.into_iter()
.map(|address| GuardianAddress(address.0))
.collect(),
)
.await;
});
tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
tokio::time::sleep(polling_interval).await;
}
}
pub async fn spawn(store: Store, rpc_endpoint: String, wormhole_contract: String) {
pub async fn spawn(
store: Arc<Store>,
rpc_endpoint: String,
wormhole_contract: String,
polling_interval: Duration,
) -> Result<()> {
tokio::spawn(async move {
loop {
if let Err(e) = run(
let current_time = Instant::now();
if let Err(ref e) = run(
store.clone(),
rpc_endpoint.clone(),
wormhole_contract.clone(),
polling_interval,
)
.await
{
log::error!("Error in ethereum network: {}", e);
// TODO: Add a backoff here.
log::error!("Error in Ethereum network listener: {:?}", e);
}
if current_time.elapsed() < Duration::from_secs(30) {
log::error!("Ethereum network listener is restarting too quickly. Sleeping for 1s");
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
});
Ok(())
}

View File

@ -28,7 +28,9 @@ import "C"
import (
"context"
"fmt"
"os"
"strings"
"time"
"github.com/libp2p/go-libp2p"
"github.com/libp2p/go-libp2p/core/crypto"
@ -44,6 +46,7 @@ import (
pubsub "github.com/libp2p/go-libp2p-pubsub"
libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
libp2pquic "github.com/libp2p/go-libp2p/p2p/transport/quic"
libp2pquicreuse "github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
)
//export RegisterObservationCallback
@ -52,11 +55,24 @@ func RegisterObservationCallback(f C.callback_t, network_id, bootstrap_addrs, li
bootstrapAddrs := strings.Split(C.GoString(bootstrap_addrs), ",")
listenAddrs := strings.Split(C.GoString(listen_addrs), ",")
go func() {
var startTime int64
var recoverRerun func()
routine := func() {
defer recoverRerun()
// Record the current time
startTime = time.Now().UnixNano()
ctx := context.Background()
// Setup base network configuration.
priv, _, err := crypto.GenerateKeyPair(crypto.Ed25519, -1)
if err != nil {
err := fmt.Errorf("Failed to generate key pair: %w", err)
fmt.Println(err)
return
}
// Setup libp2p Connection Manager.
mgr, err := connmgr.NewConnManager(
@ -76,6 +92,28 @@ func RegisterObservationCallback(f C.callback_t, network_id, bootstrap_addrs, li
libp2p.Identity(priv),
libp2p.ListenAddrStrings(listenAddrs...),
libp2p.Security(libp2ptls.ID, libp2ptls.New),
// Disable Reuse because upon panic, the Close() call on the p2p reactor does not properly clean up the
// open ports (they are kept around for re-use, this seems to be a libp2p bug in the reuse `gc()` call
// which can be found here:
//
// https://github.com/libp2p/go-libp2p/blob/master/p2p/transport/quicreuse/reuse.go#L97
//
// By disabling this we get correct Close() behaviour.
//
// IMPORTANT: Normally re-use allows libp2p to dial on the same port that is used to listen for traffic
// and by disabling this dialing uses a random high port (32768-60999) which causes the nodes that we
// connect to by dialing (instead of them connecting to us) will respond on the high range port instead
// of the specified Dial port. This requires firewalls to be configured to allow (UDP 32768-60999) which
// should be specified in our documentation.
//
// The best way to securely enable this range is via the conntrack module, which can statefully allow
// UDP packets only when a sent UDP packet is present in the conntrack table. This rule looks roughly
// like this:
//
// iptables -A INPUT -m conntrack --ctstate RELATED,ESTABLISHED -j ACCEPT
//
// Which is a standard rule in many firewall configurations (RELATED is the key flag).
libp2p.QUICReuse(libp2pquicreuse.NewConnManager, libp2pquicreuse.DisableReuseport()),
libp2p.Transport(libp2pquic.NewTransport),
libp2p.ConnectionManager(mgr),
libp2p.Routing(func(h host.Host) (routing.PeerRouting, error) {
@ -107,6 +145,8 @@ func RegisterObservationCallback(f C.callback_t, network_id, bootstrap_addrs, li
return
}
defer h.Close()
topic := fmt.Sprintf("%s/%s", networkID, "broadcast")
ps, err := pubsub.NewGossipSub(ctx, h)
if err != nil {
@ -122,6 +162,8 @@ func RegisterObservationCallback(f C.callback_t, network_id, bootstrap_addrs, li
return
}
defer th.Close()
sub, err := th.Subscribe()
if err != nil {
err := fmt.Errorf("Failed to subscribe topic: %w", err)
@ -129,6 +171,8 @@ func RegisterObservationCallback(f C.callback_t, network_id, bootstrap_addrs, li
return
}
defer sub.Cancel()
for {
for {
select {
@ -160,7 +204,25 @@ func RegisterObservationCallback(f C.callback_t, network_id, bootstrap_addrs, li
}
}
}
}()
}
recoverRerun = func() {
// Print the error if any and recall routine
if err := recover(); err != nil {
fmt.Fprintf(os.Stderr, "p2p.go error: %v\n", err)
}
// Sleep for 1 second if the time elapsed is less than 30 seconds
// to avoid spamming the network with requests.
elapsed := time.Duration(time.Now().UnixNano() - startTime)
if elapsed < 30*time.Second {
time.Sleep(1 * time.Second)
}
go routine()
}
go routine()
}
func main() {

View File

@ -26,6 +26,7 @@ use {
Receiver,
Sender,
},
Arc,
Mutex,
},
},
@ -70,8 +71,14 @@ lazy_static::lazy_static! {
extern "C" fn proxy(o: ObservationC) {
// Create a fixed slice from the pointer and length.
let vaa = unsafe { std::slice::from_raw_parts(o.vaa, o.vaa_len) }.to_owned();
// FIXME: Remove unwrap
if let Err(e) = OBSERVATIONS.0.lock().unwrap().send(vaa) {
// The chances of the mutex getting poisioned is very low and if it happens
// there is no way for us to recover from it.
if let Err(e) = OBSERVATIONS
.0
.lock()
.expect("Cannot acquire p2p channel lock")
.send(vaa)
{
log::error!("Failed to send observation: {}", e);
}
}
@ -116,40 +123,46 @@ pub fn bootstrap(
// Spawn's the P2P layer as a separate thread via Go.
pub async fn spawn(
store: Store,
store: Arc<Store>,
network_id: String,
wh_bootstrap_addrs: Vec<Multiaddr>,
wh_listen_addrs: Vec<Multiaddr>,
) -> Result<()> {
bootstrap(network_id, wh_bootstrap_addrs, wh_listen_addrs)?;
std::thread::spawn(|| bootstrap(network_id, wh_bootstrap_addrs, wh_listen_addrs).unwrap());
// Listen in the background for new VAA's from the p2p layer
// and update the state accordingly.
tokio::spawn(async move {
// Listen in the background for new VAA's from the p2p layer
// and update the state accordingly.
loop {
let vaa_bytes = {
let vaa_bytes = tokio::task::spawn_blocking(|| {
let observation = OBSERVATIONS.1.lock();
let observation = match observation {
Ok(observation) => observation,
Err(e) => {
log::error!("Failed to lock observation channel: {}", e);
return;
// This should never happen, but if it does, we want to panic and crash
// as it is not recoverable.
panic!("Failed to lock p2p observation channel: {e}");
}
};
match observation.recv() {
Ok(vaa_bytes) => vaa_bytes,
Err(e) => {
log::error!("Failed to receive observation: {}", e);
return;
// This should never happen, but if it does, we want to panic and crash
// as it is not recoverable.
panic!("Failed to receive p2p observation: {e}");
}
}
};
})
.await
.unwrap();
if let Err(e) = store.store_update(Update::Vaa(vaa_bytes)).await {
log::error!("Failed to process VAA: {:?}", e);
}
let store = store.clone();
tokio::spawn(async move {
if let Err(e) = store.store_update(Update::Vaa(vaa_bytes)).await {
log::error!("Failed to process VAA: {:?}", e);
}
});
}
});

View File

@ -10,7 +10,10 @@ use {
},
Store,
},
anyhow::Result,
anyhow::{
anyhow,
Result,
},
borsh::BorshDeserialize,
futures::stream::StreamExt,
solana_account_decoder::UiAccountEncoding,
@ -32,9 +35,14 @@ use {
pubkey::Pubkey,
system_program,
},
std::{
sync::Arc,
time::Duration,
},
tokio::time::Instant,
};
pub async fn spawn(store: Store, pythnet_ws_endpoint: String) -> Result<()> {
pub async fn run(store: Arc<Store>, pythnet_ws_endpoint: String) -> Result<!> {
let client = PubsubClient::new(pythnet_ws_endpoint.as_ref()).await?;
let config = RpcProgramAccountsConfig {
@ -56,40 +64,73 @@ pub async fn spawn(store: Store, pythnet_ws_endpoint: String) -> Result<()> {
.await?;
loop {
let update = notif.next().await;
log::debug!("Received Pythnet update: {:?}", update);
if let Some(update) = update {
let account: Account = update.value.account.decode().unwrap();
log::debug!("Received Accumulator update: {:?}", account);
let accumulator_messages = AccumulatorMessages::try_from_slice(&account.data);
match accumulator_messages {
Ok(accumulator_messages) => {
let (candidate, _) = Pubkey::find_program_address(
&[
b"AccumulatorState",
&accumulator_messages.ring_index().to_be_bytes(),
],
&system_program::id(),
);
if candidate.to_string() == update.value.pubkey {
store
.store_update(Update::AccumulatorMessages(accumulator_messages))
.await?;
} else {
log::error!(
"Failed to verify the messages public key: {:?} != {:?}",
candidate,
update.value.pubkey
);
match notif.next().await {
Some(update) => {
let account: Account = match update.value.account.decode() {
Some(account) => account,
None => {
log::error!("Failed to decode account from update: {:?}", update);
continue;
}
}
Err(err) => {
log::error!("Failed to parse AccumulatorMessages: {:?}", err);
}
};
};
let accumulator_messages = AccumulatorMessages::try_from_slice(&account.data);
match accumulator_messages {
Ok(accumulator_messages) => {
let (candidate, _) = Pubkey::find_program_address(
&[
b"AccumulatorState",
&accumulator_messages.ring_index().to_be_bytes(),
],
&system_program::id(),
);
if candidate.to_string() == update.value.pubkey {
let store = store.clone();
tokio::spawn(async move {
if let Err(err) = store
.store_update(Update::AccumulatorMessages(accumulator_messages))
.await
{
log::error!("Failed to store accumulator messages: {:?}", err);
}
});
} else {
log::error!(
"Failed to verify the messages public key: {:?} != {:?}",
candidate,
update.value.pubkey
);
}
}
Err(err) => {
log::error!("Failed to parse AccumulatorMessages: {:?}", err);
}
};
}
None => {
return Err(anyhow!("Pythnet network listener terminated"));
}
}
}
}
pub async fn spawn(store: Arc<Store>, pythnet_ws_endpoint: String) -> Result<()> {
tokio::spawn(async move {
loop {
let current_time = Instant::now();
if let Err(ref e) = run(store.clone(), pythnet_ws_endpoint.clone()).await {
log::error!("Error in Pythnet network listener: {:?}", e);
}
if current_time.elapsed() < Duration::from_secs(30) {
log::error!("Pythnet network listener is restarting too quickly. Sleeping for 1s");
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
});
Ok(())
}

View File

@ -1,13 +1,11 @@
use {
self::{
proof::wormhole_merkle::{
construct_update_data,
WormholeMerkleProof,
},
proof::wormhole_merkle::construct_update_data,
storage::StorageInstance,
types::{
AccumulatorMessages,
MessageType,
PriceFeedUpdate,
PriceFeedsWithUpdateData,
RequestTime,
Slot,
@ -22,11 +20,9 @@ use {
types::{
MessageState,
ProofSet,
UnixTimestamp,
},
wormhole::{
parse_and_verify_vaa,
WormholePayload,
},
wormhole::parse_and_verify_vaa,
},
anyhow::{
anyhow,
@ -36,12 +32,24 @@ use {
moka::future::Cache,
pyth_oracle::Message,
pyth_sdk::PriceIdentifier,
pythnet_sdk::payload::v1::{
WormholeMerkleRoot,
WormholeMessage,
WormholePayload,
},
std::{
collections::HashSet,
sync::Arc,
time::Duration,
time::{
Duration,
SystemTime,
UNIX_EPOCH,
},
},
tokio::sync::{
mpsc::Sender,
RwLock,
},
tokio::sync::RwLock,
wormhole_sdk::{
Address,
Chain,
@ -58,30 +66,27 @@ pub mod wormhole;
#[builder(derive(Debug), pattern = "immutable")]
pub struct AccumulatorState {
pub accumulator_messages: AccumulatorMessages,
pub wormhole_merkle_proof: WormholeMerkleProof,
pub wormhole_merkle_proof: (WormholeMerkleRoot, Vec<u8>),
}
#[derive(Clone)]
pub struct Store {
pub storage: StorageInstance,
pub pending_accumulations: Cache<Slot, AccumulatorStateBuilder>,
pub guardian_set: Arc<RwLock<Option<Vec<GuardianAddress>>>>,
pub guardian_set: RwLock<Option<Vec<GuardianAddress>>>,
pub update_tx: Sender<()>,
}
impl Store {
pub fn new_with_local_cache(max_size_per_key: usize) -> Self {
// TODO: Should we return an Arc<Self>? Although we are currently safe to be cloned without
// an Arc but it is easily to miss and cause a bug.
Self {
storage: storage::local_storage::LocalStorage::new_instance(
max_size_per_key,
),
pub fn new_with_local_cache(update_tx: Sender<()>, max_size_per_key: usize) -> Arc<Self> {
Arc::new(Self {
storage: storage::local_storage::LocalStorage::new_instance(max_size_per_key),
pending_accumulations: Cache::builder()
.max_capacity(10_000)
.time_to_live(Duration::from_secs(60 * 5))
.build(), // FIXME: Make this configurable
guardian_set: Arc::new(RwLock::new(None)),
}
guardian_set: RwLock::new(None),
update_tx,
})
}
/// Stores the update data in the store
@ -103,30 +108,19 @@ impl Store {
return Ok(()); // Ignore VAA from other emitters
}
let payload = WormholePayload::try_from_bytes(body.payload, &vaa_bytes)?;
match payload {
match WormholeMessage::try_from_bytes(body.payload)?.payload {
WormholePayload::Merkle(proof) => {
log::info!(
"Storing merkle proof for slot {:?}: {:?}",
proof.slot,
proof
);
store_wormhole_merkle_verified_message(self, proof.clone()).await?;
log::info!("Storing merkle proof for slot {:?}", proof.slot,);
store_wormhole_merkle_verified_message(self, proof.clone(), vaa_bytes)
.await?;
proof.slot
}
}
}
Update::AccumulatorMessages(accumulator_messages) => {
// FIXME: Move this constant to a better place
let slot = accumulator_messages.slot;
log::info!(
"Storing accumulator messages for slot {:?}: {:?}",
slot,
accumulator_messages
);
log::info!("Storing accumulator messages for slot {:?}.", slot,);
let pending_acc = self
.pending_accumulations
@ -154,10 +148,11 @@ impl Store {
Err(_) => return Ok(()),
};
log::info!("State: {:?}", state);
let wormhole_merkle_message_states_proofs = construct_message_states_proofs(state.clone())?;
let current_time: UnixTimestamp =
SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() as _;
let message_states = state
.accumulator_messages
.messages
@ -176,16 +171,19 @@ impl Store {
.clone(),
},
state.accumulator_messages.slot,
current_time,
))
})
.collect::<Result<Vec<_>>>()?;
log::info!("Message states: {:?}", message_states);
log::info!("Message states len: {:?}", message_states.len());
self.storage.store_message_states(message_states)?;
self.pending_accumulations.invalidate(&slot).await;
self.update_tx.send(()).await?;
Ok(())
}
@ -207,11 +205,20 @@ impl Store {
let price_feeds = messages
.iter()
.map(|message_state| match message_state.message {
Message::PriceFeedMessage(price_feed) => Ok(price_feed),
Message::PriceFeedMessage(price_feed) => Ok(PriceFeedUpdate {
price_feed,
received_at: message_state.received_at,
slot: message_state.slot,
wormhole_merkle_update_data: construct_update_data(vec![message_state])?
.into_iter()
.next()
.ok_or(anyhow!("Missing update data for message"))?,
}),
_ => Err(anyhow!("Invalid message state type")),
})
.collect::<Result<Vec<_>>>()?;
let update_data = construct_update_data(messages)?;
let update_data = construct_update_data(messages.iter().collect())?;
Ok(PriceFeedsWithUpdateData {
price_feeds,

View File

@ -8,10 +8,6 @@ use {
anyhow,
Result,
},
byteorder::{
BigEndian,
WriteBytesExt,
},
pythnet_sdk::{
accumulators::{
merkle::{
@ -21,23 +17,16 @@ use {
Accumulator,
},
hashers::keccak256_160::Keccak160,
},
std::io::{
Cursor,
Write,
payload::v1::{
AccumulatorUpdateData,
MerklePriceUpdate,
Proof,
WormholeMerkleRoot,
},
ser::to_vec,
},
};
type Hash = [u8; 20];
#[derive(Clone, PartialEq, Debug)]
pub struct WormholeMerkleProof {
pub vaa: Vec<u8>,
pub slot: u64,
pub ring_size: u32,
pub root: Hash,
}
#[derive(Clone, PartialEq, Debug)]
pub struct WormholeMerkleMessageProof {
pub vaa: Vec<u8>,
@ -46,7 +35,8 @@ pub struct WormholeMerkleMessageProof {
pub async fn store_wormhole_merkle_verified_message(
store: &Store,
proof: WormholeMerkleProof,
proof: WormholeMerkleRoot,
vaa_bytes: Vec<u8>,
) -> Result<()> {
let pending_acc = store
.pending_accumulations
@ -56,7 +46,10 @@ pub async fn store_wormhole_merkle_verified_message(
.into_value();
store
.pending_accumulations
.insert(proof.slot, pending_acc.wormhole_merkle_proof(proof))
.insert(
proof.slot,
pending_acc.wormhole_merkle_proof((proof, vaa_bytes)),
)
.await;
Ok(())
}
@ -76,14 +69,11 @@ pub fn construct_message_states_proofs(
None => return Ok(vec![]), // It only happens when the message set is empty
};
let proof = &state.wormhole_merkle_proof;
let (proof, vaa) = &state.wormhole_merkle_proof;
log::info!(
"Merkle root: {:?}, Verified root: {:?}",
merkle_acc.root,
proof.root
);
log::info!("Valid: {}", merkle_acc.root == proof.root);
if merkle_acc.root != proof.root {
return Err(anyhow!("Invalid merkle root"));
}
state
.accumulator_messages
@ -91,7 +81,7 @@ pub fn construct_message_states_proofs(
.iter()
.map(|m| {
Ok(WormholeMerkleMessageProof {
vaa: state.wormhole_merkle_proof.vaa.clone(),
vaa: vaa.clone(),
proof: merkle_acc
.prove(m.as_ref())
.ok_or(anyhow!("Failed to prove message"))?,
@ -100,7 +90,7 @@ pub fn construct_message_states_proofs(
.collect::<Result<Vec<WormholeMerkleMessageProof>>>()
}
pub fn construct_update_data(mut message_states: Vec<MessageState>) -> Result<Vec<Vec<u8>>> {
pub fn construct_update_data(mut message_states: Vec<&MessageState>) -> Result<Vec<Vec<u8>>> {
message_states.sort_by_key(
|m| m.proof_set.wormhole_merkle_proof.vaa.clone(), // FIXME: This is not efficient
);
@ -118,39 +108,18 @@ pub fn construct_update_data(mut message_states: Vec<MessageState>) -> Result<Ve
.vaa
.clone();
let mut cursor = Cursor::new(Vec::new());
cursor.write_u32::<BigEndian>(0x504e4155)?; // "PNAU"
cursor.write_u8(0x01)?; // Major version
cursor.write_u8(0x00)?; // Minor version
cursor.write_u8(0)?; // Trailing header size
cursor.write_u8(0)?; // Update type of WormholeMerkle. FIXME: Make this out of enum
// Writing VAA
cursor.write_u16::<BigEndian>(vaa.len().try_into()?)?;
cursor.write_all(&vaa)?;
// Writing number of messages
cursor.write_u8(messages.len().try_into()?)?;
for message in messages {
// Writing message
cursor.write_u16::<BigEndian>(message.raw_message.len().try_into()?)?;
cursor.write_all(&message.raw_message)?;
// Writing proof
cursor.write_all(
&message
.proof_set
.wormhole_merkle_proof
.proof
.serialize()
.ok_or(anyhow!("Unable to serialize merkle proof path"))?,
)?;
}
Ok(cursor.into_inner())
Ok(to_vec::<_, byteorder::BE>(&AccumulatorUpdateData::new(
Proof::WormholeMerkle {
vaa: vaa.into(),
updates: messages
.iter()
.map(|message| MerklePriceUpdate {
message: message.raw_message.clone().into(),
proof: message.proof_set.wormhole_merkle_proof.proof.clone(),
})
.collect(),
},
))?)
})
.collect::<Result<Vec<Vec<u8>>>>()
}

View File

@ -7,7 +7,6 @@ use {
},
anyhow::Result,
pyth_sdk::PriceIdentifier,
std::sync::Arc,
};
pub mod local_storage;
@ -30,4 +29,4 @@ pub trait Storage: Send + Sync {
fn keys(&self) -> Vec<MessageIdentifier>;
}
pub type StorageInstance = Arc<Box<dyn Storage>>;
pub type StorageInstance = Box<dyn Storage>;

View File

@ -28,10 +28,10 @@ pub struct LocalStorage {
impl LocalStorage {
pub fn new_instance(max_size_per_key: usize) -> StorageInstance {
Arc::new(Box::new(Self {
Box::new(Self {
cache: Arc::new(DashMap::new()),
max_size_per_key,
}))
})
}
fn retrieve_message_state(

View File

@ -60,7 +60,6 @@ pub struct WormholeMerkleState {
#[derive(Clone, PartialEq, Eq, Debug, Hash)]
pub struct MessageIdentifier {
// -> this is the real message id
pub price_id: PriceIdentifier,
pub type_: MessageType,
}
@ -84,6 +83,7 @@ pub struct MessageState {
pub message: Message,
pub raw_message: RawMessage,
pub proof_set: ProofSet,
pub received_at: UnixTimestamp,
}
impl MessageState {
@ -98,7 +98,13 @@ impl MessageState {
self.id.clone()
}
pub fn new(message: Message, raw_message: RawMessage, proof_set: ProofSet, slot: Slot) -> Self {
pub fn new(
message: Message,
raw_message: RawMessage,
proof_set: ProofSet,
slot: Slot,
received_at: UnixTimestamp,
) -> Self {
Self {
publish_time: message.publish_time(),
slot,
@ -106,6 +112,7 @@ impl MessageState {
message,
raw_message,
proof_set,
received_at,
}
}
}
@ -138,7 +145,17 @@ pub enum Update {
AccumulatorMessages(AccumulatorMessages),
}
pub struct PriceFeedUpdate {
pub price_feed: PriceFeedMessage,
pub slot: Slot,
pub received_at: UnixTimestamp,
/// Wormhole merkle update data for this single price feed update.
/// This field is available for backward compatibility and will be
/// removed in the future.
pub wormhole_merkle_update_data: Vec<u8>,
}
pub struct PriceFeedsWithUpdateData {
pub price_feeds: Vec<PriceFeedMessage>,
pub price_feeds: Vec<PriceFeedUpdate>,
pub wormhole_merkle_update_data: Vec<Vec<u8>>,
}

View File

@ -1,8 +1,5 @@
use {
super::{
proof::wormhole_merkle::WormholeMerkleProof,
Store,
},
super::Store,
anyhow::{
anyhow,
Result,
@ -16,6 +13,10 @@ use {
Secp256k1,
},
serde_wormhole::RawMessage,
sha3::{
Digest,
Keccak256,
},
wormhole_sdk::{
vaa::{
Body,
@ -26,43 +27,6 @@ use {
},
};
#[derive(Clone, Debug, PartialEq)]
pub enum WormholePayload {
Merkle(WormholeMerkleProof),
}
impl WormholePayload {
pub fn try_from_bytes(bytes: &[u8], vaa_bytes: &[u8]) -> Result<Self> {
if bytes.len() != 37 {
return Err(anyhow!("Invalid message length"));
}
// TODO: Use byte string literals for this check
let magic = u32::from_be_bytes(bytes[0..4].try_into()?);
if magic != 0x41555756u32 {
return Err(anyhow!("Invalid magic"));
}
let message_type = u8::from_be_bytes(bytes[4..5].try_into()?);
if message_type != 0 {
return Err(anyhow!("Invalid message type"));
}
let slot = u64::from_be_bytes(bytes[5..13].try_into()?);
let ring_size = u32::from_be_bytes(bytes[13..17].try_into()?);
let root_digest = bytes[17..37].try_into()?;
Ok(Self::Merkle(WormholeMerkleProof {
root: root_digest,
slot,
ring_size,
vaa: vaa_bytes.to_vec(),
}))
}
}
/// Parses and verifies a VAA to ensure it is signed by the Wormhole guardian set.
pub async fn parse_and_verify_vaa<'a>(
store: &Store,
@ -72,33 +36,51 @@ pub async fn parse_and_verify_vaa<'a>(
let (header, body): (Header, Body<&RawMessage>) = vaa.into();
let digest = body.digest()?;
let guardian_set = match store.guardian_set.read().await.as_ref() {
Some(guardian_set) => guardian_set.clone(),
None => {
return Err(anyhow!("Guardian set is not initialized"));
}
};
let mut num_correct_signers = 0;
for sig in header.signatures.iter() {
let signer_id: usize = sig.index.into();
let sig = sig.signature;
// Recover the public key from ecdsa signature from [u8; 65] that has (v, r, s) format
let recid = RecoveryId::from_i32(sig[64].into())?;
// Recover the public key from ecdsa signature from [u8; 65] that has (v, r, s) format
let secp = Secp256k1::new();
let pubkey = &secp
// To get the address we need to use the uncompressed public key
let pubkey: &[u8; 65] = &secp
.recover_ecdsa(
&Message::from_slice(&digest.secp256k_hash)?,
&RecoverableSignature::from_compact(&sig[..64], recid)?,
)?
.serialize();
.serialize_uncompressed();
let address = GuardianAddress(pubkey[pubkey.len() - 20..].try_into()?);
// The address is the last 20 bytes of the Keccak256 hash of the public key
let mut keccak = Keccak256::new();
keccak.update(&pubkey[1..]);
if let Some(guardian_set) = store.guardian_set.read().await.as_ref() {
if guardian_set.get(signer_id) == Some(&address) {
num_correct_signers += 1;
}
let address: [u8; 32] = keccak.finalize().into();
let address = GuardianAddress(address[address.len() - 20..].try_into()?);
if guardian_set.get(signer_id) == Some(&address) {
num_correct_signers += 1;
}
}
if num_correct_signers < header.signatures.len() * 2 / 3 {
return Err(anyhow!("Not enough correct signatures"));
let quorum = (guardian_set.len() * 2 + 2) / 3;
if num_correct_signers < quorum {
return Err(anyhow!(
"Not enough correct signatures. Expected {:?}, received {:?}",
quorum,
num_correct_signers
));
}
Ok(body)

View File

@ -51,25 +51,13 @@ fn hash_null<H: Hasher>() -> H::Hash {
H::hashv(&[NULL_PREFIX])
}
#[derive(Clone, Default, Debug, PartialEq, Eq, Serialize)]
#[derive(Clone, Default, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub struct MerklePath<H: Hasher>(Vec<H::Hash>);
impl<H: Hasher> MerklePath<H> {
pub fn new(path: Vec<H::Hash>) -> Self {
Self(path)
}
pub fn serialize(&self) -> Option<Vec<u8>> {
let mut serialized = vec![];
let proof_size: u8 = self.0.len().try_into().ok()?;
serialized.extend_from_slice(&proof_size.to_be_bytes());
for node in self.0.iter() {
serialized.extend_from_slice(node.as_ref());
}
Some(serialized)
}
}
/// A MerkleAccumulator maintains a Merkle Tree.

View File

@ -0,0 +1,487 @@
use {
byteorder::{
ByteOrder,
ReadBytesExt,
},
serde::{
de::{
EnumAccess,
IntoDeserializer,
MapAccess,
SeqAccess,
VariantAccess,
},
Deserialize,
},
std::io::{
Cursor,
Seek,
SeekFrom,
},
thiserror::Error,
};
pub fn from_slice<'de, B, T>(bytes: &'de [u8]) -> Result<T, DeserializeError>
where
T: Deserialize<'de>,
B: ByteOrder,
{
let mut deserializer = Deserializer::<B>::new(bytes);
T::deserialize(&mut deserializer)
}
#[derive(Debug, Error)]
pub enum DeserializeError {
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("invalid utf8: {0}")]
Utf8(#[from] std::str::Utf8Error),
#[error("this type is not supported")]
Unsupported,
#[error("sequence too large ({0} elements), max supported is 255")]
SequenceTooLarge(usize),
#[error("message: {0}")]
Message(Box<str>),
#[error("eof")]
Eof,
}
pub struct Deserializer<'de, B>
where
B: ByteOrder,
{
cursor: Cursor<&'de [u8]>,
endian: std::marker::PhantomData<B>,
}
impl serde::de::Error for DeserializeError {
fn custom<T: std::fmt::Display>(msg: T) -> Self {
DeserializeError::Message(msg.to_string().into_boxed_str())
}
}
impl<'de, B> Deserializer<'de, B>
where
B: ByteOrder,
{
pub fn new(buffer: &'de [u8]) -> Self {
Self {
cursor: Cursor::new(buffer),
endian: std::marker::PhantomData,
}
}
}
impl<'de, B> serde::de::Deserializer<'de> for &'_ mut Deserializer<'de, B>
where
B: ByteOrder,
{
type Error = DeserializeError;
fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
Err(DeserializeError::Unsupported)
}
fn deserialize_ignored_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
Err(DeserializeError::Unsupported)
}
fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
let value = self.cursor.read_u8().map_err(DeserializeError::from)?;
visitor.visit_bool(value != 0)
}
fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
let value = self.cursor.read_i8().map_err(DeserializeError::from)?;
visitor.visit_i8(value)
}
fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
let value = self
.cursor
.read_i16::<B>()
.map_err(DeserializeError::from)?;
visitor.visit_i16(value)
}
fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
let value = self
.cursor
.read_i32::<B>()
.map_err(DeserializeError::from)?;
visitor.visit_i32(value)
}
fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
let value = self
.cursor
.read_i64::<B>()
.map_err(DeserializeError::from)?;
visitor.visit_i64(value)
}
fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
let value = self.cursor.read_u8().map_err(DeserializeError::from)?;
visitor.visit_u8(value)
}
fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
let value = self
.cursor
.read_u16::<B>()
.map_err(DeserializeError::from)?;
visitor.visit_u16(value)
}
fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
let value = self
.cursor
.read_u32::<B>()
.map_err(DeserializeError::from)?;
visitor.visit_u32(value)
}
fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
let value = self
.cursor
.read_u64::<B>()
.map_err(DeserializeError::from)?;
visitor.visit_u64(value)
}
fn deserialize_f32<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
Err(DeserializeError::Unsupported)
}
fn deserialize_f64<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
Err(DeserializeError::Unsupported)
}
fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
Err(DeserializeError::Unsupported)
}
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
let len = self.cursor.read_u8().map_err(DeserializeError::from)? as u64;
// Because cursor read methods copy the data out of the internal buffer,
// where we want a pointer, we need to move the cursor forward first
// (because the ) get_ref() call triggers a borrow), and then read
// after.
self.cursor
.seek(SeekFrom::Current(len as i64))
.map_err(DeserializeError::from)?;
let buf = {
let buf = self.cursor.get_ref();
buf[(self.cursor.position() - len) as usize..]
.get(..len as usize)
.ok_or(DeserializeError::Eof)?
};
visitor.visit_borrowed_str(std::str::from_utf8(buf).map_err(DeserializeError::from)?)
}
fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_str(visitor)
}
fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
let len = self.cursor.read_u8().map_err(DeserializeError::from)? as u64;
// See comment in deserialize_str for the reason for the subtraction.
self.cursor
.seek(SeekFrom::Current(len as i64))
.map_err(DeserializeError::from)?;
let buf = {
let buf = self.cursor.get_ref();
buf[(self.cursor.position() - len) as usize..]
.get(..len as usize)
.ok_or(DeserializeError::Eof)?
};
visitor.visit_borrowed_bytes(buf)
}
fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_bytes(visitor)
}
fn deserialize_option<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
Err(DeserializeError::Unsupported)
}
fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_unit()
}
fn deserialize_unit_struct<V>(
self,
_name: &'static str,
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_unit()
}
fn deserialize_newtype_struct<V>(
self,
_name: &'static str,
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_newtype_struct(self)
}
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
let len = self.cursor.read_u8().map_err(DeserializeError::from)? as usize;
visitor.visit_seq(SequenceIterator::new(self, len))
}
fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_seq(SequenceIterator::new(self, len))
}
fn deserialize_tuple_struct<V>(
self,
_name: &'static str,
len: usize,
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_seq(SequenceIterator::new(self, len))
}
fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
let len = self.cursor.read_u8().map_err(DeserializeError::from)? as usize;
visitor.visit_map(SequenceIterator::new(self, len))
}
fn deserialize_struct<V>(
self,
_name: &'static str,
fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_seq(SequenceIterator::new(self, fields.len()))
}
fn deserialize_enum<V>(
self,
_name: &'static str,
_variants: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
let variant = self.cursor.read_u8().map_err(DeserializeError::from)?;
visitor.visit_enum(Enum { de: self, variant })
}
fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
Err(DeserializeError::Unsupported)
}
}
impl<'de, 'a, B: ByteOrder> VariantAccess<'de> for &'a mut Deserializer<'de, B> {
type Error = DeserializeError;
fn unit_variant(self) -> Result<(), Self::Error> {
Ok(())
}
fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
where
T: serde::de::DeserializeSeed<'de>,
{
seed.deserialize(self)
}
fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_seq(SequenceIterator::new(self, len))
}
fn struct_variant<V>(
self,
fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_seq(SequenceIterator::new(self, fields.len()))
}
}
struct SequenceIterator<'de, 'a, B: ByteOrder> {
de: &'a mut Deserializer<'de, B>,
len: usize,
}
impl<'de, 'a, B: ByteOrder> SequenceIterator<'de, 'a, B> {
fn new(de: &'a mut Deserializer<'de, B>, len: usize) -> Self {
Self { de, len }
}
}
impl<'de, 'a, B: ByteOrder> SeqAccess<'de> for SequenceIterator<'de, 'a, B> {
type Error = DeserializeError;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
where
T: serde::de::DeserializeSeed<'de>,
{
if self.len == 0 {
return Ok(None);
}
self.len -= 1;
seed.deserialize(&mut *self.de).map(Some)
}
fn size_hint(&self) -> Option<usize> {
Some(self.len)
}
}
impl<'de, 'a, B: ByteOrder> MapAccess<'de> for SequenceIterator<'de, 'a, B> {
type Error = DeserializeError;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
where
K: serde::de::DeserializeSeed<'de>,
{
if self.len == 0 {
return Ok(None);
}
self.len -= 1;
seed.deserialize(&mut *self.de).map(Some)
}
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
where
V: serde::de::DeserializeSeed<'de>,
{
seed.deserialize(&mut *self.de)
}
fn size_hint(&self) -> Option<usize> {
Some(self.len)
}
}
struct Enum<'de, 'a, B: ByteOrder> {
de: &'a mut Deserializer<'de, B>,
variant: u8,
}
impl<'de, 'a, B: ByteOrder> EnumAccess<'de> for Enum<'de, 'a, B> {
type Error = DeserializeError;
type Variant = &'a mut Deserializer<'de, B>;
fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
where
V: serde::de::DeserializeSeed<'de>,
{
seed.deserialize(self.variant.into_deserializer())
.map(|v| (v, self.de))
}
}

View File

@ -0,0 +1,19 @@
use thiserror::Error;
#[derive(Error, Debug)]
pub enum Error {
#[error("Invalid Magic")]
InvalidMagic,
#[error("Invalid Version")]
InvalidVersion,
}
#[macro_export]
macro_rules! require {
($cond:expr, $err:expr) => {
if !$cond {
return Err($err);
}
};
}

View File

@ -28,6 +28,7 @@ where
+ Debug
+ Default
+ Eq
+ std::hash::Hash
+ PartialOrd
+ PartialEq
+ serde::Serialize

View File

@ -7,7 +7,7 @@ use {
},
};
#[derive(Clone, Default, Debug, Eq, PartialEq, Serialize)]
#[derive(Clone, Default, Debug, Eq, Hash, PartialEq, Serialize)]
pub struct Keccak160 {}
impl Hasher for Keccak160 {

View File

@ -1,4 +1,5 @@
pub mod accumulators;
pub mod error;
pub mod hashers;
pub mod wire;
pub mod wormhole;

View File

@ -0,0 +1,109 @@
//! Definition of the Accumulator Payload Formats.
//!
//! This module defines the data types that are injected into VAA's to be sent to other chains via
//! Wormhole. The wire format for these types must be backwards compatible and so all tyeps in this
//! module are expected to be append-only (for minor changes) and versioned for breaking changes.
use {
crate::{
error::Error,
require,
ser::PrefixedVec,
},
serde::{
Deserialize,
Serialize,
},
};
// Proof Format (V1)
// --------------------------------------------------------------------------------
// The definitions within each module can be updated with append-only data without requiring a new
// module to be defined. So for example, it is possible to add new fields can be added to the end
// of the `AccumulatorAccount` without moving to a `v1`.
pub mod v1 {
use {
super::*,
crate::{
accumulators::merkle::MerklePath,
de::from_slice,
hashers::keccak256_160::Keccak160,
},
};
// Transfer Format.
// --------------------------------------------------------------------------------
// This definition is what will be sent over the wire (I.E, pulled from PythNet and submitted
// to target chains).
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub struct AccumulatorUpdateData {
magic: [u8; 4],
major_version: u8,
minor_version: u8,
trailing: Vec<u8>,
proof: Proof,
}
impl AccumulatorUpdateData {
pub fn new(proof: Proof) -> Self {
Self {
magic: *b"PNAU",
major_version: 1,
minor_version: 0,
trailing: vec![],
proof,
}
}
pub fn try_from_slice(bytes: &[u8]) -> Result<Self, Error> {
let message = from_slice::<byteorder::BE, Self>(bytes).unwrap();
require!(&message.magic[..] != b"PNAU", Error::InvalidMagic);
require!(message.major_version == 1, Error::InvalidVersion);
require!(message.minor_version == 0, Error::InvalidVersion);
Ok(message)
}
}
// A hash of some data.
pub type Hash = [u8; 20];
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub enum Proof {
WormholeMerkle {
vaa: PrefixedVec<u16, u8>,
updates: Vec<MerklePriceUpdate>,
},
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub struct MerklePriceUpdate {
pub message: PrefixedVec<u16, u8>,
pub proof: MerklePath<Keccak160>,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub struct WormholeMessage {
pub magic: [u8; 4],
pub payload: WormholePayload,
}
impl WormholeMessage {
pub fn try_from_bytes(bytes: impl AsRef<[u8]>) -> Result<Self, Error> {
let message = from_slice::<byteorder::BE, Self>(bytes.as_ref()).unwrap();
require!(&message.magic[..] == b"AUWV", Error::InvalidMagic);
Ok(message)
}
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub enum WormholePayload {
Merkle(WormholeMerkleRoot),
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub struct WormholeMerkleRoot {
pub slot: u64,
pub ring_size: u32,
pub root: Hash,
}
}

View File

@ -0,0 +1,893 @@
//! A module defining serde serialize for a simple Rust struct-like message format. The format will
//! read Rust types exactly the size they are, and reads sequences by reading a u8 length followed
//! by a count of the elements of the vector.
//!
//! TL;DR: How to Use
//! ================================================================================
//!
//! ```rust,ignore
//! #[derive(Serialize)]
//! struct ExampleStruct {
//! a: (),
//! b: bool,
//! c: u8,
//! ...,
//! }
//!
//! let mut buf = Vec::new();
//! let mut cur = Cursor::new(&mut buf);
//! pythnet_sdk::ser::to_writer(&mut cur, &ExampleStruct { ... }).unwrap();
//! let result = pythnet_sdk::ser::to_vec(&ExampleStruct { ... }).unwrap();
//! ```
//!
//! A quick primer on `serde::Serialize`:
//! ================================================================================
//!
//! Given some type `T`, the `serde::Serialize` derives an implementation with a `serialize` method
//! that calls all the relevant `serialize_` calls defined in this file, so for example, given the
//! following types:
//!
//! ```rust,ignore
//! #[derive(Serialize)]
//! enum ExampleEnum {
//! A,
//! B(u8),
//! C(u8, u8),
//! D { a: u8, b: u8 },
//! }
//!
//! #[derive(Serialize)]
//! struct ExampleStruct {
//! a: (),
//! b: bool,
//! c: u8,
//! d: &str,
//! e: ExampleEnum
//! }
//! ```
//!
//! The macro will expand into (a more complicated but equivelent) version of:
//!
//! ```rust,ignore
//! impl serde::Serialize for ExampleEnum {
//! fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
//! match self {
//! ExampleEnum::A => serializer.serialize_unit_variant("ExampleEnum", 0, "A"),
//! ExampleEnum::B(v) => serializer.serialize_newtype_variant("ExampleEnum", 1, "B", v),
//! ExampleEnum::C(v0, v1) => serializer.serialize_tuple_variant("ExampleEnum", 2, "C", (v0, v1)),
//! ExampleEnum::D { a, b } => serializer.serialize_struct_variant("ExampleEnum", 3, "D", 2, "a", a, "b", b),
//! }
//! }
//! }
//!
//! impl serde::Serialize for ExampleStruct {
//! fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
//! let mut state = serializer.serialize_struct("ExampleStruct", 5)?;
//! state.serialize_field("a", &self.a)?;
//! state.serialize_field("b", &self.b)?;
//! state.serialize_field("c", &self.c)?;
//! state.serialize_field("d", &self.d)?;
//! state.serialize_field("e", &self.e)?;
//! state.end()
//! }
//! }
//! ```
//!
//! Note that any parser can be passed in, which gives the serializer the ability to serialize to
//! any format we desire as long as there is a `Serializer` implementation for it. With aggressive
//! inlining, the compiler will be able to optimize away the intermediate state objects and calls
//! to `serialize_field` and `serialize_*_variant` and the final result of our parser will have
//! very close to equivelent performance to a hand written implementation.
//!
//! The Pyth Serialization Format
//! ================================================================================
//!
//! Pyth has various data formats that are serialized in compact forms to make storing or passing
//! cross-chain cheaper. So far all these formats follow a similar pattern and so this serializer
//! is designed to be able to provide a canonical implementation for all of them.
//!
//!
//! Format Spec:
//! --------------------------------------------------------------------------------
//!
//! Integers:
//!
//! - `{u,i}8` are serialized as a single byte
//! - `{u,i}16/32/64` are serialized as bytes specified by the parser endianess type param.
//! - `{u,i}128` is not supported.
//!
//! Floats:
//!
//! - `f32/64/128` are not supported due to different chains having different float formats.
//!
//! Strings:
//!
//! - `&str` is serialized as a u8 length followed by the bytes of the string.
//! - `String` is serialized as a u8 length followed by the bytes of the string.
//!
//! Sequences:
//!
//! - `Vec<T>` is serialized as a u8 length followed by the serialized elements of the vector.
//! - `&[T]` is serialized as a u8 length followed by the serialized elements of the slice.
//!
//! Enums:
//!
//! - `enum` is serialized as a u8 variant index followed by the serialized variant data.
//! - `Option<T>` is serialized as a u8 variant index followed by the serialized variant data.
//!
//! Structs:
//!
//! - `struct` is serialized as the serialized fields of the struct in order.
//!
//! Tuples:
//!
//! - `tuple` is serialized as the serialized elements of the tuple in order.
//!
//! Unit:
//!
//! - `()` is serialized as nothing.
//!
//!
//! Example Usage
//! --------------------------------------------------------------------------------
//!
//! ```rust,ignore
//! fn example(data: &[u8]) {
//! let mut buf = Vec::new();
//! let mut cur = Cursor::new(&mut buf);
//! let mut des = Deserializer::new(&mut cur);
//! let mut result = des.deserialize::<ExampleStruct>(data).unwrap();
//! ...
//! }
//! ```
use {
byteorder::{
ByteOrder,
WriteBytesExt,
},
serde::{
ser::{
SerializeMap,
SerializeSeq,
SerializeStruct,
SerializeStructVariant,
SerializeTuple,
SerializeTupleStruct,
SerializeTupleVariant,
},
Serialize,
},
std::{
fmt::Display,
io::Write,
},
thiserror::Error,
};
pub fn to_vec<T, B>(value: &T) -> Result<Vec<u8>, SerializeError>
where
T: Serialize,
B: ByteOrder,
{
let mut buf = Vec::new();
value.serialize(&mut Serializer::<_, B>::new(&mut buf))?;
Ok(buf)
}
#[derive(Debug, Error)]
pub enum SerializeError {
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("this type is not supported")]
Unsupported,
#[error("sequence too large ({0} elements), max supported is 255")]
SequenceTooLarge(usize),
#[error("sequence length must be known before serializing")]
SequenceLengthUnknown,
#[error("enum variant {0}::{1} cannot be parsed as `u8`: {2}")]
InvalidEnumVariant(&'static str, u32, &'static str),
#[error("message: {0}")]
Message(Box<str>),
}
#[derive(Clone, Debug, Hash, PartialEq, PartialOrd, Error, serde::Deserialize)]
pub struct PrefixedVec<L, T> {
data: PrefixlessVec<T>,
__phantom: std::marker::PhantomData<L>,
}
impl<T, L> From<Vec<T>> for PrefixedVec<L, T> {
fn from(data: Vec<T>) -> Self {
Self {
data: PrefixlessVec { data },
__phantom: std::marker::PhantomData,
}
}
}
#[derive(Clone, Debug, Hash, PartialEq, PartialOrd, Error, serde::Deserialize)]
struct PrefixlessVec<T> {
data: Vec<T>,
}
impl<L, T> Serialize for PrefixedVec<L, T>
where
T: Serialize,
L: Serialize,
L: TryFrom<usize>,
<L as TryFrom<usize>>::Error: std::fmt::Debug,
{
#[inline]
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let len: L = L::try_from(self.data.data.len()).unwrap();
let mut st = serializer.serialize_struct("SizedVec", 1)?;
st.serialize_field("len", &len)?;
st.serialize_field("data", &self.data)?;
st.end()
}
}
impl<T> Serialize for PrefixlessVec<T>
where
T: Serialize,
{
#[inline]
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut seq = serializer.serialize_seq(None)?;
for item in &self.data {
seq.serialize_element(item)?;
}
seq.end()
}
}
/// A type for Pyth's common serialization format. Note that a ByteOrder type param is required as
/// we serialize in both big and little endian depending on different use-cases.
#[derive(Clone)]
pub struct Serializer<W: Write, B: ByteOrder> {
writer: W,
_endian: std::marker::PhantomData<B>,
}
impl serde::ser::Error for SerializeError {
fn custom<T: Display>(msg: T) -> Self {
SerializeError::Message(msg.to_string().into_boxed_str())
}
}
impl<W: Write, B: ByteOrder> Serializer<W, B> {
pub fn new(writer: W) -> Self {
Self {
writer,
_endian: std::marker::PhantomData,
}
}
}
impl<'a, W: Write, B: ByteOrder> serde::Serializer for &'a mut Serializer<W, B> {
type Ok = ();
type Error = SerializeError;
// Serde uses different types for different parse targets to allow for different
// implementations. We only support one target, so we can set all these to `Self`
// and implement those traits on the same type.
type SerializeSeq = Self;
type SerializeTuple = Self;
type SerializeTupleStruct = Self;
type SerializeTupleVariant = Self;
type SerializeMap = Self;
type SerializeStruct = Self;
type SerializeStructVariant = Self;
#[inline]
fn serialize_bool(self, v: bool) -> Result<Self::Ok, Self::Error> {
self.writer
.write_all(&[v as u8])
.map_err(SerializeError::from)
}
#[inline]
fn serialize_i8(self, v: i8) -> Result<Self::Ok, Self::Error> {
self.writer
.write_all(&[v as u8])
.map_err(SerializeError::from)
}
#[inline]
fn serialize_i16(self, v: i16) -> Result<Self::Ok, Self::Error> {
self.writer.write_i16::<B>(v).map_err(SerializeError::from)
}
#[inline]
fn serialize_i32(self, v: i32) -> Result<Self::Ok, Self::Error> {
self.writer.write_i32::<B>(v).map_err(SerializeError::from)
}
#[inline]
fn serialize_i64(self, v: i64) -> Result<Self::Ok, Self::Error> {
self.writer.write_i64::<B>(v).map_err(SerializeError::from)
}
#[inline]
fn serialize_u8(self, v: u8) -> Result<Self::Ok, Self::Error> {
self.writer.write_all(&[v]).map_err(SerializeError::from)
}
#[inline]
fn serialize_u16(self, v: u16) -> Result<Self::Ok, Self::Error> {
self.writer.write_u16::<B>(v).map_err(SerializeError::from)
}
#[inline]
fn serialize_u32(self, v: u32) -> Result<Self::Ok, Self::Error> {
self.writer.write_u32::<B>(v).map_err(SerializeError::from)
}
#[inline]
fn serialize_u64(self, v: u64) -> Result<Self::Ok, Self::Error> {
self.writer.write_u64::<B>(v).map_err(SerializeError::from)
}
#[inline]
fn serialize_f32(self, _: f32) -> Result<Self::Ok, Self::Error> {
Err(SerializeError::Unsupported)
}
#[inline]
fn serialize_f64(self, _: f64) -> Result<Self::Ok, Self::Error> {
Err(SerializeError::Unsupported)
}
#[inline]
fn serialize_char(self, _: char) -> Result<Self::Ok, Self::Error> {
Err(SerializeError::Unsupported)
}
#[inline]
fn serialize_str(self, v: &str) -> Result<Self::Ok, Self::Error> {
let len = u8::try_from(v.len()).map_err(|_| SerializeError::SequenceTooLarge(v.len()))?;
self.writer.write_all(&[len])?;
self.writer
.write_all(v.as_bytes())
.map_err(SerializeError::from)
}
#[inline]
fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
let len = u8::try_from(v.len()).map_err(|_| SerializeError::SequenceTooLarge(v.len()))?;
self.writer.write_all(&[len])?;
self.writer.write_all(v).map_err(SerializeError::from)
}
#[inline]
fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
Err(SerializeError::Unsupported)
}
#[inline]
fn serialize_some<T: ?Sized + Serialize>(self, value: &T) -> Result<Self::Ok, Self::Error> {
value.serialize(self)
}
#[inline]
fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
Ok(())
}
#[inline]
fn serialize_unit_struct(self, _name: &'static str) -> Result<Self::Ok, Self::Error> {
Ok(())
}
#[inline]
fn serialize_unit_variant(
self,
name: &'static str,
variant_index: u32,
variant: &'static str,
) -> Result<Self::Ok, Self::Error> {
let variant: u8 = variant_index
.try_into()
.map_err(|_| SerializeError::InvalidEnumVariant(name, variant_index, variant))?;
self.writer
.write_all(&[variant])
.map_err(SerializeError::from)
}
#[inline]
fn serialize_newtype_struct<T: ?Sized + Serialize>(
self,
_name: &'static str,
value: &T,
) -> Result<Self::Ok, Self::Error> {
value.serialize(self)
}
#[inline]
fn serialize_newtype_variant<T: ?Sized + Serialize>(
self,
name: &'static str,
variant_index: u32,
variant: &'static str,
value: &T,
) -> Result<Self::Ok, Self::Error> {
let variant: u8 = variant_index
.try_into()
.map_err(|_| SerializeError::InvalidEnumVariant(name, variant_index, variant))?;
self.writer.write_all(&[variant])?;
value.serialize(self)
}
#[inline]
fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> {
if let Some(len) = len {
let len = u8::try_from(len).map_err(|_| SerializeError::SequenceTooLarge(len))?;
self.writer.write_all(&[len])?;
}
Ok(self)
}
#[inline]
fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple, Self::Error> {
Ok(self)
}
#[inline]
fn serialize_tuple_struct(
self,
_name: &'static str,
_len: usize,
) -> Result<Self::SerializeTupleStruct, Self::Error> {
Ok(self)
}
#[inline]
fn serialize_tuple_variant(
self,
name: &'static str,
variant_index: u32,
variant: &'static str,
_len: usize,
) -> Result<Self::SerializeTupleVariant, Self::Error> {
let variant: u8 = variant_index
.try_into()
.map_err(|_| SerializeError::InvalidEnumVariant(name, variant_index, variant))?;
self.writer.write_all(&[variant])?;
Ok(self)
}
#[inline]
fn serialize_map(self, len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> {
let len = len
.ok_or(SerializeError::SequenceLengthUnknown)
.and_then(|len| u8::try_from(len).map_err(|_| SerializeError::SequenceTooLarge(len)))?;
self.writer.write_all(&[len])?;
Ok(self)
}
#[inline]
fn serialize_struct(
self,
_name: &'static str,
_len: usize,
) -> Result<Self::SerializeStruct, Self::Error> {
Ok(self)
}
#[inline]
fn serialize_struct_variant(
self,
name: &'static str,
variant_index: u32,
variant: &'static str,
_len: usize,
) -> Result<Self::SerializeStructVariant, Self::Error> {
let variant: u8 = variant_index
.try_into()
.map_err(|_| SerializeError::InvalidEnumVariant(name, variant_index, variant))?;
self.writer.write_all(&[variant])?;
Ok(self)
}
fn is_human_readable(&self) -> bool {
false
}
fn collect_str<T: ?Sized + Display>(self, value: &T) -> Result<Self::Ok, Self::Error> {
self.serialize_str(&value.to_string())
}
}
impl<'a, W: Write, B: ByteOrder> SerializeSeq for &'a mut Serializer<W, B> {
type Ok = ();
type Error = SerializeError;
#[inline]
fn serialize_element<T: ?Sized + Serialize>(&mut self, value: &T) -> Result<(), Self::Error> {
value.serialize(&mut **self)
}
fn end(self) -> Result<Self::Ok, Self::Error> {
Ok(())
}
}
impl<'a, W: Write, B: ByteOrder> SerializeTuple for &'a mut Serializer<W, B> {
type Ok = ();
type Error = SerializeError;
#[inline]
fn serialize_element<T: ?Sized + Serialize>(&mut self, value: &T) -> Result<(), Self::Error> {
value.serialize(&mut **self)
}
fn end(self) -> Result<Self::Ok, Self::Error> {
Ok(())
}
}
impl<'a, W: Write, B: ByteOrder> SerializeTupleStruct for &'a mut Serializer<W, B> {
type Ok = ();
type Error = SerializeError;
#[inline]
fn serialize_field<T: ?Sized + Serialize>(&mut self, value: &T) -> Result<(), Self::Error> {
value.serialize(&mut **self)
}
fn end(self) -> Result<Self::Ok, Self::Error> {
Ok(())
}
}
impl<'a, W: Write, B: ByteOrder> SerializeTupleVariant for &'a mut Serializer<W, B> {
type Ok = ();
type Error = SerializeError;
#[inline]
fn serialize_field<T: ?Sized + Serialize>(&mut self, value: &T) -> Result<(), Self::Error> {
value.serialize(&mut **self)
}
fn end(self) -> Result<Self::Ok, Self::Error> {
Ok(())
}
}
impl<'a, W: Write, B: ByteOrder> SerializeMap for &'a mut Serializer<W, B> {
type Ok = ();
type Error = SerializeError;
#[inline]
fn serialize_key<T: ?Sized + Serialize>(&mut self, key: &T) -> Result<(), Self::Error> {
key.serialize(&mut **self)
}
#[inline]
fn serialize_value<T: ?Sized + Serialize>(&mut self, value: &T) -> Result<(), Self::Error> {
value.serialize(&mut **self)
}
fn end(self) -> Result<Self::Ok, Self::Error> {
Ok(())
}
}
impl<'a, W: Write, B: ByteOrder> SerializeStruct for &'a mut Serializer<W, B> {
type Ok = ();
type Error = SerializeError;
#[inline]
fn serialize_field<T: ?Sized + Serialize>(
&mut self,
_key: &'static str,
value: &T,
) -> Result<(), Self::Error> {
value.serialize(&mut **self)
}
fn end(self) -> Result<Self::Ok, Self::Error> {
Ok(())
}
}
impl<'a, W: Write, B: ByteOrder> SerializeStructVariant for &'a mut Serializer<W, B> {
type Ok = ();
type Error = SerializeError;
#[inline]
fn serialize_field<T: ?Sized + Serialize>(
&mut self,
_key: &'static str,
value: &T,
) -> Result<(), Self::Error> {
value.serialize(&mut **self)
}
fn end(self) -> Result<Self::Ok, Self::Error> {
Ok(())
}
}
// By default, serde does not know how to parse fixed length arrays of sizes
// that aren't common (I.E: 32) Here we provide a module that can be used to
// serialize arrays that relies on const generics.
//
// Usage:
//
// ```rust,ignore`
// #[derive(Serialize)]
// struct Example {
// #[serde(with = "array")]
// array: [u8; 55],
// }
// ```
pub mod array {
use std::mem::MaybeUninit;
/// Serialize an array of size N using a const generic parameter to drive serialize_seq.
pub fn serialize<S, T, const N: usize>(array: &[T; N], serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
T: serde::Serialize,
{
use serde::ser::SerializeTuple;
let mut seq = serializer.serialize_tuple(N)?;
array.iter().try_for_each(|e| seq.serialize_element(e))?;
seq.end()
}
/// A Marer type that carries type-level information about the length of the
/// array we want to deserialize.
struct ArrayVisitor<T, const N: usize> {
_marker: std::marker::PhantomData<T>,
}
/// Implement a Visitor over our ArrayVisitor that knows how many times to
/// call next_element using the generic.
impl<'de, T, const N: usize> serde::de::Visitor<'de> for ArrayVisitor<T, N>
where
T: serde::de::Deserialize<'de>,
{
type Value = [T; N];
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(formatter, "an array of length {}", N)
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
// We use MaybeUninit to allocate the right amount of memory
// because we do not know if `T` has a constructor or a default.
// Without this we would have to allocate a Vec.
let mut array = MaybeUninit::<[T; N]>::uninit();
let ptr = array.as_mut_ptr() as *mut T;
let mut pos = 0;
while pos < N {
let next = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(pos, &self))?;
unsafe {
std::ptr::write(ptr.add(pos), next);
}
pos += 1;
}
// We only succeed if we fully filled the array. This prevents
// accidentally returning garbage.
if pos == N {
return Ok(unsafe { array.assume_init() });
}
Err(serde::de::Error::invalid_length(pos, &self))
}
}
/// Deserialize an array with an ArrayVisitor aware of `N` during deserialize.
pub fn deserialize<'de, D, T, const N: usize>(deserializer: D) -> Result<[T; N], D::Error>
where
D: serde::Deserializer<'de>,
T: serde::de::Deserialize<'de>,
{
deserializer.deserialize_tuple(
N,
ArrayVisitor {
_marker: std::marker::PhantomData,
},
)
}
}
#[cfg(test)]
mod tests {
use super::{
super::de::Deserializer,
array,
Serializer,
};
// Test the arbitrary fixed sized array serialization implementation.
#[test]
fn test_array_serde() {
// Serialize an array into a buffer.
let mut buffer = Vec::new();
let mut cursor = std::io::Cursor::new(&mut buffer);
let mut serializer: Serializer<_, byteorder::LE> = Serializer::new(&mut cursor);
array::serialize(&[1u8; 37], &mut serializer).unwrap();
// The result should not have been prefixed with a length byte.
assert_eq!(buffer.len(), 37);
// We should also be able to deserialize it back.
let mut deserializer = Deserializer::<byteorder::LE>::new(&buffer);
let deserialized: [u8; 37] = array::deserialize(&mut deserializer).unwrap();
// The deserialized array should be the same as the original.
assert_eq!(deserialized, [1u8; 37]);
}
// The array serializer should not interfere with other serializers. Here we
// check serde_json to make sure an array is written as expected.
#[test]
fn test_array_serde_json() {
// Serialize an array into a buffer.
let mut buffer = Vec::new();
let mut cursor = std::io::Cursor::new(&mut buffer);
let mut serialized = serde_json::Serializer::new(&mut cursor);
array::serialize(&[1u8; 7], &mut serialized).unwrap();
let result = String::from_utf8(buffer).unwrap();
assert_eq!(result, "[1,1,1,1,1,1,1]");
// Deserializing should also work.
let mut deserializer = serde_json::Deserializer::from_str(&result);
let deserialized: [u8; 7] = array::deserialize(&mut deserializer).unwrap();
assert_eq!(deserialized, [1u8; 7]);
}
// Golden Structure Test
//
// This test serializes a struct containing all the expected types we should
// be able to handle and checks the output is as expected. The reason I
// opted to serialize all in one struct instead of with separate tests is to
// ensure that the positioning of elements when in relation to others is
// also as expected. Especially when it comes to things such as nesting and
// length prefixing.
#[test]
fn test_pyth_serde() {
use serde::Serialize;
// Setup Serializer.
let mut buffer = Vec::new();
let mut cursor = std::io::Cursor::new(&mut buffer);
let mut serializer: Serializer<_, byteorder::LE> = Serializer::new(&mut cursor);
// Golden Test Value. As binary data can be fickle to understand in
// tests this should be kept commented with detail.
#[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
struct GoldenStruct<'a> {
// Test `unit` is not serialized to anything.
unit: (),
// Test `bool` is serialized to a single byte.
t_bool: bool,
// Test integer serializations.
t_u8: u8,
t_u16: u16,
t_u32: u32,
t_u64: u64,
// Test `str` is serialized to a variable length array.
t_string: String,
t_str: &'a str,
// Test `Vec` is serialized to a variable length array.
t_vec: Vec<u8>,
t_vec_empty: Vec<u8>,
t_vec_nested: Vec<Vec<u8>>,
t_vec_nested_empty: Vec<Vec<u8>>,
t_slice: &'a [u8],
t_slice_empty: &'a [u8],
// Test tuples serialize as expected.
t_tuple: (u8, u16, u32, u64, String, Vec<u8>, &'a [u8]),
t_tuple_nested: ((u8, u16), (u32, u64)),
// Test enum serializations.
t_enum_unit: GoldenEnum,
t_enum_newtype: GoldenEnum,
t_enum_tuple: GoldenEnum,
t_enum_struct: GoldenEnum,
}
#[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
enum GoldenEnum {
UnitVariant,
NewtypeVariant(u8),
TupleVariant(u8, u16),
StructVariant { a: u8, b: u16 },
}
// Serialize the golden test value.
let golden_struct = GoldenStruct {
unit: (),
t_bool: true,
t_u8: 1,
t_u16: 2,
t_u32: 3,
t_u64: 4,
t_string: "9".to_string(),
t_str: "10",
t_vec: vec![11, 12, 13],
t_vec_empty: vec![],
t_vec_nested: vec![vec![14, 15, 16], vec![17, 18, 19]],
t_vec_nested_empty: vec![vec![], vec![]],
t_slice: &[20, 21, 22],
t_slice_empty: &[],
t_tuple: (
29,
30,
31,
32,
"10".to_string(),
vec![35, 36, 37],
&[38, 39, 40],
),
t_tuple_nested: ((41, 42), (43, 44)),
t_enum_unit: GoldenEnum::UnitVariant,
t_enum_newtype: GoldenEnum::NewtypeVariant(45),
t_enum_tuple: GoldenEnum::TupleVariant(46, 47),
t_enum_struct: GoldenEnum::StructVariant { a: 48, b: 49 },
};
golden_struct.serialize(&mut serializer).unwrap();
// The serialized output should be as expected.
assert_eq!(
&buffer,
&[
1, // t_bool
1, // t_u8
2, 0, // t_u16
3, 0, 0, 0, // t_u32
4, 0, 0, 0, 0, 0, 0, 0, // t_u64
1, 57, // t_string
2, 49, 48, // t_str
3, 11, 12, 13, // t_vec
0, // t_vec_empty
2, 3, 14, 15, 16, 3, 17, 18, 19, // t_vec_nested
2, 0, 0, // t_vec_nested_empty
3, 20, 21, 22, // t_slice
0, // t_slice_empty
29, // t_tuple
30, 0, // u8
31, 0, 0, 0, // u16
32, 0, 0, 0, 0, 0, 0, 0, // u32
2, 49, 48, // "10"
3, 35, 36, 37, // [35, 36, 37]
3, 38, 39, 40, // [38, 39, 40]
41, 42, 0, 43, 0, 0, 0, 44, 0, 0, 0, 0, 0, 0, 0, // t_tuple_nested
0, // t_enum_unit
1, 45, // t_enum_newtype
2, 46, 47, 0, // t_enum_tuple
3, 48, 49, 0, // t_enum_struct
]
);
}
}

View File

@ -1,3 +1,8 @@
//! This module provides Wormhole primitives.
//!
//! Wormhole does not provide an SDK for working with Solana versions of Wormhole related types, so
//! we clone the definitions from the Solana contracts here and adapt them to Pyth purposes. This
//! allows us to emit and parse messages through Wormhole.
use {
crate::Pubkey,
borsh::{