diff --git a/target_chains/near/receiver/src/governance.rs b/target_chains/near/receiver/src/governance.rs index d3a18125..11dcc0c4 100644 --- a/target_chains/near/receiver/src/governance.rs +++ b/target_chains/near/receiver/src/governance.rs @@ -380,11 +380,6 @@ impl Pyth { } } - #[private] - pub fn set_upgrade_hash(&mut self, codehash: [u8; 32]) { - self.codehash = codehash; - } - #[private] #[handle_result] pub fn authorize_gov_source_transfer(&mut self, claim_vaa: Vec) -> Result<(), Error> { @@ -469,11 +464,14 @@ impl Pyth { /// this method as a normal public method. #[handle_result] pub(crate) fn upgrade(&mut self, new_code: Vec) -> Result { - let signature = env::sha256(&new_code); + let signature = TryInto::<[u8; 32]>::try_into(env::sha256(&new_code)).unwrap(); + let default = <[u8; 32] as Default>::default(); + ensure!(signature != default, UnauthorizedUpgrade); ensure!(signature == self.codehash, UnauthorizedUpgrade); Ok(Promise::new(env::current_account_id()) .deploy_contract(new_code) + .then(Self::ext(env::current_account_id()).migrate()) .then(Self::ext(env::current_account_id()).refund_upgrade( env::predecessor_account_id(), env::attached_deposit(), @@ -502,6 +500,10 @@ impl Pyth { .then_some(()) .ok_or(UnknownSource(source.emitter)) } + + pub fn set_upgrade_hash(&mut self, codehash: [u8; 32]) { + self.codehash = codehash; + } } #[cfg(test)] @@ -541,7 +543,6 @@ mod tests { let mut contract = Pyth::new( near_sdk::AccountId::new_unchecked("pyth.near".to_owned()), - [0; 32], Source::default(), Source::default(), 0.into(), @@ -561,7 +562,6 @@ mod tests { let mut contract = Pyth::new( near_sdk::AccountId::new_unchecked("pyth.near".to_owned()), - [0; 32], Source::default(), Source::default(), 0.into(), @@ -580,7 +580,6 @@ mod tests { let mut contract = Pyth::new( near_sdk::AccountId::new_unchecked("pyth.near".to_owned()), - [0; 32], Source::default(), Source::default(), 0.into(), @@ -599,7 +598,6 @@ mod tests { let mut contract = Pyth::new( near_sdk::AccountId::new_unchecked("pyth.near".to_owned()), - [0; 32], Source::default(), Source::default(), 0.into(), diff --git a/target_chains/near/receiver/src/lib.rs b/target_chains/near/receiver/src/lib.rs index 3772b2a2..8cbf2b2f 100644 --- a/target_chains/near/receiver/src/lib.rs +++ b/target_chains/near/receiver/src/lib.rs @@ -123,7 +123,6 @@ impl Pyth { #[allow(clippy::new_without_default)] pub fn new( wormhole: AccountId, - codehash: [u8; 32], initial_source: Source, gov_source: Source, update_fee: U128, @@ -140,14 +139,44 @@ impl Pyth { gov_source, sources, wormhole, - codehash, + codehash: Default::default(), update_fee: update_fee.into(), } } + #[private] #[init(ignore_state)] pub fn migrate() -> Self { - let state: Self = env::state_read().expect("Failed to read state"); + // This currently deserializes and produces the same state, I.E migration is a no-op to the + // current state. We only update the codehash to prevent re-upgrading. + // + // In the case where we want to actually migrate to a new state, we can do this by defining + // the old State struct here and then deserializing into that, then migrating into the new + // state, example code for the future reader: + // + // ```rust + // pub fn migrate() -> Self { + // pub struct OldPyth { + // sources: UnorderedSet, + // gov_source: Source, + // executed_governance_vaa: u64, + // executed_governance_change_vaa: u64, + // prices: UnorderedMap, + // wormhole: AccountId, + // codehash: [u8; 32], + // stale_threshold: Duration, + // update_fee: u128, + // } + // + // // Construct new Pyth State from old, perform any migrations needed. + // let old: OldPyth = env::state_read().expect("Failed to read state"); + // Self { + // ... + // } + // } + // ``` + let mut state: Self = env::state_read().expect("Failed to read state"); + state.codehash = Default::default(); state } diff --git a/target_chains/near/receiver/src/tests.rs b/target_chains/near/receiver/src/tests.rs index 2c78f8b2..f5285a65 100644 --- a/target_chains/near/receiver/src/tests.rs +++ b/target_chains/near/receiver/src/tests.rs @@ -16,7 +16,6 @@ mod tests { fn create_contract() -> Pyth { Pyth::new( "wormhole.near".parse().unwrap(), - [0; 32], Source::default(), Source::default(), 1.into(),