From 178ad4cb0edff38f43d8e26f23d1d9e83448093c Mon Sep 17 00:00:00 2001 From: guibescos <59208140+guibescos@users.noreply.github.com> Date: Fri, 15 Dec 2023 10:47:22 +0700 Subject: [PATCH] [pythnet-sdk] Bump borsh add borsh to MerklePriceUpdate (#1186) * Add borsh * Bump borsh * Cleanup * Try * Cleanup * Do it * Add a test --- pythnet/pythnet_sdk/Cargo.toml | 2 +- pythnet/pythnet_sdk/src/wire.rs | 8 +++- pythnet/pythnet_sdk/src/wire/prefixed_vec.rs | 34 ++++++++++++---- pythnet/pythnet_sdk/src/wormhole.rs | 43 +++++++++++++++----- 4 files changed, 66 insertions(+), 21 deletions(-) diff --git a/pythnet/pythnet_sdk/Cargo.toml b/pythnet/pythnet_sdk/Cargo.toml index c2e233dc..cd47cc93 100644 --- a/pythnet/pythnet_sdk/Cargo.toml +++ b/pythnet/pythnet_sdk/Cargo.toml @@ -12,7 +12,7 @@ name = "pythnet_sdk" [dependencies] bincode = "1.3.1" -borsh = "0.9.1" +borsh = "0.10.3" bytemuck = { version = "1.11.0", features = ["derive"] } byteorder = "1.4.3" fast-math = "0.1" diff --git a/pythnet/pythnet_sdk/src/wire.rs b/pythnet/pythnet_sdk/src/wire.rs index 6da2316d..922e82d5 100644 --- a/pythnet/pythnet_sdk/src/wire.rs +++ b/pythnet/pythnet_sdk/src/wire.rs @@ -40,6 +40,10 @@ pub mod v1 { hashers::keccak256_160::Keccak160, require, }, + borsh::{ + BorshDeserialize, + BorshSerialize, + }, serde::{ Deserialize, Serialize, @@ -99,7 +103,9 @@ pub mod v1 { }, } - #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] + #[derive( + Clone, Debug, Hash, PartialEq, Serialize, Deserialize, BorshDeserialize, BorshSerialize, + )] pub struct MerklePriceUpdate { pub message: PrefixedVec, pub proof: MerklePath, diff --git a/pythnet/pythnet_sdk/src/wire/prefixed_vec.rs b/pythnet/pythnet_sdk/src/wire/prefixed_vec.rs index d8922362..9459a137 100644 --- a/pythnet/pythnet_sdk/src/wire/prefixed_vec.rs +++ b/pythnet/pythnet_sdk/src/wire/prefixed_vec.rs @@ -1,15 +1,21 @@ -use serde::{ - de::DeserializeSeed, - ser::{ - SerializeSeq, - SerializeStruct, +use { + borsh::{ + BorshDeserialize, + BorshSerialize, + }, + serde::{ + de::DeserializeSeed, + ser::{ + SerializeSeq, + SerializeStruct, + }, + Deserialize, + Serialize, }, - Deserialize, - Serialize, }; /// PrefixlessVec overrides the serialization to _not_ write a length prefix. -#[derive(Clone, Debug, Hash, PartialEq, PartialOrd)] +#[derive(Clone, Debug, Hash, PartialEq, PartialOrd, BorshDeserialize, BorshSerialize)] struct PrefixlessVec { inner: Vec, } @@ -99,7 +105,7 @@ where /// /// For non-Pyth formats this results in a struct which is the correct way to interpret our /// data on chain anyway. -#[derive(Clone, Debug, Hash, PartialEq, PartialOrd)] +#[derive(Clone, Debug, Hash, PartialEq, PartialOrd, BorshDeserialize, BorshSerialize)] pub struct PrefixedVec { __phantom: std::marker::PhantomData, data: PrefixlessVec, @@ -227,3 +233,13 @@ where ) } } + +#[test] +fn test_borsh_roundtrip() { + let prefixed_vec = PrefixedVec::::from(vec![1, 2, 3, 4, 5]); + let encoded = borsh::to_vec(&prefixed_vec).unwrap(); + assert_eq!(encoded, vec![5, 0, 0, 0, 1, 2, 3, 4, 5]); + + let decoded_prefixed_vec = PrefixedVec::::try_from_slice(encoded.as_slice()).unwrap(); + assert_eq!(decoded_prefixed_vec, prefixed_vec); +} diff --git a/pythnet/pythnet_sdk/src/wormhole.rs b/pythnet/pythnet_sdk/src/wormhole.rs index 48012035..e7418cfe 100644 --- a/pythnet/pythnet_sdk/src/wormhole.rs +++ b/pythnet/pythnet_sdk/src/wormhole.rs @@ -27,12 +27,14 @@ use { }; #[repr(transparent)] -#[derive(Default)] +#[derive(Default, PartialEq, Debug)] pub struct PostedMessageUnreliableData { pub message: MessageData, } -#[derive(Debug, Default, BorshSerialize, BorshDeserialize, Clone, Serialize, Deserialize)] +#[derive( + Debug, Default, BorshSerialize, BorshDeserialize, Clone, Serialize, Deserialize, PartialEq, +)] pub struct MessageData { pub vaa_version: u8, pub consistency_level: u8, @@ -54,22 +56,19 @@ impl BorshSerialize for PostedMessageUnreliableData { } impl BorshDeserialize for PostedMessageUnreliableData { - fn deserialize(buf: &mut &[u8]) -> std::io::Result { - if buf.len() < 3 { - return Err(Error::new(InvalidData, "Not enough bytes")); - } + fn deserialize_reader(reader: &mut R) -> std::io::Result { + let mut magic = [0u8; 3]; + reader.read_exact(&mut magic)?; let expected = b"msu"; - let magic: &[u8] = &buf[0..3]; - if magic != expected { + if &magic != expected { return Err(Error::new( InvalidData, format!("Magic mismatch. Expected {expected:?} but got {magic:?}"), )); }; - *buf = &buf[3..]; Ok(PostedMessageUnreliableData { - message: ::deserialize(buf)?, + message: ::deserialize_reader(reader)?, }) } } @@ -99,3 +98,27 @@ impl Clone for PostedMessageUnreliableData { pub struct AccumulatorSequenceTracker { pub sequence: u64, } + +#[test] +fn test_borsh_roundtrip() { + let post_message_unreliable_data = PostedMessageUnreliableData { + message: MessageData { + vaa_version: 1, + consistency_level: 2, + vaa_time: 3, + vaa_signature_account: [4u8; 32], + submission_time: 5, + nonce: 6, + sequence: 7, + emitter_chain: 8, + emitter_address: [9u8; 32], + payload: vec![10u8; 32], + }, + }; + + + let encoded = borsh::to_vec(&post_message_unreliable_data).unwrap(); + + let decoded = PostedMessageUnreliableData::try_from_slice(encoded.as_slice()).unwrap(); + assert_eq!(decoded, post_message_unreliable_data); +}