diff --git a/ethereum/contracts/Getters.sol b/ethereum/contracts/Getters.sol index 661ada1da..511374433 100644 --- a/ethereum/contracts/Getters.sol +++ b/ethereum/contracts/Getters.sol @@ -28,8 +28,7 @@ contract Getters is State { function chainId() public view returns (uint16) { if (evmChainId() != block.chainid) { - // reduce the likelihood of forked chain ID collisions - return type(uint16).max - 32 + uint16(block.chainid % 32); + return type(uint16).max; } return _state.provider.chainId; } diff --git a/ethereum/contracts/Governance.sol b/ethereum/contracts/Governance.sol index ec25bfa77..67366e951 100644 --- a/ethereum/contracts/Governance.sol +++ b/ethereum/contracts/Governance.sol @@ -37,7 +37,8 @@ abstract contract Governance is GovernanceStructs, Messages, Setters, ERC1967Upg require(upgrade.module == module, "Invalid Module"); // Verify the VAA is for this chain - require(upgrade.chain == chainId(), "Invalid Chain"); + uint16 chainId = chainId(); + require(upgrade.chain == chainId && chainId != type(uint16).max, "Invalid Chain"); // Record the governance action as consumed setGovernanceActionConsumed(vm.hash); @@ -138,6 +139,34 @@ abstract contract Governance is GovernanceStructs, Messages, Setters, ERC1967Upg recipient.transfer(transfer.amount); } + /** + * @dev Updates the `chainId` and `evmChainId` on a forked chain via Governance VAA/VM + */ + function submitRecoverChainId(bytes memory _vm) public { + require(chainId() == type(uint16).max, "invalid chain"); + + Structs.VM memory vm = parseVM(_vm); + + // Verify the VAA is valid before processing it + (bool isValid, string memory reason) = verifyGovernanceVM(vm); + require(isValid, reason); + + GovernanceStructs.RecoverChainId memory rci = parseRecoverChainId(vm.payload); + + // Verify the VAA is for this module + require(rci.module == module, "invalid Module"); + + // Verify the VAA is for this chain + require(rci.evmChainId == block.chainid, "invalid EVM Chain"); + + // Record the governance action as consumed to prevent reentry + setGovernanceActionConsumed(vm.hash); + + // Update the chainIds + setEvmChainId(rci.evmChainId); + setChainId(rci.newChainId); + } + /** * @dev Upgrades the `currentImplementation` with a `newImplementation` */ diff --git a/ethereum/contracts/GovernanceStructs.sol b/ethereum/contracts/GovernanceStructs.sol index 7e4e05d6f..88d9be2f1 100644 --- a/ethereum/contracts/GovernanceStructs.sol +++ b/ethereum/contracts/GovernanceStructs.sol @@ -52,6 +52,14 @@ contract GovernanceStructs { bytes32 recipient; } + struct RecoverChainId { + bytes32 module; + uint8 action; + + uint256 evmChainId; + uint16 newChainId; + } + /// @dev Parse a contract upgrade (action 1) with minimal validation function parseContractUpgrade(bytes memory encodedUpgrade) public pure returns (ContractUpgrade memory cu) { uint index = 0; @@ -151,4 +159,25 @@ contract GovernanceStructs { require(encodedTransferFees.length == index, "invalid TransferFees"); } + + /// @dev Parse a recoverChainId (action 5) with minimal validation + function parseRecoverChainId(bytes memory encodedRecoverChainId) public pure returns (RecoverChainId memory rci) { + uint index = 0; + + rci.module = encodedRecoverChainId.toBytes32(index); + index += 32; + + rci.action = encodedRecoverChainId.toUint8(index); + index += 1; + + require(rci.action == 5, "invalid RecoverChainId"); + + rci.evmChainId = encodedRecoverChainId.toUint256(index); + index += 32; + + rci.newChainId = encodedRecoverChainId.toUint16(index); + index += 2; + + require(encodedRecoverChainId.length == index, "invalid RecoverChainId"); + } } \ No newline at end of file diff --git a/ethereum/contracts/bridge/Bridge.sol b/ethereum/contracts/bridge/Bridge.sol index df4664516..626fd389d 100644 --- a/ethereum/contracts/bridge/Bridge.sol +++ b/ethereum/contracts/bridge/Bridge.sol @@ -20,15 +20,15 @@ import "./token/TokenImplementation.sol"; contract Bridge is BridgeGovernance, ReentrancyGuard { using BytesLib for bytes; - modifier onlyEvmChainId() { - require(evmChainId() == block.chainid, "invalid evmChainId"); + modifier noFork() { + require(evmChainId() == block.chainid, "bad fork"); _; } /* * @dev Produce a AssetMeta message for a given token */ - function attestToken(address tokenAddress, uint32 nonce) public payable onlyEvmChainId returns (uint64 sequence) { + function attestToken(address tokenAddress, uint32 nonce) public payable returns (uint64 sequence) { // decimals, symbol & token are not part of the core ERC20 token standard, so we need to support contracts that dont implement them (,bytes memory queriedDecimals) = tokenAddress.staticcall(abi.encodeWithSignature("decimals()")); (,bytes memory queriedSymbol) = tokenAddress.staticcall(abi.encodeWithSignature("symbol()")); @@ -71,7 +71,7 @@ contract Bridge is BridgeGovernance, ReentrancyGuard { bytes32 recipient, uint256 arbiterFee, uint32 nonce - ) public payable onlyEvmChainId returns (uint64 sequence) { + ) public payable returns (uint64 sequence) { BridgeStructs.TransferResult memory transferResult = _wrapAndTransferETH(arbiterFee); sequence = logTransfer( @@ -103,7 +103,7 @@ contract Bridge is BridgeGovernance, ReentrancyGuard { bytes32 recipient, uint32 nonce, bytes memory payload - ) public payable onlyEvmChainId returns (uint64 sequence) { + ) public payable returns (uint64 sequence) { BridgeStructs.TransferResult memory transferResult = _wrapAndTransferETH(0); sequence = logTransferWithPayload( @@ -163,7 +163,7 @@ contract Bridge is BridgeGovernance, ReentrancyGuard { bytes32 recipient, uint256 arbiterFee, uint32 nonce - ) public payable nonReentrant onlyEvmChainId returns (uint64 sequence) { + ) public payable nonReentrant returns (uint64 sequence) { BridgeStructs.TransferResult memory transferResult = _transferTokens( token, amount, @@ -200,7 +200,7 @@ contract Bridge is BridgeGovernance, ReentrancyGuard { bytes32 recipient, uint32 nonce, bytes memory payload - ) public payable nonReentrant onlyEvmChainId returns (uint64 sequence) { + ) public payable nonReentrant returns (uint64 sequence) { BridgeStructs.TransferResult memory transferResult = _transferTokens( token, amount, @@ -359,7 +359,7 @@ contract Bridge is BridgeGovernance, ReentrancyGuard { finality() ); } - function updateWrapped(bytes memory encodedVm) external onlyEvmChainId returns (address token) { + function updateWrapped(bytes memory encodedVm) external returns (address token) { (IWormhole.VM memory vm, bool valid, string memory reason) = wormhole().parseAndVerifyVM(encodedVm); require(valid, reason); @@ -379,7 +379,7 @@ contract Bridge is BridgeGovernance, ReentrancyGuard { return wrapped; } - function createWrapped(bytes memory encodedVm) external onlyEvmChainId returns (address token) { + function createWrapped(bytes memory encodedVm) external returns (address token) { (IWormhole.VM memory vm, bool valid, string memory reason) = wormhole().parseAndVerifyVM(encodedVm); require(valid, reason); @@ -489,7 +489,7 @@ contract Bridge is BridgeGovernance, ReentrancyGuard { } // Execute a Transfer message - function _completeTransfer(bytes memory encodedVm, bool unwrapWETH) internal onlyEvmChainId returns (bytes memory) { + function _completeTransfer(bytes memory encodedVm, bool unwrapWETH) internal returns (bytes memory) { (IWormhole.VM memory vm, bool valid, string memory reason) = wormhole().parseAndVerifyVM(encodedVm); require(valid, reason); @@ -581,7 +581,7 @@ contract Bridge is BridgeGovernance, ReentrancyGuard { setOutstandingBridged(token, outstandingBridged(token) - normalizedAmount); } - function verifyBridgeVM(IWormhole.VM memory vm) internal view returns (bool){ + function verifyBridgeVM(IWormhole.VM memory vm) internal view noFork returns (bool){ if (bridgeContracts(vm.emitterChainId) == vm.emitterAddress) { return true; } diff --git a/ethereum/contracts/bridge/BridgeGetters.sol b/ethereum/contracts/bridge/BridgeGetters.sol index 35b6df317..d6bf4e84f 100644 --- a/ethereum/contracts/bridge/BridgeGetters.sol +++ b/ethereum/contracts/bridge/BridgeGetters.sol @@ -28,8 +28,7 @@ contract BridgeGetters is BridgeState { function chainId() public view returns (uint16){ if (evmChainId() != block.chainid) { - // reduce the likelihood of forked chain ID collisions - return type(uint16).max - 32 + uint16(block.chainid % 32); + return type(uint16).max; } return _state.provider.chainId; } diff --git a/ethereum/contracts/bridge/BridgeGovernance.sol b/ethereum/contracts/bridge/BridgeGovernance.sol index c577d5d52..fd0dce803 100644 --- a/ethereum/contracts/bridge/BridgeGovernance.sol +++ b/ethereum/contracts/bridge/BridgeGovernance.sol @@ -48,11 +48,33 @@ contract BridgeGovernance is BridgeGetters, BridgeSetters, ERC1967Upgrade { BridgeStructs.UpgradeContract memory implementation = parseUpgrade(vm.payload); - require(implementation.chainId == chainId(), "wrong chain id"); + uint16 chainId = chainId(); + require(implementation.chainId == chainId && chainId != type(uint16).max, "wrong chain id"); upgradeImplementation(address(uint160(uint256(implementation.newContract)))); } + /** + * @dev Updates the `chainId` and `evmChainId` on a forked chain via Governance VAA/VM + */ + function submitRecoverChainId(bytes memory encodedVM) public { + require(chainId() == type(uint16).max, "invalid chain"); + + (IWormhole.VM memory vm, bool valid, string memory reason) = verifyGovernanceVM(encodedVM); + require(valid, reason); + + setGovernanceActionConsumed(vm.hash); + + BridgeStructs.RecoverChainId memory rci = parseRecoverChainId(vm.payload); + + // Verify the VAA is for this chain + require(rci.evmChainId == block.chainid, "invalid EVM Chain"); + + // Update the chainIds + setEvmChainId(rci.evmChainId); + setChainId(rci.newChainId); + } + function verifyGovernanceVM(bytes memory encodedVM) internal view returns (IWormhole.VM memory parsedVM, bool isValid, string memory invalidReason){ (IWormhole.VM memory vm, bool valid, string memory reason) = wormhole().parseAndVerifyVM(encodedVM); @@ -96,11 +118,11 @@ contract BridgeGovernance is BridgeGetters, BridgeSetters, ERC1967Upgrade { chain.module = encoded.toBytes32(index); index += 32; - require(chain.module == module, "invalid RegisterChain: wrong module"); + require(chain.module == module, "wrong module"); chain.action = encoded.toUint8(index); index += 1; - require(chain.action == 1, "invalid RegisterChain: wrong action"); + require(chain.action == 1, "wrong action"); chain.chainId = encoded.toUint16(index); index += 2; @@ -113,7 +135,7 @@ contract BridgeGovernance is BridgeGetters, BridgeSetters, ERC1967Upgrade { chain.emitterAddress = encoded.toBytes32(index); index += 32; - require(encoded.length == index, "invalid RegisterChain: wrong length"); + require(encoded.length == index, "wrong length"); } function parseUpgrade(bytes memory encoded) public pure returns (BridgeStructs.UpgradeContract memory chain) { @@ -123,11 +145,11 @@ contract BridgeGovernance is BridgeGetters, BridgeSetters, ERC1967Upgrade { chain.module = encoded.toBytes32(index); index += 32; - require(chain.module == module, "invalid UpgradeContract: wrong module"); + require(chain.module == module, "wrong module"); chain.action = encoded.toUint8(index); index += 1; - require(chain.action == 2, "invalid UpgradeContract: wrong action"); + require(chain.action == 2, "wrong action"); chain.chainId = encoded.toUint16(index); index += 2; @@ -137,6 +159,27 @@ contract BridgeGovernance is BridgeGetters, BridgeSetters, ERC1967Upgrade { chain.newContract = encoded.toBytes32(index); index += 32; - require(encoded.length == index, "invalid UpgradeContract: wrong length"); + require(encoded.length == index, "wrong length"); + } + + /// @dev Parse a recoverChainId (action 5) with minimal validation + function parseRecoverChainId(bytes memory encodedRecoverChainId) public pure returns (BridgeStructs.RecoverChainId memory rci) { + uint index = 0; + + rci.module = encodedRecoverChainId.toBytes32(index); + index += 32; + require(rci.module == module, "wrong module"); + + rci.action = encodedRecoverChainId.toUint8(index); + index += 1; + require(rci.action == 5, "wrong action"); + + rci.evmChainId = encodedRecoverChainId.toUint256(index); + index += 32; + + rci.newChainId = encodedRecoverChainId.toUint16(index); + index += 2; + + require(encodedRecoverChainId.length == index, "wrong length"); } } diff --git a/ethereum/contracts/bridge/BridgeStructs.sol b/ethereum/contracts/bridge/BridgeStructs.sol index c69ef8459..b5f5ff3cb 100644 --- a/ethereum/contracts/bridge/BridgeStructs.sol +++ b/ethereum/contracts/bridge/BridgeStructs.sol @@ -95,4 +95,17 @@ contract BridgeStructs { // Address of the new contract bytes32 newContract; } + + struct RecoverChainId { + // Governance Header + // module: "TokenBridge" left-padded + bytes32 module; + // governance action: 5 + uint8 action; + + // EIP-155 Chain ID + uint256 evmChainId; + // Chain ID + uint16 newChainId; + } } diff --git a/ethereum/contracts/nft/NFTBridge.sol b/ethereum/contracts/nft/NFTBridge.sol index 30b8dc984..97772b5e7 100644 --- a/ethereum/contracts/nft/NFTBridge.sol +++ b/ethereum/contracts/nft/NFTBridge.sol @@ -19,13 +19,13 @@ import "./token/NFTImplementation.sol"; contract NFTBridge is NFTBridgeGovernance { using BytesLib for bytes; - modifier onlyEvmChainId() { - require(evmChainId() == block.chainid, "invalid evmChainId"); + modifier noFork() { + require(evmChainId() == block.chainid, "bad fork"); _; } // Initiate a Transfer - function transferNFT(address token, uint256 tokenID, uint16 recipientChain, bytes32 recipient, uint32 nonce) public payable onlyEvmChainId returns (uint64 sequence) { + function transferNFT(address token, uint256 tokenID, uint16 recipientChain, bytes32 recipient, uint32 nonce) public payable returns (uint64 sequence) { // determine token parameters uint16 tokenChain; bytes32 tokenAddress; @@ -102,7 +102,7 @@ contract NFTBridge is NFTBridgeGovernance { }(nonce, encoded, finality()); } - function completeTransfer(bytes memory encodedVm) public onlyEvmChainId { + function completeTransfer(bytes memory encodedVm) public { _completeTransfer(encodedVm); } @@ -197,7 +197,7 @@ contract NFTBridge is NFTBridgeGovernance { setWrappedAsset(tokenChain, tokenAddress, token); } - function verifyBridgeVM(IWormhole.VM memory vm) internal view returns (bool){ + function verifyBridgeVM(IWormhole.VM memory vm) internal view noFork returns (bool){ if (bridgeContracts(vm.emitterChainId) == vm.emitterAddress) { return true; } diff --git a/ethereum/contracts/nft/NFTBridgeGetters.sol b/ethereum/contracts/nft/NFTBridgeGetters.sol index 58e7f8bbf..12546b5e3 100644 --- a/ethereum/contracts/nft/NFTBridgeGetters.sol +++ b/ethereum/contracts/nft/NFTBridgeGetters.sol @@ -28,8 +28,7 @@ contract NFTBridgeGetters is NFTBridgeState { function chainId() public view returns (uint16){ if (evmChainId() != block.chainid) { - // reduce the likelihood of forked chain ID collisions - return type(uint16).max - 32 + uint16(block.chainid % 32); + return type(uint16).max; } return _state.provider.chainId; } diff --git a/ethereum/contracts/nft/NFTBridgeGovernance.sol b/ethereum/contracts/nft/NFTBridgeGovernance.sol index 4199eed8a..060d9a900 100644 --- a/ethereum/contracts/nft/NFTBridgeGovernance.sol +++ b/ethereum/contracts/nft/NFTBridgeGovernance.sol @@ -46,11 +46,33 @@ contract NFTBridgeGovernance is NFTBridgeGetters, NFTBridgeSetters, ERC1967Upgra NFTBridgeStructs.UpgradeContract memory implementation = parseUpgrade(vm.payload); - require(implementation.chainId == chainId(), "wrong chain id"); + uint16 chainId = chainId(); + require(implementation.chainId == chainId && chainId != type(uint16).max, "wrong chain id"); upgradeImplementation(address(uint160(uint256(implementation.newContract)))); } + /** + * @dev Updates the `chainId` and `evmChainId` on a forked chain via Governance VAA/VM + */ + function submitRecoverChainId(bytes memory encodedVM) public { + require(chainId() == type(uint16).max, "invalid chain"); + + (IWormhole.VM memory vm, bool valid, string memory reason) = verifyGovernanceVM(encodedVM); + require(valid, reason); + + setGovernanceActionConsumed(vm.hash); + + NFTBridgeStructs.RecoverChainId memory rci = parseRecoverChainId(vm.payload); + + // Verify the VAA is for this chain + require(rci.evmChainId == block.chainid, "invalid EVM Chain"); + + // Update the chainIds + setEvmChainId(rci.evmChainId); + setChainId(rci.newChainId); + } + function verifyGovernanceVM(bytes memory encodedVM) internal view returns (IWormhole.VM memory parsedVM, bool isValid, string memory invalidReason){ (IWormhole.VM memory vm, bool valid, string memory reason) = wormhole().parseAndVerifyVM(encodedVM); @@ -136,4 +158,25 @@ contract NFTBridgeGovernance is NFTBridgeGetters, NFTBridgeSetters, ERC1967Upgra require(encoded.length == index, "invalid UpgradeContract: wrong length"); } + + /// @dev Parse a recoverChainId (action 5) with minimal validation + function parseRecoverChainId(bytes memory encodedRecoverChainId) public pure returns (NFTBridgeStructs.RecoverChainId memory rci) { + uint index = 0; + + rci.module = encodedRecoverChainId.toBytes32(index); + index += 32; + require(rci.module == module, "invalid RecoverChainId: wrong module"); + + rci.action = encodedRecoverChainId.toUint8(index); + index += 1; + require(rci.action == 5, "invalid RecoverChainId: wrong action"); + + rci.evmChainId = encodedRecoverChainId.toUint256(index); + index += 32; + + rci.newChainId = encodedRecoverChainId.toUint16(index); + index += 2; + + require(encodedRecoverChainId.length == index, "invalid RecoverChainId"); + } } diff --git a/ethereum/contracts/nft/NFTBridgeStructs.sol b/ethereum/contracts/nft/NFTBridgeStructs.sol index ea5b98b43..e05024265 100644 --- a/ethereum/contracts/nft/NFTBridgeStructs.sol +++ b/ethereum/contracts/nft/NFTBridgeStructs.sol @@ -51,4 +51,17 @@ contract NFTBridgeStructs { // Address of the new contract bytes32 newContract; } + + struct RecoverChainId { + // Governance Header + // module: "TokenBridge" left-padded + bytes32 module; + // governance action: 5 + uint8 action; + + // EIP-155 Chain ID + uint256 evmChainId; + // Chain ID + uint16 newChainId; + } } diff --git a/ethereum/test/wormhole.js b/ethereum/test/wormhole.js index 0a32eae33..85f76ea38 100644 --- a/ethereum/test/wormhole.js +++ b/ethereum/test/wormhole.js @@ -18,6 +18,7 @@ const actionContractUpgrade = "01" const actionGuardianSetUpgrade = "02" const actionMessageFee = "03" const actionTransferFee = "04" +const actionRecoverChainId = "05" const ImplementationFullABI = jsonfile.readFileSync("build/contracts/Implementation.json").abi @@ -613,6 +614,97 @@ contract("Wormhole", function () { assert.ok(isUpgraded); }) + it("should revert smart contract upgrades with the bad fork chain ID (uint16 max)", async function () { + const initialized = new web3.eth.Contract(ImplementationFullABI, Wormhole.address); + const accounts = await web3.eth.getAccounts(); + + const mock = await MockImplementation.new(); + + const timestamp = 1000; + const nonce = 1001; + const emitterChainId = testGovernanceChainId; + const emitterAddress = testGovernanceContract + + data = [ + // Core + core, + // Action 1 (Contract Upgrade) + actionContractUpgrade, + // ChainID (max uint16 - bad fork) + web3.eth.abi.encodeParameter("uint16", 65535).substring(2 + (64 - 4)), + // New Contract Address + web3.eth.abi.encodeParameter("address", mock.address).substring(2), + ].join('') + + const vm = await signAndEncodeVM( + timestamp, + nonce, + emitterChainId, + emitterAddress, + 0, + data, + [ + testSigner1PK, + testSigner2PK, + testSigner3PK + ], + 1, + 2 + ); + + try { + await initialized.methods.submitContractUpgrade("0x" + vm).send({ + value: 0, + from: accounts[0], + gasLimit: 1000000 + }); + assert.fail("contract upgrade for bad fork accepted") + } catch (e) { + assert.equal(e.data[Object.keys(e.data)[0]].reason, "Invalid Chain") + } + }) + + it("should revert recover chain ID governance packets on supported (non-bad-fork) chains", async function () { + const initialized = new web3.eth.Contract(ImplementationFullABI, Wormhole.address); + const accounts = await web3.eth.getAccounts(); + + data = [ + // Core + core, + // Action 5 (Recover Chain ID) + actionRecoverChainId, + // EvmChainID + web3.eth.abi.encodeParameter("uint256", 1).substring(2), + // NewChainID + web3.eth.abi.encodeParameter("uint16", testChainId).substring(2 + (64 - 4)), + ].join('') + + const vm = await signAndEncodeVM( + 0, + 0, + testGovernanceChainId, + testGovernanceContract, + 0, + data, + [ + testSigner1PK, + ], + 0, + 2 + ); + + try { + await initialized.methods.submitRecoverChainId("0x" + vm).send({ + value: 0, + from: accounts[0], + gasLimit: 1000000 + }); + assert.fail("recover chain ID governance packet on supported chain accepted") + } catch (e) { + assert.equal(e.data[Object.keys(e.data)[0]].reason, "invalid chain") + } + }) + it("should revert governance packets from old guardian set", async function () { const initialized = new web3.eth.Contract(ImplementationFullABI, Wormhole.address); const accounts = await web3.eth.getAccounts(); @@ -657,7 +749,7 @@ contract("Wormhole", function () { } }) - it("should time out old gardians", async function () { + it("should time out old guardians", async function () { const initialized = new web3.eth.Contract(ImplementationFullABI, Wormhole.address); const timestamp = 1000;