CCQ: High level validation helpers + slight refactor

This commit is contained in:
Dirk Brink 2023-11-15 09:47:52 -08:00 committed by Jeff Schroeder
parent 7a2a19c31b
commit 3c243e0e87
3 changed files with 123 additions and 49 deletions

View File

@ -10,12 +10,11 @@ import "./QueryResponse.sol";
error InvalidOwner();
// @dev for the onlyOwner modifier
error InvalidCaller();
error InvalidContractAddress();
error InvalidCalldata();
error InvalidWormholeAddress();
error InvalidForeignChainID();
error ObsoleteUpdate();
error StaleUpdate();
error UnexpectedCallData();
error UnexpectedResultLength();
error UnexpectedResultMismatch();
@ -32,14 +31,13 @@ contract QueryDemo is QueryResponse {
}
address private immutable owner;
address private immutable wormhole;
uint16 private immutable myChainID;
mapping(uint16 => ChainEntry) private counters;
uint16[] private foreignChainIDs;
bytes4 GetMyCounter = bytes4(hex"916d5743");
constructor(address _owner, address _wormhole, uint16 _myChainID) {
constructor(address _owner, address _wormhole, uint16 _myChainID) QueryResponse(_wormhole) {
if (_owner == address(0)) {
revert InvalidOwner();
}
@ -48,7 +46,7 @@ contract QueryDemo is QueryResponse {
if (_wormhole == address(0)) {
revert InvalidWormholeAddress();
}
wormhole = _wormhole;
myChainID = _myChainID;
counters[_myChainID] = ChainEntry(_myChainID, address(this), 0, 0, 0);
}
@ -87,7 +85,7 @@ contract QueryDemo is QueryResponse {
// @notice Takes the cross chain query response for the other counters, stores the results for the other chains, and updates the counter for this chain.
function updateCounters(bytes memory response, IWormhole.Signature[] memory signatures) public {
uint256 adjustedBlockTime;
ParsedQueryResponse memory r = parseAndVerifyQueryResponse(address(wormhole), response, signatures);
ParsedQueryResponse memory r = parseAndVerifyQueryResponse(response, signatures);
uint numResponses = r.responses.length;
if (numResponses != foreignChainIDs.length) {
revert UnexpectedResultLength();
@ -101,33 +99,24 @@ contract QueryDemo is QueryResponse {
}
EthCallQueryResponse memory eqr = parseEthCallQueryResponse(r.responses[i]);
if (eqr.blockNum <= chainEntry.blockNum) {
revert ObsoleteUpdate();
}
// wormhole time is in microseconds, timestamp is in seconds
adjustedBlockTime = eqr.blockTime / 1_000_000;
if (adjustedBlockTime <= block.timestamp - 300) {
revert StaleUpdate();
}
// Validate that update is not obsolete
validateBlockNum(eqr.blockNum, chainEntry.blockNum, block.number);
// Validate that update is not stale
validateBlockTime(eqr.blockTime, block.timestamp - 300, block.timestamp);
if (eqr.result.length != 1) {
revert UnexpectedResultMismatch();
}
if (eqr.result[0].contractAddress != chainEntry.contractAddress) {
revert InvalidContractAddress();
}
// Validate addresses and function signatures
address[] memory validAddresses = new address[](1);
bytes4[] memory validFunctionSignatures = new bytes4[](1);
validAddresses[0] = chainEntry.contractAddress;
validFunctionSignatures[0] = GetMyCounter;
// TODO: Is there an easier way to verify that the call data is correct!
bytes memory callData = eqr.result[0].callData;
bytes4 result;
assembly {
result := mload(add(callData, 32))
}
if (result != GetMyCounter) {
revert UnexpectedCallData();
}
validateMultipleEthCallData(eqr.result, validAddresses, validFunctionSignatures);
require(eqr.result[0].result.length == 32, "result is not a uint256");

View File

@ -64,6 +64,7 @@ struct EthCallData {
}
// Custom errors
error EmptyWormholeAddress();
error InvalidResponseVersion();
error VersionMismatch();
error ZeroQueries();
@ -73,11 +74,18 @@ error RequestTypeMismatch();
error UnsupportedQueryType();
error UnexpectedNumberOfResults();
error InvalidPayloadLength(uint256 received, uint256 expected);
error InvalidContractAddress();
error InvalidFunctionSignature();
error InvalidChainId();
error InvalidBlockNum();
error InvalidBlockTime();
// @dev QueryResponse is a library that implements the parsing and verification of Cross Chain Query (CCQ) responses.
abstract contract QueryResponse {
using BytesParsing for bytes;
IWormhole public immutable wormhole;
bytes public constant responsePrefix = bytes("query_response_0000000000000000000|");
uint8 public constant VERSION = 1;
uint8 public constant QT_ETH_CALL = 1;
@ -85,6 +93,14 @@ abstract contract QueryResponse {
uint8 public constant QT_ETH_CALL_WITH_FINALITY = 3;
uint8 public constant QT_MAX = 4; // Keep this last
constructor(address _wormhole) {
if (_wormhole == address(0)) {
revert EmptyWormholeAddress();
}
wormhole = IWormhole(_wormhole);
}
/// @dev getResponseHash computes the hash of the specified query response.
function getResponseHash(bytes memory response) public pure returns (bytes32) {
return keccak256(response);
@ -96,8 +112,8 @@ abstract contract QueryResponse {
}
/// @dev parseAndVerifyQueryResponse verifies the query response and returns the parsed response.
function parseAndVerifyQueryResponse(address wormhole, bytes memory response, IWormhole.Signature[] memory signatures) public view returns (ParsedQueryResponse memory r) {
verifyQueryResponseSignatures(wormhole, response, signatures);
function parseAndVerifyQueryResponse(bytes memory response, IWormhole.Signature[] memory signatures) public view returns (ParsedQueryResponse memory r) {
verifyQueryResponseSignatures(response, signatures);
uint index = 0;
@ -341,13 +357,80 @@ abstract contract QueryResponse {
checkLength(pcr.response, respIdx);
}
/// @dev validateBlockTime validates that the parsed block time is in an acceptable range
/// @param _blockTime Wormhole block time in MICROseconds
/// @param _minBlockTime Minium block time in seconds
/// @param _maxBlockTime Maximum block time in seconds
function validateBlockTime(uint64 _blockTime, uint256 _minBlockTime, uint256 _maxBlockTime) public pure {
uint256 blockTimeInSeconds = _blockTime / 1_000_000; // Rounds down
if (blockTimeInSeconds < _minBlockTime || _blockTime > _maxBlockTime) {
revert InvalidBlockTime();
}
}
/// @dev validateBlockNum validates that the parsed blockNum is in an acceptable range
function validateBlockNum(uint64 _blockNum, uint256 _minBlockNum, uint256 _maxBlockNum) public pure {
if (_blockNum < _minBlockNum || _blockNum > _maxBlockNum) {
revert InvalidBlockNum();
}
}
/// @dev validateChainId validates that the parsed chainId is one of an array of chainIds we expect
function validateChainId(uint16 chainId, uint16[] memory _validChainIds) public pure {
for (uint256 i = 0; i < _validChainIds.length; ++i) {
if (chainId == _validChainIds[i]) {
revert InvalidChainId();
}
}
}
/// @dev validateMutlipleEthCallData validates that each EthCallData in an array comes from a function signature and contract address we expect
function validateMultipleEthCallData(EthCallData[] memory r, address[] memory _expectedContractAddresses, bytes4[] memory _expectedFunctionSignatures) public pure {
for (uint256 i = 0; i < r.length; ++i) {
validateEthCallData(r[i], _expectedContractAddresses, _expectedFunctionSignatures);
}
}
/// @dev validateEthCallData validates that EthCallData comes from a function signature and contract address we expect
function validateEthCallData(EthCallData memory r, address[] memory _expectedContractAddresses, bytes4[] memory _expectedFunctionSignatures) public pure {
// An empty array means we accept all addresses/function signatures
bool validContractAddress = _expectedContractAddresses.length == 0 ? true : false;
bool validFunctionSignature = _expectedFunctionSignatures.length == 0 ? true : false;
// Check that the contract address called in the request is expected
for (uint256 i = 0; i < _expectedContractAddresses.length; ++i) {
if (r.contractAddress == _expectedContractAddresses[i]) {
validContractAddress = true;
break;
}
}
// Early exit to save gas
if (!validContractAddress) {
revert InvalidContractAddress();
}
// Check that the function signature called is expected
for (uint256 i = 0; i < _expectedFunctionSignatures.length; ++i) {
(bytes4 funcSig,) = r.callData.asBytes4Unchecked(0);
if (funcSig == _expectedFunctionSignatures[i]) {
validFunctionSignature = true;
break;
}
}
if (!validFunctionSignature) {
revert InvalidFunctionSignature();
}
}
/**
* @dev verifyQueryResponseSignatures verifies the signatures on a query response. It calls into the Wormhole contract.
* IWormhole.Signature expects the last byte to be bumped by 27
* see https://github.com/wormhole-foundation/wormhole/blob/637b1ee657de7de05f783cbb2078dd7d8bfda4d0/ethereum/contracts/Messages.sol#L174
*/
function verifyQueryResponseSignatures(address _wormhole, bytes memory response, IWormhole.Signature[] memory signatures) public view {
IWormhole wormhole = IWormhole(_wormhole);
function verifyQueryResponseSignatures(bytes memory response, IWormhole.Signature[] memory signatures) public view {
// It might be worth adding a verifyCurrentQuorum call on the core bridge so that there is only 1 cross call instead of 4.
uint32 gsi = wormhole.getCurrentGuardianSetIndex();
IWormhole.GuardianSet memory guardianSet = wormhole.getGuardianSet(gsi);

View File

@ -11,7 +11,9 @@ import "../../contracts/Wormhole.sol";
import "forge-std/Test.sol";
// @dev A non-abstract QueryResponse contract
contract QueryResponseContract is QueryResponse { }
contract QueryResponseContract is QueryResponse {
constructor(address _wormhole) QueryResponse(_wormhole) {}
}
contract TestQueryResponse is Test {
// Some happy case defaults
@ -35,7 +37,7 @@ contract TestQueryResponse is Test {
function setUp() public {
wormhole = deployWormholeForTest();
queryResponse = new QueryResponseContract();
queryResponse = new QueryResponseContract(address(wormhole));
}
uint16 constant TEST_CHAIN_ID = 2;
@ -121,7 +123,7 @@ contract TestQueryResponse is Test {
(uint8 sigV, bytes32 sigR, bytes32 sigS) = getSignature(resp);
IWormhole.Signature[] memory signatures = new IWormhole.Signature[](1);
signatures[0] = IWormhole.Signature({r: sigR, s: sigS, v: sigV, guardianIndex: sigGuardianIndex});
queryResponse.verifyQueryResponseSignatures(address(wormhole), resp, signatures);
queryResponse.verifyQueryResponseSignatures(resp, signatures);
// TODO: There are no assertions for this test
}
@ -130,7 +132,7 @@ contract TestQueryResponse is Test {
(uint8 sigV, bytes32 sigR, bytes32 sigS) = getSignature(resp);
IWormhole.Signature[] memory signatures = new IWormhole.Signature[](1);
signatures[0] = IWormhole.Signature({r: sigR, s: sigS, v: sigV, guardianIndex: sigGuardianIndex});
ParsedQueryResponse memory r = queryResponse.parseAndVerifyQueryResponse(address(wormhole), resp, signatures);
ParsedQueryResponse memory r = queryResponse.parseAndVerifyQueryResponse(resp, signatures);
assertEq(r.version, 1);
assertEq(r.senderChainId, 0);
assertEq(r.requestId, hex"ff0c222dc9e3655ec38e212e9792bf1860356d1277462b6bf747db865caca6fc08e6317b64ee3245264e371146b1d315d38c867fe1f69614368dc4430bb560f200");
@ -306,7 +308,7 @@ contract TestQueryResponse is Test {
IWormhole.Signature[] memory signatures = new IWormhole.Signature[](1);
signatures[0] = IWormhole.Signature({r: sigR, s: sigS, v: sigV, guardianIndex: sigGuardianIndex});
vm.expectRevert(InvalidResponseVersion.selector);
queryResponse.parseAndVerifyQueryResponse(address(wormhole), resp, signatures);
queryResponse.parseAndVerifyQueryResponse(resp, signatures);
}
function testFuzz_parseAndVerifyQueryResponse_fuzzSenderChainId(uint16 _senderChainId) public {
@ -318,7 +320,7 @@ contract TestQueryResponse is Test {
signatures[0] = IWormhole.Signature({r: sigR, s: sigS, v: sigV, guardianIndex: sigGuardianIndex});
// This could revert for multiple reasons. But the checkLength to ensure all the bytes are consumed is the backstop.
vm.expectRevert();
queryResponse.parseAndVerifyQueryResponse(address(wormhole), resp, signatures);
queryResponse.parseAndVerifyQueryResponse(resp, signatures);
}
function testFuzz_parseAndVerifyQueryResponse_fuzzSignatureHappyCase(bytes memory _signature) public {
@ -329,7 +331,7 @@ contract TestQueryResponse is Test {
(uint8 sigV, bytes32 sigR, bytes32 sigS) = getSignature(resp);
IWormhole.Signature[] memory signatures = new IWormhole.Signature[](1);
signatures[0] = IWormhole.Signature({r: sigR, s: sigS, v: sigV, guardianIndex: sigGuardianIndex});
ParsedQueryResponse memory r = queryResponse.parseAndVerifyQueryResponse(address(wormhole), resp, signatures);
ParsedQueryResponse memory r = queryResponse.parseAndVerifyQueryResponse(resp, signatures);
assertEq(r.requestId, _signature);
}
@ -343,7 +345,7 @@ contract TestQueryResponse is Test {
IWormhole.Signature[] memory signatures = new IWormhole.Signature[](1);
signatures[0] = IWormhole.Signature({r: sigR, s: sigS, v: sigV, guardianIndex: sigGuardianIndex});
vm.expectRevert();
queryResponse.parseAndVerifyQueryResponse(address(wormhole), resp, signatures);
queryResponse.parseAndVerifyQueryResponse(resp, signatures);
}
function testFuzz_parseAndVerifyQueryResponse_fuzzQueryRequestLen(uint32 _queryRequestLen, bytes calldata _perChainQueries) public {
@ -355,7 +357,7 @@ contract TestQueryResponse is Test {
IWormhole.Signature[] memory signatures = new IWormhole.Signature[](1);
signatures[0] = IWormhole.Signature({r: sigR, s: sigS, v: sigV, guardianIndex: sigGuardianIndex});
vm.expectRevert();
queryResponse.parseAndVerifyQueryResponse(address(wormhole), resp, signatures);
queryResponse.parseAndVerifyQueryResponse(resp, signatures);
}
function testFuzz_parseAndVerifyQueryResponse_fuzzQueryRequestVersion(uint8 _version, uint8 _queryRequestVersion) public {
@ -366,7 +368,7 @@ contract TestQueryResponse is Test {
IWormhole.Signature[] memory signatures = new IWormhole.Signature[](1);
signatures[0] = IWormhole.Signature({r: sigR, s: sigS, v: sigV, guardianIndex: sigGuardianIndex});
vm.expectRevert();
queryResponse.parseAndVerifyQueryResponse(address(wormhole), resp, signatures);
queryResponse.parseAndVerifyQueryResponse(resp, signatures);
}
function testFuzz_parseAndVerifyQueryResponse_fuzzQueryRequestNonce(uint32 _queryRequestNonce) public {
@ -374,7 +376,7 @@ contract TestQueryResponse is Test {
(uint8 sigV, bytes32 sigR, bytes32 sigS) = getSignature(resp);
IWormhole.Signature[] memory signatures = new IWormhole.Signature[](1);
signatures[0] = IWormhole.Signature({r: sigR, s: sigS, v: sigV, guardianIndex: sigGuardianIndex});
ParsedQueryResponse memory r = queryResponse.parseAndVerifyQueryResponse(address(wormhole), resp, signatures);
ParsedQueryResponse memory r = queryResponse.parseAndVerifyQueryResponse(resp, signatures);
assertEq(r.nonce, _queryRequestNonce);
}
@ -387,7 +389,7 @@ contract TestQueryResponse is Test {
IWormhole.Signature[] memory signatures = new IWormhole.Signature[](1);
signatures[0] = IWormhole.Signature({r: sigR, s: sigS, v: sigV, guardianIndex: sigGuardianIndex});
vm.expectRevert();
queryResponse.parseAndVerifyQueryResponse(address(wormhole), resp, signatures);
queryResponse.parseAndVerifyQueryResponse(resp, signatures);
}
function testFuzz_parseAndVerifyQueryResponse_fuzzChainIds(uint16 _requestChainId, uint16 _responseChainId, uint256 _requestQueryType) public {
@ -401,7 +403,7 @@ contract TestQueryResponse is Test {
IWormhole.Signature[] memory signatures = new IWormhole.Signature[](1);
signatures[0] = IWormhole.Signature({r: sigR, s: sigS, v: sigV, guardianIndex: sigGuardianIndex});
vm.expectRevert(ChainIdMismatch.selector);
queryResponse.parseAndVerifyQueryResponse(address(wormhole), resp, signatures);
queryResponse.parseAndVerifyQueryResponse(resp, signatures);
}
function testFuzz_parseAndVerifyQueryResponse_fuzzMistmatchedRequestType(uint256 _requestQueryType, uint256 _responseQueryType) public {
@ -416,7 +418,7 @@ contract TestQueryResponse is Test {
IWormhole.Signature[] memory signatures = new IWormhole.Signature[](1);
signatures[0] = IWormhole.Signature({r: sigR, s: sigS, v: sigV, guardianIndex: sigGuardianIndex});
vm.expectRevert(RequestTypeMismatch.selector);
queryResponse.parseAndVerifyQueryResponse(address(wormhole), resp, signatures);
queryResponse.parseAndVerifyQueryResponse(resp, signatures);
}
function testFuzz_parseAndVerifyQueryResponse_fuzzUnsupportedRequestType(uint8 _requestQueryType) public {
@ -429,7 +431,7 @@ contract TestQueryResponse is Test {
IWormhole.Signature[] memory signatures = new IWormhole.Signature[](1);
signatures[0] = IWormhole.Signature({r: sigR, s: sigS, v: sigV, guardianIndex: sigGuardianIndex});
vm.expectRevert(UnsupportedQueryType.selector);
queryResponse.parseAndVerifyQueryResponse(address(wormhole), resp, signatures);
queryResponse.parseAndVerifyQueryResponse(resp, signatures);
}
function testFuzz_parseAndVerifyQueryResponse_fuzzQueryBytesLength(uint32 _queryLength) public {
@ -442,7 +444,7 @@ contract TestQueryResponse is Test {
IWormhole.Signature[] memory signatures = new IWormhole.Signature[](1);
signatures[0] = IWormhole.Signature({r: sigR, s: sigS, v: sigV, guardianIndex: sigGuardianIndex});
vm.expectRevert();
queryResponse.parseAndVerifyQueryResponse(address(wormhole), resp, signatures);
queryResponse.parseAndVerifyQueryResponse(resp, signatures);
}
function testFuzz_verifyQueryResponseSignatures_validSignature(bytes calldata resp) public view {
@ -450,7 +452,7 @@ contract TestQueryResponse is Test {
(uint8 sigV, bytes32 sigR, bytes32 sigS) = getSignature(resp);
IWormhole.Signature[] memory signatures = new IWormhole.Signature[](1);
signatures[0] = IWormhole.Signature({r: sigR, s: sigS, v: sigV, guardianIndex: sigGuardianIndex});
queryResponse.verifyQueryResponseSignatures(address(wormhole), resp, signatures);
queryResponse.verifyQueryResponseSignatures(resp, signatures);
}
function testFuzz_verifyQueryResponseSignatures_invalidSignature(bytes calldata resp, uint256 privateKey) public {
@ -463,7 +465,7 @@ contract TestQueryResponse is Test {
IWormhole.Signature[] memory signatures = new IWormhole.Signature[](1);
signatures[0] = IWormhole.Signature({r: sigR, s: sigS, v: sigV, guardianIndex: sigGuardianIndex});
vm.expectRevert("VM signature invalid");
queryResponse.verifyQueryResponseSignatures(address(wormhole), resp, signatures);
queryResponse.verifyQueryResponseSignatures(resp, signatures);
}
function testFuzz_verifyQueryResponseSignatures_validSignatureWrongPrefix(bytes calldata responsePrefix) public {
@ -476,7 +478,7 @@ contract TestQueryResponse is Test {
IWormhole.Signature[] memory signatures = new IWormhole.Signature[](1);
signatures[0] = IWormhole.Signature({r: sigR, s: sigS, v: sigV, guardianIndex: sigGuardianIndex});
vm.expectRevert("VM signature invalid");
queryResponse.verifyQueryResponseSignatures(address(wormhole), resp, signatures);
queryResponse.verifyQueryResponseSignatures(resp, signatures);
}
}