refactor(target_chains/starknet): remove Result from merkle_tree and pyth setters (#1548)

* refactor(target_chains/starknet): remove Result from merkle_tree

* refactor(target_chains/starknet): remove Result from pyth contract setters
This commit is contained in:
Pavel Strakhov 2024-05-06 16:21:36 +01:00 committed by GitHub
parent 55cbe62997
commit 42b64ac09f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 59 additions and 75 deletions

View File

@ -3,6 +3,7 @@ use super::reader::{Reader, ReaderImpl};
use super::byte_array::ByteArray;
use super::util::ONE_SHIFT_96;
use core::cmp::{min, max};
use core::panic_with_felt252;
const MERKLE_LEAF_PREFIX: u8 = 0;
const MERKLE_NODE_PREFIX: u8 = 1;
@ -14,6 +15,15 @@ pub enum MerkleVerificationError {
DigestMismatch,
}
impl MerkleVerificationErrorIntoFelt252 of Into<MerkleVerificationError, felt252> {
fn into(self: MerkleVerificationError) -> felt252 {
match self {
MerkleVerificationError::Reader(err) => err.into(),
MerkleVerificationError::DigestMismatch => 'digest mismatch',
}
}
}
#[generate_trait]
impl ResultReaderToMerkleVerification<T> of ResultReaderToMerkleVerificationTrait<T> {
fn map_err(self: Result<T, pyth::reader::Error>) -> Result<T, MerkleVerificationError> {
@ -24,12 +34,11 @@ impl ResultReaderToMerkleVerification<T> of ResultReaderToMerkleVerificationTrai
}
}
fn leaf_hash(mut reader: Reader) -> Result<u256, super::reader::Error> {
fn leaf_hash(mut reader: Reader) -> u256 {
let mut hasher = HasherImpl::new();
hasher.push_u8(MERKLE_LEAF_PREFIX);
hasher.push_reader(ref reader);
let hash = hasher.finalize() / ONE_SHIFT_96;
Result::Ok(hash)
hasher.finalize() / ONE_SHIFT_96
}
fn node_hash(a: u256, b: u256) -> u256 {
@ -40,25 +49,20 @@ fn node_hash(a: u256, b: u256) -> u256 {
hasher.finalize() / ONE_SHIFT_96
}
pub fn read_and_verify_proof(
root_digest: u256, message: @ByteArray, ref reader: Reader
) -> Result<(), MerkleVerificationError> {
pub fn read_and_verify_proof(root_digest: u256, message: @ByteArray, ref reader: Reader) {
let mut message_reader = ReaderImpl::new(message.clone());
let mut current_hash = leaf_hash(message_reader.clone()).map_err()?;
let mut current_hash = leaf_hash(message_reader.clone());
let proof_size = reader.read_u8();
let mut i = 0;
let mut result = Result::Ok(());
while i < proof_size {
let sibling_digest = reader.read_u160();
current_hash = node_hash(current_hash, sibling_digest);
i += 1;
};
result?;
if root_digest != current_hash {
return Result::Err(MerkleVerificationError::DigestMismatch);
panic_with_felt252(MerkleVerificationError::DigestMismatch.into());
}
Result::Ok(())
}

View File

@ -9,11 +9,9 @@ pub use pyth::{Event, PriceFeedUpdateEvent};
pub trait IPyth<T> {
fn get_price_unsafe(self: @T, price_id: u256) -> Result<Price, GetPriceUnsafeError>;
fn get_ema_price_unsafe(self: @T, price_id: u256) -> Result<Price, GetPriceUnsafeError>;
fn set_data_sources(
ref self: T, sources: Array<DataSource>
) -> Result<(), GovernanceActionError>;
fn set_fee(ref self: T, single_update_fee: u256) -> Result<(), GovernanceActionError>;
fn update_price_feeds(ref self: T, data: ByteArray) -> Result<(), UpdatePriceFeedsError>;
fn set_data_sources(ref self: T, sources: Array<DataSource>);
fn set_fee(ref self: T, single_update_fee: u256);
fn update_price_feeds(ref self: T, data: ByteArray);
}
#[derive(Copy, Drop, Debug, Serde, PartialEq)]
@ -333,51 +331,44 @@ mod pyth {
Result::Ok(price)
}
fn set_data_sources(
ref self: ContractState, sources: Array<DataSource>
) -> Result<(), GovernanceActionError> {
fn set_data_sources(ref self: ContractState, sources: Array<DataSource>) {
if self.owner.read() != get_caller_address() {
return Result::Err(GovernanceActionError::AccessDenied);
panic_with_felt252(GovernanceActionError::AccessDenied.into());
}
write_data_sources(ref self, sources);
Result::Ok(())
}
fn set_fee(
ref self: ContractState, single_update_fee: u256
) -> Result<(), GovernanceActionError> {
fn set_fee(ref self: ContractState, single_update_fee: u256) {
if self.owner.read() != get_caller_address() {
return Result::Err(GovernanceActionError::AccessDenied);
panic_with_felt252(GovernanceActionError::AccessDenied.into());
}
self.single_update_fee.write(single_update_fee);
Result::Ok(())
}
fn update_price_feeds(
ref self: ContractState, data: ByteArray
) -> Result<(), UpdatePriceFeedsError> {
fn update_price_feeds(ref self: ContractState, data: ByteArray) {
let mut reader = ReaderImpl::new(data);
let x = reader.read_u32();
if x != ACCUMULATOR_MAGIC {
return Result::Err(UpdatePriceFeedsError::InvalidUpdateData);
panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into());
}
if reader.read_u8() != MAJOR_VERSION {
return Result::Err(UpdatePriceFeedsError::InvalidUpdateData);
panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into());
}
if reader.read_u8() < MINIMUM_ALLOWED_MINOR_VERSION {
return Result::Err(UpdatePriceFeedsError::InvalidUpdateData);
panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into());
}
let trailing_header_size = reader.read_u8();
reader.skip(trailing_header_size);
let update_type: Option<UpdateType> = reader.read_u8().try_into();
let update_type: UpdateType = reader
.read_u8()
.try_into()
.expect(UpdatePriceFeedsError::InvalidUpdateData.into());
match update_type {
Option::Some(v) => match v {
UpdateType::WormholeMerkle => {}
},
Option::None => { return Result::Err(UpdatePriceFeedsError::InvalidUpdateData); }
};
UpdateType::WormholeMerkle => {}
}
let wh_proof_size = reader.read_u16();
let wh_proof = reader.read_byte_array(wh_proof_size.into());
@ -388,22 +379,23 @@ mod pyth {
emitter_chain_id: vm.emitter_chain_id, emitter_address: vm.emitter_address
};
if !self.is_valid_data_source.read(source) {
return Result::Err(UpdatePriceFeedsError::InvalidUpdateDataSource);
panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateDataSource.into());
}
let mut payload_reader = ReaderImpl::new(vm.payload);
let x = payload_reader.read_u32();
if x != ACCUMULATOR_WORMHOLE_MAGIC {
return Result::Err(UpdatePriceFeedsError::InvalidUpdateData);
panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into());
}
let update_type: Option<UpdateType> = payload_reader.read_u8().try_into();
let update_type: UpdateType = payload_reader
.read_u8()
.try_into()
.expect(UpdatePriceFeedsError::InvalidUpdateData.into());
match update_type {
Option::Some(v) => match v {
UpdateType::WormholeMerkle => {}
},
Option::None => { return Result::Err(UpdatePriceFeedsError::InvalidUpdateData); }
};
UpdateType::WormholeMerkle => {}
}
let _slot = payload_reader.read_u64();
let _ring_size = payload_reader.read_u32();
@ -419,50 +411,39 @@ mod pyth {
let caller = execution_info.caller_address;
let contract = execution_info.contract_address;
if fee_contract.allowance(caller, contract) < total_fee {
return Result::Err(UpdatePriceFeedsError::InsufficientFeeAllowance);
panic_with_felt252(UpdatePriceFeedsError::InsufficientFeeAllowance.into());
}
if !fee_contract.transferFrom(caller, contract, total_fee) {
return Result::Err(UpdatePriceFeedsError::InsufficientFeeAllowance);
panic_with_felt252(UpdatePriceFeedsError::InsufficientFeeAllowance.into());
}
let mut i = 0;
let mut result = Result::Ok(());
while i < num_updates {
let r = read_and_verify_message(ref reader, root_digest);
match r {
Result::Ok(message) => { update_latest_price_if_necessary(ref self, message); },
Result::Err(err) => {
result = Result::Err(err);
break;
}
}
let message = read_and_verify_message(ref reader, root_digest);
update_latest_price_if_necessary(ref self, message);
i += 1;
};
result?;
if reader.len() != 0 {
return Result::Err(UpdatePriceFeedsError::InvalidUpdateData);
panic_with_felt252(UpdatePriceFeedsError::InvalidUpdateData.into());
}
Result::Ok(())
}
}
fn read_and_verify_message(
ref reader: Reader, root_digest: u256
) -> Result<PriceFeedMessage, UpdatePriceFeedsError> {
fn read_and_verify_message(ref reader: Reader, root_digest: u256) -> PriceFeedMessage {
let message_size = reader.read_u16();
let message = reader.read_byte_array(message_size.into());
read_and_verify_proof(root_digest, @message, ref reader).map_err()?;
read_and_verify_proof(root_digest, @message, ref reader);
let mut message_reader = ReaderImpl::new(message);
let message_type: Option<MessageType> = message_reader.read_u8().try_into();
let message_type: MessageType = message_reader
.read_u8()
.try_into()
.expect(UpdatePriceFeedsError::InvalidUpdateData.into());
match message_type {
Option::Some(v) => match v {
MessageType::PriceFeed => {}
},
Option::None => { return Result::Err(UpdatePriceFeedsError::InvalidUpdateData); }
};
MessageType::PriceFeed => {}
}
let price_id = message_reader.read_u256();
let price = u64_as_i64(message_reader.read_u64());
@ -473,10 +454,9 @@ mod pyth {
let ema_price = u64_as_i64(message_reader.read_u64());
let ema_conf = message_reader.read_u64();
let message = PriceFeedMessage {
PriceFeedMessage {
price_id, price, conf, expo, publish_time, prev_publish_time, ema_price, ema_conf,
};
Result::Ok(message)
}
}
fn update_latest_price_if_necessary(ref self: ContractState, message: PriceFeedMessage) {

View File

@ -55,7 +55,7 @@ fn update_price_feeds_works() {
let mut spy = spy_events(SpyOn::One(pyth.contract_address));
start_prank(CheatTarget::One(pyth.contract_address), user.try_into().unwrap());
pyth.update_price_feeds(good_update1()).unwrap_with_felt252();
pyth.update_price_feeds(good_update1());
stop_prank(CheatTarget::One(pyth.contract_address));
spy.fetch_events();