refactor(target_chains/starknet): remove Result from wormhole (#1541)

This commit is contained in:
Pavel Strakhov 2024-05-06 11:27:28 +01:00 committed by GitHub
parent ff6b11023c
commit 94b36c4961
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 49 additions and 112 deletions

View File

@ -55,8 +55,7 @@ pub impl HasherImpl of HasherTrait {
/// Reads all remaining data from the reader and pushes it to /// Reads all remaining data from the reader and pushes it to
/// the hashing buffer. /// the hashing buffer.
fn push_reader(ref self: Hasher, ref reader: Reader) -> Result<(), pyth::reader::Error> { fn push_reader(ref self: Hasher, ref reader: Reader) {
let mut result = Result::Ok(());
while reader.len() > 0 { while reader.len() > 0 {
let mut chunk_len = 8 - self.num_last_bytes; let mut chunk_len = 8 - self.num_last_bytes;
if reader.len() < chunk_len.into() { if reader.len() < chunk_len.into() {
@ -66,8 +65,7 @@ pub impl HasherImpl of HasherTrait {
let value = reader.read_num_bytes(chunk_len); let value = reader.read_num_bytes(chunk_len);
// chunk_len <= 8 so value must fit in u64. // chunk_len <= 8 so value must fit in u64.
self.push_to_last(value.try_into().expect(UNEXPECTED_OVERFLOW), chunk_len); self.push_to_last(value.try_into().expect(UNEXPECTED_OVERFLOW), chunk_len);
}; }
result
} }
/// Returns the keccak256 hash of the buffer. The output hash is interpreted /// Returns the keccak256 hash of the buffer. The output hash is interpreted

View File

@ -27,7 +27,7 @@ impl ResultReaderToMerkleVerification<T> of ResultReaderToMerkleVerificationTrai
fn leaf_hash(mut reader: Reader) -> Result<u256, super::reader::Error> { fn leaf_hash(mut reader: Reader) -> Result<u256, super::reader::Error> {
let mut hasher = HasherImpl::new(); let mut hasher = HasherImpl::new();
hasher.push_u8(MERKLE_LEAF_PREFIX); hasher.push_u8(MERKLE_LEAF_PREFIX);
hasher.push_reader(ref reader)?; hasher.push_reader(ref reader);
let hash = hasher.finalize() / ONE_SHIFT_96; let hash = hasher.finalize() / ONE_SHIFT_96;
Result::Ok(hash) Result::Ok(hash)
} }

View File

@ -382,7 +382,7 @@ mod pyth {
let wh_proof_size = reader.read_u16(); let wh_proof_size = reader.read_u16();
let wh_proof = reader.read_byte_array(wh_proof_size.into()); let wh_proof = reader.read_byte_array(wh_proof_size.into());
let wormhole = IWormholeDispatcher { contract_address: self.wormhole_address.read() }; let wormhole = IWormholeDispatcher { contract_address: self.wormhole_address.read() };
let vm = wormhole.parse_and_verify_vm(wh_proof).map_err()?; let vm = wormhole.parse_and_verify_vm(wh_proof);
let source = DataSource { let source = DataSource {
emitter_chain_id: vm.emitter_chain_id, emitter_address: vm.emitter_address emitter_chain_id: vm.emitter_chain_id, emitter_address: vm.emitter_address

View File

@ -4,10 +4,8 @@ use pyth::util::UnwrapWithFelt252;
#[starknet::interface] #[starknet::interface]
pub trait IWormhole<T> { pub trait IWormhole<T> {
fn submit_new_guardian_set( fn submit_new_guardian_set(ref self: T, set_index: u32, guardians: Array<felt252>);
ref self: T, set_index: u32, guardians: Array<felt252> fn parse_and_verify_vm(self: @T, encoded_vm: ByteArray) -> VM;
) -> Result<(), SubmitNewGuardianSetError>;
fn parse_and_verify_vm(self: @T, encoded_vm: ByteArray) -> Result<VM, ParseAndVerifyVmError>;
} }
#[derive(Drop, Debug, Clone, Serde)] #[derive(Drop, Debug, Clone, Serde)]
@ -127,16 +125,6 @@ mod wormhole {
use pyth::hash::{Hasher, HasherImpl}; use pyth::hash::{Hasher, HasherImpl};
use pyth::util::{ONE_SHIFT_160, UNEXPECTED_OVERFLOW}; use pyth::util::{ONE_SHIFT_160, UNEXPECTED_OVERFLOW};
#[generate_trait]
impl ResultReaderToWormhole<T> of ResultReaderToWormholeTrait<T> {
fn map_err(self: Result<T, pyth::reader::Error>) -> Result<T, ParseAndVerifyVmError> {
match self {
Result::Ok(v) => Result::Ok(v),
Result::Err(err) => Result::Err(ParseAndVerifyVmError::Reader(err)),
}
}
}
#[derive(Drop, Debug, Clone, Serde, starknet::Store)] #[derive(Drop, Debug, Clone, Serde, starknet::Store)]
struct GuardianSet { struct GuardianSet {
num_guardians: usize, num_guardians: usize,
@ -159,29 +147,24 @@ mod wormhole {
) { ) {
self.owner.write(owner); self.owner.write(owner);
let set_index = 0; let set_index = 0;
store_guardian_set(ref self, set_index, @initial_guardians).unwrap_with_felt252(); store_guardian_set(ref self, set_index, @initial_guardians);
} }
fn store_guardian_set( fn store_guardian_set(ref self: ContractState, set_index: u32, guardians: @Array<felt252>) {
ref self: ContractState, set_index: u32, guardians: @Array<felt252>
) -> Result<(), SubmitNewGuardianSetError> {
if guardians.len() == 0 { if guardians.len() == 0 {
return Result::Err(SubmitNewGuardianSetError::NoGuardiansSpecified.into()); panic_with_felt252(SubmitNewGuardianSetError::NoGuardiansSpecified.into());
} }
if guardians.len() >= 256 { if guardians.len() >= 256 {
return Result::Err(SubmitNewGuardianSetError::TooManyGuardians.into()); panic_with_felt252(SubmitNewGuardianSetError::TooManyGuardians.into());
} }
let mut i = 0; let mut i = 0;
let mut result = Result::Ok(());
while i < guardians.len() { while i < guardians.len() {
if *guardians.at(i) == 0 { if *guardians.at(i) == 0 {
result = Result::Err(SubmitNewGuardianSetError::InvalidGuardianKey.into()); panic_with_felt252(SubmitNewGuardianSetError::InvalidGuardianKey.into());
break;
} }
i += 1; i += 1;
}; };
result?;
let set = GuardianSet { num_guardians: guardians.len(), expiration_time: 0 }; let set = GuardianSet { num_guardians: guardians.len(), expiration_time: 0 };
self.guardian_sets.write(set_index, set); self.guardian_sets.write(set_index, set);
@ -195,7 +178,6 @@ mod wormhole {
i += 1; i += 1;
}; };
self.current_guardian_set_index.write(set_index); self.current_guardian_set_index.write(set_index);
Result::Ok(())
} }
fn expire_guardian_set(ref self: ContractState, set_index: u32, now: u64) { fn expire_guardian_set(ref self: ContractState, set_index: u32, now: u64) {
@ -208,41 +190,37 @@ mod wormhole {
impl WormholeImpl of IWormhole<ContractState> { impl WormholeImpl of IWormhole<ContractState> {
fn submit_new_guardian_set( fn submit_new_guardian_set(
ref self: ContractState, set_index: u32, guardians: Array<felt252> ref self: ContractState, set_index: u32, guardians: Array<felt252>
) -> Result<(), SubmitNewGuardianSetError> { ) {
let execution_info = get_execution_info().unbox(); let execution_info = get_execution_info().unbox();
if self.owner.read() != execution_info.caller_address { if self.owner.read() != execution_info.caller_address {
return Result::Err(SubmitNewGuardianSetError::AccessDenied); panic_with_felt252(SubmitNewGuardianSetError::AccessDenied.into());
} }
let current_set_index = self.current_guardian_set_index.read(); let current_set_index = self.current_guardian_set_index.read();
if set_index != current_set_index + 1 { if set_index != current_set_index + 1 {
return Result::Err(SubmitNewGuardianSetError::InvalidGuardianSetSequence.into()); panic_with_felt252(SubmitNewGuardianSetError::InvalidGuardianSetSequence.into());
} }
store_guardian_set(ref self, set_index, @guardians)?; store_guardian_set(ref self, set_index, @guardians);
expire_guardian_set( expire_guardian_set(
ref self, current_set_index, execution_info.block_info.unbox().block_timestamp ref self, current_set_index, execution_info.block_info.unbox().block_timestamp
); );
Result::Ok(())
} }
fn parse_and_verify_vm( fn parse_and_verify_vm(self: @ContractState, encoded_vm: ByteArray) -> VM {
self: @ContractState, encoded_vm: ByteArray let (vm, body_hash) = parse_vm(encoded_vm);
) -> Result<VM, ParseAndVerifyVmError> {
let (vm, body_hash) = parse_vm(encoded_vm)?;
let guardian_set = self.guardian_sets.read(vm.guardian_set_index); let guardian_set = self.guardian_sets.read(vm.guardian_set_index);
if guardian_set.num_guardians == 0 { if guardian_set.num_guardians == 0 {
return Result::Err(ParseAndVerifyVmError::InvalidGuardianSetIndex); panic_with_felt252(ParseAndVerifyVmError::InvalidGuardianSetIndex.into());
} }
if vm.guardian_set_index != self.current_guardian_set_index.read() if vm.guardian_set_index != self.current_guardian_set_index.read()
&& guardian_set.expiration_time < get_block_timestamp() { && guardian_set.expiration_time < get_block_timestamp() {
return Result::Err(ParseAndVerifyVmError::GuardianSetExpired); panic_with_felt252(ParseAndVerifyVmError::GuardianSetExpired.into());
} }
if vm.signatures.len() < quorum(guardian_set.num_guardians) { if vm.signatures.len() < quorum(guardian_set.num_guardians) {
return Result::Err(ParseAndVerifyVmError::NoQuorum); panic_with_felt252(ParseAndVerifyVmError::NoQuorum.into());
} }
let mut signatures_clone = vm.signatures.clone(); let mut signatures_clone = vm.signatures.clone();
let mut last_index = Option::None; let mut last_index = Option::None;
let mut result = Result::Ok(());
loop { loop {
let signature = match signatures_clone.pop_front() { let signature = match signatures_clone.pop_front() {
Option::Some(v) => { v }, Option::Some(v) => { v },
@ -252,8 +230,7 @@ mod wormhole {
match last_index { match last_index {
Option::Some(last_index) => { Option::Some(last_index) => {
if *(@signature).guardian_index <= last_index { if *(@signature).guardian_index <= last_index {
result = Result::Err(ParseAndVerifyVmError::InvalidSignatureOrder); panic_with_felt252(ParseAndVerifyVmError::InvalidSignatureOrder.into());
break;
} }
}, },
Option::None => {}, Option::None => {},
@ -261,43 +238,33 @@ mod wormhole {
last_index = Option::Some(*(@signature).guardian_index); last_index = Option::Some(*(@signature).guardian_index);
if signature.guardian_index.into() >= guardian_set.num_guardians { if signature.guardian_index.into() >= guardian_set.num_guardians {
result = Result::Err(ParseAndVerifyVmError::InvalidGuardianIndex); panic_with_felt252(ParseAndVerifyVmError::InvalidGuardianIndex.into());
break;
} }
let guardian_key = self let guardian_key = self
.guardian_keys .guardian_keys
.read((vm.guardian_set_index, signature.guardian_index)); .read((vm.guardian_set_index, signature.guardian_index));
let r = verify_signature(body_hash, signature.signature, guardian_key); verify_signature(body_hash, signature.signature, guardian_key);
if r.is_err() {
result = r;
break;
}
}; };
result?; vm
Result::Ok(vm)
} }
} }
fn parse_signature(ref reader: Reader) -> Result<GuardianSignature, ParseAndVerifyVmError> { fn parse_signature(ref reader: Reader) -> GuardianSignature {
let guardian_index = reader.read_u8(); let guardian_index = reader.read_u8();
let r = reader.read_u256(); let r = reader.read_u256();
let s = reader.read_u256(); let s = reader.read_u256();
let recovery_id = reader.read_u8(); let recovery_id = reader.read_u8();
let y_parity = (recovery_id % 2) > 0; let y_parity = (recovery_id % 2) > 0;
let signature = GuardianSignature { GuardianSignature { guardian_index, signature: Signature { r, s, y_parity } }
guardian_index, signature: Signature { r, s, y_parity }
};
Result::Ok(signature)
} }
fn parse_vm(encoded_vm: ByteArray) -> Result<(VM, u256), ParseAndVerifyVmError> { fn parse_vm(encoded_vm: ByteArray) -> (VM, u256) {
let mut reader = ReaderImpl::new(encoded_vm); let mut reader = ReaderImpl::new(encoded_vm);
let version = reader.read_u8(); let version = reader.read_u8();
if version != 1 { if version != 1 {
return Result::Err(ParseAndVerifyVmError::VmVersionIncompatible); panic_with_felt252(ParseAndVerifyVmError::VmVersionIncompatible.into());
} }
let guardian_set_index = reader.read_u32(); let guardian_set_index = reader.read_u32();
@ -305,22 +272,14 @@ mod wormhole {
let mut i = 0; let mut i = 0;
let mut signatures = array![]; let mut signatures = array![];
let mut result = Result::Ok(());
while i < sig_count { while i < sig_count {
match parse_signature(ref reader) { signatures.append(parse_signature(ref reader));
Result::Ok(signature) => { signatures.append(signature); },
Result::Err(err) => {
result = Result::Err(err);
break;
},
}
i += 1; i += 1;
}; };
result?;
let mut reader_for_hash = reader.clone(); let mut reader_for_hash = reader.clone();
let mut hasher = HasherImpl::new(); let mut hasher = HasherImpl::new();
hasher.push_reader(ref reader_for_hash).map_err()?; hasher.push_reader(ref reader_for_hash);
let body_hash1 = hasher.finalize(); let body_hash1 = hasher.finalize();
let mut hasher2 = HasherImpl::new(); let mut hasher2 = HasherImpl::new();
hasher2.push_u256(body_hash1); hasher2.push_u256(body_hash1);
@ -347,32 +306,25 @@ mod wormhole {
consistency_level, consistency_level,
payload, payload,
}; };
Result::Ok((vm, body_hash2)) (vm, body_hash2)
} }
fn verify_signature( fn verify_signature(body_hash: u256, signature: Signature, guardian_key: u256,) {
body_hash: u256, signature: Signature, guardian_key: u256,
) -> Result<(), ParseAndVerifyVmError> {
let point: Secp256k1Point = recover_public_key(body_hash, signature) let point: Secp256k1Point = recover_public_key(body_hash, signature)
.ok_or(ParseAndVerifyVmError::InvalidSignature)?; .expect(ParseAndVerifyVmError::InvalidSignature.into());
let address = eth_address(point)?; let address = eth_address(point);
assert(guardian_key != 0, SubmitNewGuardianSetError::InvalidGuardianKey.into()); assert(guardian_key != 0, SubmitNewGuardianSetError::InvalidGuardianKey.into());
if address != guardian_key { if address != guardian_key {
return Result::Err(ParseAndVerifyVmError::InvalidSignature); panic_with_felt252(ParseAndVerifyVmError::InvalidSignature.into());
} }
Result::Ok(())
} }
fn eth_address(point: Secp256k1Point) -> Result<u256, ParseAndVerifyVmError> { fn eth_address(point: Secp256k1Point) -> u256 {
let (x, y) = match point.get_coordinates() { let (x, y) = point.get_coordinates().expect(ParseAndVerifyVmError::InvalidSignature.into());
Result::Ok(v) => { v },
Result::Err(_) => { return Result::Err(ParseAndVerifyVmError::InvalidSignature); },
};
let mut hasher = HasherImpl::new(); let mut hasher = HasherImpl::new();
hasher.push_u256(x); hasher.push_u256(x);
hasher.push_u256(y); hasher.push_u256(y);
let address = hasher.finalize() % ONE_SHIFT_160; hasher.finalize() % ONE_SHIFT_160
Result::Ok(address)
} }
} }

View File

@ -11,7 +11,7 @@ fn test_parse_and_verify_vm_works() {
let owner = 'owner'.try_into().unwrap(); let owner = 'owner'.try_into().unwrap();
let dispatcher = deploy_and_init(owner); let dispatcher = deploy_and_init(owner);
let vm = dispatcher.parse_and_verify_vm(good_vm1()).unwrap(); let vm = dispatcher.parse_and_verify_vm(good_vm1());
assert!(vm.version == 1); assert!(vm.version == 1);
assert!(vm.guardian_set_index == 3); assert!(vm.guardian_set_index == 3);
assert!(vm.signatures.len() == 13); assert!(vm.signatures.len() == 13);
@ -36,26 +36,13 @@ fn test_parse_and_verify_vm_works() {
#[test] #[test]
#[fuzzer(runs: 100, seed: 0)] #[fuzzer(runs: 100, seed: 0)]
#[should_panic(expected: ('any_expected',))] #[should_panic]
fn test_parse_and_verify_vm_rejects_corrupted_vm(pos: usize, random1: usize, random2: usize) { fn test_parse_and_verify_vm_rejects_corrupted_vm(pos: usize, random1: usize, random2: usize) {
let owner = 'owner'.try_into().unwrap(); let owner = 'owner'.try_into().unwrap();
let dispatcher = deploy_and_init(owner); let dispatcher = deploy_and_init(owner);
let r = dispatcher.parse_and_verify_vm(corrupted_vm(pos, random1, random2)); let vm = dispatcher.parse_and_verify_vm(corrupted_vm(pos, random1, random2));
match r { println!("no error, output: {:?}", vm);
Result::Ok(v) => { println!("no error, output: {:?}", v); },
Result::Err(err) => {
if err == ParseAndVerifyVmError::InvalidSignature
|| err == ParseAndVerifyVmError::InvalidGuardianIndex
|| err == ParseAndVerifyVmError::InvalidGuardianSetIndex
|| err == ParseAndVerifyVmError::VmVersionIncompatible
|| err == ParseAndVerifyVmError::Reader(pyth::reader::Error::UnexpectedEndOfInput) {
panic_with_felt252('any_expected');
} else {
panic_with_felt252(err.into());
}
},
}
} }
#[test] #[test]
@ -64,7 +51,7 @@ fn test_submit_guardian_set_rejects_wrong_owner() {
let owner = 'owner'.try_into().unwrap(); let owner = 'owner'.try_into().unwrap();
let dispatcher = deploy(owner, guardian_set1()); let dispatcher = deploy(owner, guardian_set1());
start_prank(CheatTarget::One(dispatcher.contract_address), 'baddy'.try_into().unwrap()); start_prank(CheatTarget::One(dispatcher.contract_address), 'baddy'.try_into().unwrap());
dispatcher.submit_new_guardian_set(1, guardian_set1()).unwrap_with_felt252(); dispatcher.submit_new_guardian_set(1, guardian_set1());
} }
#[test] #[test]
@ -74,8 +61,8 @@ fn test_submit_guardian_set_rejects_wrong_index() {
let dispatcher = deploy(owner, guardian_set1()); let dispatcher = deploy(owner, guardian_set1());
start_prank(CheatTarget::One(dispatcher.contract_address), owner.try_into().unwrap()); start_prank(CheatTarget::One(dispatcher.contract_address), owner.try_into().unwrap());
dispatcher.submit_new_guardian_set(1, guardian_set1()).unwrap_with_felt252(); dispatcher.submit_new_guardian_set(1, guardian_set1());
dispatcher.submit_new_guardian_set(3, guardian_set3()).unwrap_with_felt252(); dispatcher.submit_new_guardian_set(3, guardian_set3());
} }
#[test] #[test]
@ -92,7 +79,7 @@ fn test_submit_guardian_set_rejects_empty() {
let dispatcher = deploy(owner, guardian_set1()); let dispatcher = deploy(owner, guardian_set1());
start_prank(CheatTarget::One(dispatcher.contract_address), owner.try_into().unwrap()); start_prank(CheatTarget::One(dispatcher.contract_address), owner.try_into().unwrap());
dispatcher.submit_new_guardian_set(1, array![]).unwrap_with_felt252(); dispatcher.submit_new_guardian_set(1, array![]);
} }
fn deploy(owner: ContractAddress, guardians: Array<felt252>) -> IWormholeDispatcher { fn deploy(owner: ContractAddress, guardians: Array<felt252>) -> IWormholeDispatcher {
@ -113,9 +100,9 @@ pub fn deploy_and_init(owner: ContractAddress) -> IWormholeDispatcher {
let dispatcher = deploy(owner, guardian_set1()); let dispatcher = deploy(owner, guardian_set1());
start_prank(CheatTarget::One(dispatcher.contract_address), owner.try_into().unwrap()); start_prank(CheatTarget::One(dispatcher.contract_address), owner.try_into().unwrap());
dispatcher.submit_new_guardian_set(1, guardian_set1()).unwrap_with_felt252(); dispatcher.submit_new_guardian_set(1, guardian_set1());
dispatcher.submit_new_guardian_set(2, guardian_set2()).unwrap_with_felt252(); dispatcher.submit_new_guardian_set(2, guardian_set2());
dispatcher.submit_new_guardian_set(3, guardian_set3()).unwrap_with_felt252(); dispatcher.submit_new_guardian_set(3, guardian_set3());
stop_prank(CheatTarget::One(dispatcher.contract_address)); stop_prank(CheatTarget::One(dispatcher.contract_address));
dispatcher dispatcher