From e26c9d1a3064e27eed950ef58be6adf3b5d2062e Mon Sep 17 00:00:00 2001 From: Pavel Strakhov Date: Tue, 7 May 2024 14:20:59 +0100 Subject: [PATCH] refactor(target_chains/starknet): split pyth module (#1554) --- .../starknet/contracts/src/pyth.cairo | 356 +++--------------- .../starknet/contracts/src/pyth/errors.cairo | 46 +++ .../contracts/src/pyth/interface.cairo | 25 ++ .../contracts/src/pyth/price_update.cairo | 141 +++++++ 4 files changed, 272 insertions(+), 296 deletions(-) create mode 100644 target_chains/starknet/contracts/src/pyth/errors.cairo create mode 100644 target_chains/starknet/contracts/src/pyth/interface.cairo create mode 100644 target_chains/starknet/contracts/src/pyth/price_update.cairo diff --git a/target_chains/starknet/contracts/src/pyth.cairo b/target_chains/starknet/contracts/src/pyth.cairo index 5d4e59ca..4e04df8a 100644 --- a/target_chains/starknet/contracts/src/pyth.cairo +++ b/target_chains/starknet/contracts/src/pyth.cairo @@ -1,115 +1,27 @@ -use core::array::ArrayTrait; -use core::fmt::{Debug, Formatter}; -use super::byte_array::ByteArray; -use super::util::UnwrapWithFelt252; +mod errors; +mod interface; +mod price_update; pub use pyth::{Event, PriceFeedUpdateEvent}; - -#[starknet::interface] -pub trait IPyth { - fn get_price_unsafe(self: @T, price_id: u256) -> Result; - fn get_ema_price_unsafe(self: @T, price_id: u256) -> Result; - fn set_data_sources(ref self: T, sources: Array); - fn set_fee(ref self: T, single_update_fee: u256); - fn update_price_feeds(ref self: T, data: ByteArray); -} - -#[derive(Copy, Drop, Debug, Serde, PartialEq)] -pub enum GetPriceUnsafeError { - PriceFeedNotFound, -} - -impl GetPriceUnsafeErrorIntoFelt252 of Into { - fn into(self: GetPriceUnsafeError) -> felt252 { - match self { - GetPriceUnsafeError::PriceFeedNotFound => 'price feed not found', - } - } -} - -#[derive(Copy, Drop, Debug, Serde, PartialEq)] -pub enum GovernanceActionError { - AccessDenied, -} - -impl GovernanceActionErrorIntoFelt252 of Into { - fn into(self: GovernanceActionError) -> felt252 { - match self { - GovernanceActionError::AccessDenied => 'access denied', - } - } -} - -#[derive(Copy, Drop, Debug, Serde, PartialEq)] -pub enum UpdatePriceFeedsError { - Reader: super::reader::Error, - Wormhole: super::wormhole::ParseAndVerifyVmError, - InvalidUpdateData, - InvalidUpdateDataSource, - InsufficientFeeAllowance, -} - -impl UpdatePriceFeedsErrorIntoFelt252 of Into { - fn into(self: UpdatePriceFeedsError) -> felt252 { - match self { - UpdatePriceFeedsError::Reader(err) => err.into(), - UpdatePriceFeedsError::Wormhole(err) => err.into(), - UpdatePriceFeedsError::InvalidUpdateData => 'invalid update data', - UpdatePriceFeedsError::InvalidUpdateDataSource => 'invalid update data source', - UpdatePriceFeedsError::InsufficientFeeAllowance => 'insufficient fee allowance', - } - } -} - -#[derive(Drop, Debug, Clone, Copy, Hash, Default, Serde, starknet::Store)] -pub struct DataSource { - pub emitter_chain_id: u16, - pub emitter_address: u256, -} - -#[derive(Drop, Clone, Serde, starknet::Store)] -struct PriceInfo { - pub price: i64, - pub conf: u64, - pub expo: i32, - pub publish_time: u64, - pub ema_price: i64, - pub ema_conf: u64, -} - -#[derive(Drop, Clone, Serde)] -struct Price { - pub price: i64, - pub conf: u64, - pub expo: i32, - pub publish_time: u64, -} +pub use errors::{GetPriceUnsafeError, GovernanceActionError, UpdatePriceFeedsError}; +pub use interface::{IPyth, IPythDispatcher, IPythDispatcherTrait, DataSource, Price}; #[starknet::contract] mod pyth { - use pyth::reader::ReaderTrait; + use super::price_update::{ + PriceInfo, PriceFeedMessage, read_and_verify_message, read_header_and_wormhole_proof, + parse_wormhole_proof + }; use pyth::reader::{Reader, ReaderImpl}; use pyth::byte_array::{ByteArray, ByteArrayImpl}; use core::panic_with_felt252; use core::starknet::{ContractAddress, get_caller_address, get_execution_info}; use pyth::wormhole::{IWormholeDispatcher, IWormholeDispatcherTrait}; use super::{ - DataSource, UpdatePriceFeedsError, PriceInfo, GovernanceActionError, Price, - GetPriceUnsafeError + DataSource, UpdatePriceFeedsError, GovernanceActionError, Price, GetPriceUnsafeError }; - use pyth::merkle_tree::{read_and_verify_proof, MerkleVerificationError}; - use pyth::hash::{Hasher, HasherImpl}; - use core::fmt::{Debug, Formatter}; - use pyth::util::{u64_as_i64, u32_as_i32}; use openzeppelin::token::erc20::interface::{IERC20CamelDispatcherTrait, IERC20CamelDispatcher}; - // Stands for PNAU (Pyth Network Accumulator Update) - const ACCUMULATOR_MAGIC: u32 = 0x504e4155; - // Stands for AUWV (Accumulator Update Wormhole Verficiation) - const ACCUMULATOR_WORMHOLE_MAGIC: u32 = 0x41555756; - const MAJOR_VERSION: u8 = 1; - const MINIMUM_ALLOWED_MINOR_VERSION: u8 = 0; - #[event] #[derive(Drop, PartialEq, starknet::Event)] pub enum Event { @@ -125,74 +37,6 @@ mod pyth { pub conf: u64, } - #[generate_trait] - impl ResultReaderToUpdatePriceFeeds of ResultReaderToUpdatePriceFeedsTrait { - fn map_err(self: Result) -> Result { - match self { - Result::Ok(v) => Result::Ok(v), - Result::Err(err) => Result::Err(UpdatePriceFeedsError::Reader(err)), - } - } - } - - #[generate_trait] - impl ResultWormholeToUpdatePriceFeeds of ResultWormholeToUpdatePriceFeedsTrait { - fn map_err( - self: Result - ) -> Result { - match self { - Result::Ok(v) => Result::Ok(v), - Result::Err(err) => Result::Err(UpdatePriceFeedsError::Wormhole(err)), - } - } - } - - #[generate_trait] - impl ResultMerkleToUpdatePriceFeeds of ResultMerkleToUpdatePriceFeedsTrait { - fn map_err(self: Result) -> Result { - match self { - Result::Ok(v) => Result::Ok(v), - Result::Err(err) => { - let err = match err { - MerkleVerificationError::Reader(err) => UpdatePriceFeedsError::Reader(err), - MerkleVerificationError::DigestMismatch => UpdatePriceFeedsError::InvalidUpdateData, - }; - Result::Err(err) - }, - } - } - } - - #[derive(Drop)] - enum UpdateType { - WormholeMerkle - } - - impl U8TryIntoUpdateType of TryInto { - fn try_into(self: u8) -> Option { - if self == 0 { - Option::Some(UpdateType::WormholeMerkle) - } else { - Option::None - } - } - } - - #[derive(Drop)] - enum MessageType { - PriceFeed - } - - impl U8TryIntoMessageType of TryInto { - fn try_into(self: u8) -> Option { - if self == 0 { - Option::Some(MessageType::PriceFeed) - } else { - Option::None - } - } - } - #[storage] struct Storage { wormhole_address: ContractAddress, @@ -235,39 +79,7 @@ mod pyth { self.wormhole_address.write(wormhole_address); self.fee_contract_address.write(fee_contract_address); self.single_update_fee.write(single_update_fee); - write_data_sources(ref self, data_sources); - } - - fn write_data_sources(ref self: ContractState, data_sources: Array) { - let num_old = self.num_data_sources.read(); - let mut i = 0; - while i < num_old { - let old_source = self.data_sources.read(i); - self.is_valid_data_source.write(old_source, false); - self.data_sources.write(i, Default::default()); - i += 1; - }; - - self.num_data_sources.write(data_sources.len()); - i = 0; - while i < data_sources.len() { - let source = data_sources.at(i); - self.is_valid_data_source.write(*source, true); - self.data_sources.write(i, *source); - i += 1; - }; - } - - #[derive(Drop)] - struct PriceFeedMessage { - price_id: u256, - price: i64, - conf: u64, - expo: i32, - publish_time: u64, - prev_publish_time: u64, - ema_price: i64, - ema_conf: u64, + self.write_data_sources(data_sources); } #[abi(embed_v0)] @@ -308,7 +120,7 @@ mod pyth { if self.owner.read() != get_caller_address() { panic_with_felt252(GovernanceActionError::AccessDenied.into()); } - write_data_sources(ref self, sources); + self.write_data_sources(sources); } fn set_fee(ref self: ContractState, single_update_fee: u256) { @@ -320,33 +132,9 @@ mod pyth { fn update_price_feeds(ref self: ContractState, data: ByteArray) { let mut reader = ReaderImpl::new(data); - let x = reader.read_u32(); - if x != ACCUMULATOR_MAGIC { - panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into()); - } - if reader.read_u8() != MAJOR_VERSION { - panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into()); - } - if reader.read_u8() < MINIMUM_ALLOWED_MINOR_VERSION { - panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into()); - } - - let trailing_header_size = reader.read_u8(); - reader.skip(trailing_header_size); - - let update_type: UpdateType = reader - .read_u8() - .try_into() - .expect(UpdatePriceFeedsError::InvalidUpdateData.into()); - - match update_type { - UpdateType::WormholeMerkle => {} - } - - let wh_proof_size = reader.read_u16(); - let wh_proof = reader.read_byte_array(wh_proof_size.into()); + let wormhole_proof = read_header_and_wormhole_proof(ref reader); let wormhole = IWormholeDispatcher { contract_address: self.wormhole_address.read() }; - let vm = wormhole.parse_and_verify_vm(wh_proof); + let vm = wormhole.parse_and_verify_vm(wormhole_proof); let source = DataSource { emitter_chain_id: vm.emitter_chain_id, emitter_address: vm.emitter_address @@ -355,28 +143,10 @@ mod pyth { panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateDataSource.into()); } - let mut payload_reader = ReaderImpl::new(vm.payload); - let x = payload_reader.read_u32(); - if x != ACCUMULATOR_WORMHOLE_MAGIC { - panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into()); - } - - let update_type: UpdateType = payload_reader - .read_u8() - .try_into() - .expect(UpdatePriceFeedsError::InvalidUpdateData.into()); - - match update_type { - UpdateType::WormholeMerkle => {} - } - - let _slot = payload_reader.read_u64(); - let _ring_size = payload_reader.read_u32(); - let root_digest = payload_reader.read_u160(); + let root_digest = parse_wormhole_proof(vm.payload); let num_updates = reader.read_u8(); - - let total_fee = get_total_fee(ref self, num_updates); + let total_fee = self.get_total_fee(num_updates); let fee_contract = IERC20CamelDispatcher { contract_address: self.fee_contract_address.read() }; @@ -393,7 +163,7 @@ mod pyth { let mut i = 0; while i < num_updates { let message = read_and_verify_message(ref reader, root_digest); - update_latest_price_if_necessary(ref self, message); + self.update_latest_price_if_necessary(message); i += 1; }; @@ -403,59 +173,53 @@ mod pyth { } } - fn read_and_verify_message(ref reader: Reader, root_digest: u256) -> PriceFeedMessage { - let message_size = reader.read_u16(); - let message = reader.read_byte_array(message_size.into()); - read_and_verify_proof(root_digest, @message, ref reader); - - let mut message_reader = ReaderImpl::new(message); - let message_type: MessageType = message_reader - .read_u8() - .try_into() - .expect(UpdatePriceFeedsError::InvalidUpdateData.into()); - - match message_type { - MessageType::PriceFeed => {} - } - - let price_id = message_reader.read_u256(); - let price = u64_as_i64(message_reader.read_u64()); - let conf = message_reader.read_u64(); - let expo = u32_as_i32(message_reader.read_u32()); - let publish_time = message_reader.read_u64(); - let prev_publish_time = message_reader.read_u64(); - let ema_price = u64_as_i64(message_reader.read_u64()); - let ema_conf = message_reader.read_u64(); - - PriceFeedMessage { - price_id, price, conf, expo, publish_time, prev_publish_time, ema_price, ema_conf, - } - } - - fn update_latest_price_if_necessary(ref self: ContractState, message: PriceFeedMessage) { - let latest_publish_time = self.latest_price_info.read(message.price_id).publish_time; - if message.publish_time > latest_publish_time { - let info = PriceInfo { - price: message.price, - conf: message.conf, - expo: message.expo, - publish_time: message.publish_time, - ema_price: message.ema_price, - ema_conf: message.ema_conf, + #[generate_trait] + impl PrivateImpl of PrivateTrait { + fn write_data_sources(ref self: ContractState, data_sources: Array) { + let num_old = self.num_data_sources.read(); + let mut i = 0; + while i < num_old { + let old_source = self.data_sources.read(i); + self.is_valid_data_source.write(old_source, false); + self.data_sources.write(i, Default::default()); + i += 1; }; - self.latest_price_info.write(message.price_id, info); - let event = PriceFeedUpdateEvent { - price_id: message.price_id, - publish_time: message.publish_time, - price: message.price, - conf: message.conf, + self.num_data_sources.write(data_sources.len()); + i = 0; + while i < data_sources.len() { + let source = data_sources.at(i); + self.is_valid_data_source.write(*source, true); + self.data_sources.write(i, *source); + i += 1; }; - self.emit(event); } - } - fn get_total_fee(ref self: ContractState, num_updates: u8) -> u256 { - self.single_update_fee.read() * num_updates.into() + fn update_latest_price_if_necessary(ref self: ContractState, message: PriceFeedMessage) { + let latest_publish_time = self.latest_price_info.read(message.price_id).publish_time; + if message.publish_time > latest_publish_time { + let info = PriceInfo { + price: message.price, + conf: message.conf, + expo: message.expo, + publish_time: message.publish_time, + ema_price: message.ema_price, + ema_conf: message.ema_conf, + }; + self.latest_price_info.write(message.price_id, info); + + let event = PriceFeedUpdateEvent { + price_id: message.price_id, + publish_time: message.publish_time, + price: message.price, + conf: message.conf, + }; + self.emit(event); + } + } + + fn get_total_fee(ref self: ContractState, num_updates: u8) -> u256 { + self.single_update_fee.read() * num_updates.into() + } } } diff --git a/target_chains/starknet/contracts/src/pyth/errors.cairo b/target_chains/starknet/contracts/src/pyth/errors.cairo new file mode 100644 index 00000000..89922062 --- /dev/null +++ b/target_chains/starknet/contracts/src/pyth/errors.cairo @@ -0,0 +1,46 @@ +#[derive(Copy, Drop, Debug, Serde, PartialEq)] +pub enum GetPriceUnsafeError { + PriceFeedNotFound, +} + +impl GetPriceUnsafeErrorIntoFelt252 of Into { + fn into(self: GetPriceUnsafeError) -> felt252 { + match self { + GetPriceUnsafeError::PriceFeedNotFound => 'price feed not found', + } + } +} + +#[derive(Copy, Drop, Debug, Serde, PartialEq)] +pub enum GovernanceActionError { + AccessDenied, +} + +impl GovernanceActionErrorIntoFelt252 of Into { + fn into(self: GovernanceActionError) -> felt252 { + match self { + GovernanceActionError::AccessDenied => 'access denied', + } + } +} + +#[derive(Copy, Drop, Debug, Serde, PartialEq)] +pub enum UpdatePriceFeedsError { + Reader: pyth::reader::Error, + Wormhole: pyth::wormhole::ParseAndVerifyVmError, + InvalidUpdateData, + InvalidUpdateDataSource, + InsufficientFeeAllowance, +} + +impl UpdatePriceFeedsErrorIntoFelt252 of Into { + fn into(self: UpdatePriceFeedsError) -> felt252 { + match self { + UpdatePriceFeedsError::Reader(err) => err.into(), + UpdatePriceFeedsError::Wormhole(err) => err.into(), + UpdatePriceFeedsError::InvalidUpdateData => 'invalid update data', + UpdatePriceFeedsError::InvalidUpdateDataSource => 'invalid update data source', + UpdatePriceFeedsError::InsufficientFeeAllowance => 'insufficient fee allowance', + } + } +} diff --git a/target_chains/starknet/contracts/src/pyth/interface.cairo b/target_chains/starknet/contracts/src/pyth/interface.cairo new file mode 100644 index 00000000..02eca0ac --- /dev/null +++ b/target_chains/starknet/contracts/src/pyth/interface.cairo @@ -0,0 +1,25 @@ +use super::GetPriceUnsafeError; +use pyth::byte_array::ByteArray; + +#[starknet::interface] +pub trait IPyth { + fn get_price_unsafe(self: @T, price_id: u256) -> Result; + fn get_ema_price_unsafe(self: @T, price_id: u256) -> Result; + fn set_data_sources(ref self: T, sources: Array); + fn set_fee(ref self: T, single_update_fee: u256); + fn update_price_feeds(ref self: T, data: ByteArray); +} + +#[derive(Drop, Debug, Clone, Copy, Hash, Default, Serde, starknet::Store)] +pub struct DataSource { + pub emitter_chain_id: u16, + pub emitter_address: u256, +} + +#[derive(Drop, Clone, Serde)] +pub struct Price { + pub price: i64, + pub conf: u64, + pub expo: i32, + pub publish_time: u64, +} diff --git a/target_chains/starknet/contracts/src/pyth/price_update.cairo b/target_chains/starknet/contracts/src/pyth/price_update.cairo new file mode 100644 index 00000000..9d08de34 --- /dev/null +++ b/target_chains/starknet/contracts/src/pyth/price_update.cairo @@ -0,0 +1,141 @@ +use pyth::reader::{Reader, ReaderImpl}; +use pyth::pyth::UpdatePriceFeedsError; +use core::panic_with_felt252; +use pyth::byte_array::ByteArray; +use pyth::merkle_tree::read_and_verify_proof; +use pyth::util::{u32_as_i32, u64_as_i64}; + +// Stands for PNAU (Pyth Network Accumulator Update) +const ACCUMULATOR_MAGIC: u32 = 0x504e4155; +// Stands for AUWV (Accumulator Update Wormhole Verficiation) +const ACCUMULATOR_WORMHOLE_MAGIC: u32 = 0x41555756; +const MAJOR_VERSION: u8 = 1; +const MINIMUM_ALLOWED_MINOR_VERSION: u8 = 0; + +#[derive(Drop, Clone, Serde, starknet::Store)] +pub struct PriceInfo { + pub price: i64, + pub conf: u64, + pub expo: i32, + pub publish_time: u64, + pub ema_price: i64, + pub ema_conf: u64, +} + +#[derive(Drop)] +enum UpdateType { + WormholeMerkle +} + +impl U8TryIntoUpdateType of TryInto { + fn try_into(self: u8) -> Option { + if self == 0 { + Option::Some(UpdateType::WormholeMerkle) + } else { + Option::None + } + } +} + +#[derive(Drop)] +enum MessageType { + PriceFeed +} + +impl U8TryIntoMessageType of TryInto { + fn try_into(self: u8) -> Option { + if self == 0 { + Option::Some(MessageType::PriceFeed) + } else { + Option::None + } + } +} + +#[derive(Drop)] +pub struct PriceFeedMessage { + pub price_id: u256, + pub price: i64, + pub conf: u64, + pub expo: i32, + pub publish_time: u64, + pub prev_publish_time: u64, + pub ema_price: i64, + pub ema_conf: u64, +} + +pub fn read_header_and_wormhole_proof(ref reader: Reader) -> ByteArray { + if reader.read_u32() != ACCUMULATOR_MAGIC { + panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into()); + } + if reader.read_u8() != MAJOR_VERSION { + panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into()); + } + if reader.read_u8() < MINIMUM_ALLOWED_MINOR_VERSION { + panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into()); + } + + let trailing_header_size = reader.read_u8(); + reader.skip(trailing_header_size); + + let update_type: UpdateType = reader + .read_u8() + .try_into() + .expect(UpdatePriceFeedsError::InvalidUpdateData.into()); + + match update_type { + UpdateType::WormholeMerkle => {} + } + + let wormhole_proof_size = reader.read_u16(); + reader.read_byte_array(wormhole_proof_size.into()) +} + +pub fn parse_wormhole_proof(payload: ByteArray) -> u256 { + let mut reader = ReaderImpl::new(payload); + if reader.read_u32() != ACCUMULATOR_WORMHOLE_MAGIC { + panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into()); + } + + let update_type: UpdateType = reader + .read_u8() + .try_into() + .expect(UpdatePriceFeedsError::InvalidUpdateData.into()); + + match update_type { + UpdateType::WormholeMerkle => {} + } + + let _slot = reader.read_u64(); + let _ring_size = reader.read_u32(); + reader.read_u160() +} + +pub fn read_and_verify_message(ref reader: Reader, root_digest: u256) -> PriceFeedMessage { + let message_size = reader.read_u16(); + let message = reader.read_byte_array(message_size.into()); + read_and_verify_proof(root_digest, @message, ref reader); + + let mut message_reader = ReaderImpl::new(message); + let message_type: MessageType = message_reader + .read_u8() + .try_into() + .expect(UpdatePriceFeedsError::InvalidUpdateData.into()); + + match message_type { + MessageType::PriceFeed => {} + } + + let price_id = message_reader.read_u256(); + let price = u64_as_i64(message_reader.read_u64()); + let conf = message_reader.read_u64(); + let expo = u32_as_i32(message_reader.read_u32()); + let publish_time = message_reader.read_u64(); + let prev_publish_time = message_reader.read_u64(); + let ema_price = u64_as_i64(message_reader.read_u64()); + let ema_conf = message_reader.read_u64(); + + PriceFeedMessage { + price_id, price, conf, expo, publish_time, prev_publish_time, ema_price, ema_conf, + } +}