From 42b64ac09f4c702c0c74cdd5a27a8909e652c671 Mon Sep 17 00:00:00 2001 From: Pavel Strakhov Date: Mon, 6 May 2024 16:21:36 +0100 Subject: [PATCH] refactor(target_chains/starknet): remove Result from merkle_tree and pyth setters (#1548) * refactor(target_chains/starknet): remove Result from merkle_tree * refactor(target_chains/starknet): remove Result from pyth contract setters --- .../starknet/contracts/src/merkle_tree.cairo | 26 +++-- .../starknet/contracts/src/pyth.cairo | 106 +++++++----------- .../starknet/contracts/tests/pyth.cairo | 2 +- 3 files changed, 59 insertions(+), 75 deletions(-) diff --git a/target_chains/starknet/contracts/src/merkle_tree.cairo b/target_chains/starknet/contracts/src/merkle_tree.cairo index e715698b..b880ccf2 100644 --- a/target_chains/starknet/contracts/src/merkle_tree.cairo +++ b/target_chains/starknet/contracts/src/merkle_tree.cairo @@ -3,6 +3,7 @@ use super::reader::{Reader, ReaderImpl}; use super::byte_array::ByteArray; use super::util::ONE_SHIFT_96; use core::cmp::{min, max}; +use core::panic_with_felt252; const MERKLE_LEAF_PREFIX: u8 = 0; const MERKLE_NODE_PREFIX: u8 = 1; @@ -14,6 +15,15 @@ pub enum MerkleVerificationError { DigestMismatch, } +impl MerkleVerificationErrorIntoFelt252 of Into { + fn into(self: MerkleVerificationError) -> felt252 { + match self { + MerkleVerificationError::Reader(err) => err.into(), + MerkleVerificationError::DigestMismatch => 'digest mismatch', + } + } +} + #[generate_trait] impl ResultReaderToMerkleVerification of ResultReaderToMerkleVerificationTrait { fn map_err(self: Result) -> Result { @@ -24,12 +34,11 @@ impl ResultReaderToMerkleVerification of ResultReaderToMerkleVerificationTrai } } -fn leaf_hash(mut reader: Reader) -> Result { +fn leaf_hash(mut reader: Reader) -> u256 { let mut hasher = HasherImpl::new(); hasher.push_u8(MERKLE_LEAF_PREFIX); hasher.push_reader(ref reader); - let hash = hasher.finalize() / ONE_SHIFT_96; - Result::Ok(hash) + hasher.finalize() / ONE_SHIFT_96 } fn node_hash(a: u256, b: u256) -> u256 { @@ -40,25 +49,20 @@ fn node_hash(a: u256, b: u256) -> u256 { hasher.finalize() / ONE_SHIFT_96 } -pub fn read_and_verify_proof( - root_digest: u256, message: @ByteArray, ref reader: Reader -) -> Result<(), MerkleVerificationError> { +pub fn read_and_verify_proof(root_digest: u256, message: @ByteArray, ref reader: Reader) { let mut message_reader = ReaderImpl::new(message.clone()); - let mut current_hash = leaf_hash(message_reader.clone()).map_err()?; + let mut current_hash = leaf_hash(message_reader.clone()); let proof_size = reader.read_u8(); let mut i = 0; - let mut result = Result::Ok(()); while i < proof_size { let sibling_digest = reader.read_u160(); current_hash = node_hash(current_hash, sibling_digest); i += 1; }; - result?; if root_digest != current_hash { - return Result::Err(MerkleVerificationError::DigestMismatch); + panic_with_felt252(MerkleVerificationError::DigestMismatch.into()); } - Result::Ok(()) } diff --git a/target_chains/starknet/contracts/src/pyth.cairo b/target_chains/starknet/contracts/src/pyth.cairo index 3dd319cc..bf25ca90 100644 --- a/target_chains/starknet/contracts/src/pyth.cairo +++ b/target_chains/starknet/contracts/src/pyth.cairo @@ -9,11 +9,9 @@ pub use pyth::{Event, PriceFeedUpdateEvent}; pub trait IPyth { fn get_price_unsafe(self: @T, price_id: u256) -> Result; fn get_ema_price_unsafe(self: @T, price_id: u256) -> Result; - fn set_data_sources( - ref self: T, sources: Array - ) -> Result<(), GovernanceActionError>; - fn set_fee(ref self: T, single_update_fee: u256) -> Result<(), GovernanceActionError>; - fn update_price_feeds(ref self: T, data: ByteArray) -> Result<(), UpdatePriceFeedsError>; + fn set_data_sources(ref self: T, sources: Array); + fn set_fee(ref self: T, single_update_fee: u256); + fn update_price_feeds(ref self: T, data: ByteArray); } #[derive(Copy, Drop, Debug, Serde, PartialEq)] @@ -333,51 +331,44 @@ mod pyth { Result::Ok(price) } - fn set_data_sources( - ref self: ContractState, sources: Array - ) -> Result<(), GovernanceActionError> { + fn set_data_sources(ref self: ContractState, sources: Array) { if self.owner.read() != get_caller_address() { - return Result::Err(GovernanceActionError::AccessDenied); + panic_with_felt252(GovernanceActionError::AccessDenied.into()); } write_data_sources(ref self, sources); - Result::Ok(()) } - fn set_fee( - ref self: ContractState, single_update_fee: u256 - ) -> Result<(), GovernanceActionError> { + fn set_fee(ref self: ContractState, single_update_fee: u256) { if self.owner.read() != get_caller_address() { - return Result::Err(GovernanceActionError::AccessDenied); + panic_with_felt252(GovernanceActionError::AccessDenied.into()); } self.single_update_fee.write(single_update_fee); - Result::Ok(()) } - fn update_price_feeds( - ref self: ContractState, data: ByteArray - ) -> Result<(), UpdatePriceFeedsError> { + fn update_price_feeds(ref self: ContractState, data: ByteArray) { let mut reader = ReaderImpl::new(data); let x = reader.read_u32(); if x != ACCUMULATOR_MAGIC { - return Result::Err(UpdatePriceFeedsError::InvalidUpdateData); + panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into()); } if reader.read_u8() != MAJOR_VERSION { - return Result::Err(UpdatePriceFeedsError::InvalidUpdateData); + panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into()); } if reader.read_u8() < MINIMUM_ALLOWED_MINOR_VERSION { - return Result::Err(UpdatePriceFeedsError::InvalidUpdateData); + panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into()); } let trailing_header_size = reader.read_u8(); reader.skip(trailing_header_size); - let update_type: Option = reader.read_u8().try_into(); + let update_type: UpdateType = reader + .read_u8() + .try_into() + .expect(UpdatePriceFeedsError::InvalidUpdateData.into()); + match update_type { - Option::Some(v) => match v { - UpdateType::WormholeMerkle => {} - }, - Option::None => { return Result::Err(UpdatePriceFeedsError::InvalidUpdateData); } - }; + UpdateType::WormholeMerkle => {} + } let wh_proof_size = reader.read_u16(); let wh_proof = reader.read_byte_array(wh_proof_size.into()); @@ -388,22 +379,23 @@ mod pyth { emitter_chain_id: vm.emitter_chain_id, emitter_address: vm.emitter_address }; if !self.is_valid_data_source.read(source) { - return Result::Err(UpdatePriceFeedsError::InvalidUpdateDataSource); + panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateDataSource.into()); } let mut payload_reader = ReaderImpl::new(vm.payload); let x = payload_reader.read_u32(); if x != ACCUMULATOR_WORMHOLE_MAGIC { - return Result::Err(UpdatePriceFeedsError::InvalidUpdateData); + panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into()); } - let update_type: Option = payload_reader.read_u8().try_into(); + let update_type: UpdateType = payload_reader + .read_u8() + .try_into() + .expect(UpdatePriceFeedsError::InvalidUpdateData.into()); + match update_type { - Option::Some(v) => match v { - UpdateType::WormholeMerkle => {} - }, - Option::None => { return Result::Err(UpdatePriceFeedsError::InvalidUpdateData); } - }; + UpdateType::WormholeMerkle => {} + } let _slot = payload_reader.read_u64(); let _ring_size = payload_reader.read_u32(); @@ -419,50 +411,39 @@ mod pyth { let caller = execution_info.caller_address; let contract = execution_info.contract_address; if fee_contract.allowance(caller, contract) < total_fee { - return Result::Err(UpdatePriceFeedsError::InsufficientFeeAllowance); + panic_with_felt252(UpdatePriceFeedsError::InsufficientFeeAllowance.into()); } if !fee_contract.transferFrom(caller, contract, total_fee) { - return Result::Err(UpdatePriceFeedsError::InsufficientFeeAllowance); + panic_with_felt252(UpdatePriceFeedsError::InsufficientFeeAllowance.into()); } let mut i = 0; - let mut result = Result::Ok(()); while i < num_updates { - let r = read_and_verify_message(ref reader, root_digest); - match r { - Result::Ok(message) => { update_latest_price_if_necessary(ref self, message); }, - Result::Err(err) => { - result = Result::Err(err); - break; - } - } + let message = read_and_verify_message(ref reader, root_digest); + update_latest_price_if_necessary(ref self, message); i += 1; }; - result?; if reader.len() != 0 { - return Result::Err(UpdatePriceFeedsError::InvalidUpdateData); + panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into()); } - - Result::Ok(()) } } - fn read_and_verify_message( - ref reader: Reader, root_digest: u256 - ) -> Result { + fn read_and_verify_message(ref reader: Reader, root_digest: u256) -> PriceFeedMessage { let message_size = reader.read_u16(); let message = reader.read_byte_array(message_size.into()); - read_and_verify_proof(root_digest, @message, ref reader).map_err()?; + read_and_verify_proof(root_digest, @message, ref reader); let mut message_reader = ReaderImpl::new(message); - let message_type: Option = message_reader.read_u8().try_into(); + let message_type: MessageType = message_reader + .read_u8() + .try_into() + .expect(UpdatePriceFeedsError::InvalidUpdateData.into()); + match message_type { - Option::Some(v) => match v { - MessageType::PriceFeed => {} - }, - Option::None => { return Result::Err(UpdatePriceFeedsError::InvalidUpdateData); } - }; + MessageType::PriceFeed => {} + } let price_id = message_reader.read_u256(); let price = u64_as_i64(message_reader.read_u64()); @@ -473,10 +454,9 @@ mod pyth { let ema_price = u64_as_i64(message_reader.read_u64()); let ema_conf = message_reader.read_u64(); - let message = PriceFeedMessage { + PriceFeedMessage { price_id, price, conf, expo, publish_time, prev_publish_time, ema_price, ema_conf, - }; - Result::Ok(message) + } } fn update_latest_price_if_necessary(ref self: ContractState, message: PriceFeedMessage) { diff --git a/target_chains/starknet/contracts/tests/pyth.cairo b/target_chains/starknet/contracts/tests/pyth.cairo index d56bee99..35a49239 100644 --- a/target_chains/starknet/contracts/tests/pyth.cairo +++ b/target_chains/starknet/contracts/tests/pyth.cairo @@ -55,7 +55,7 @@ fn update_price_feeds_works() { let mut spy = spy_events(SpyOn::One(pyth.contract_address)); start_prank(CheatTarget::One(pyth.contract_address), user.try_into().unwrap()); - pyth.update_price_feeds(good_update1()).unwrap_with_felt252(); + pyth.update_price_feeds(good_update1()); stop_prank(CheatTarget::One(pyth.contract_address)); spy.fetch_events();