From f94dceb1bcd361d1c53c88ed41dd306f057fad7e Mon Sep 17 00:00:00 2001 From: Ali Behjati Date: Wed, 3 May 2023 16:08:53 +0200 Subject: [PATCH] [eth] Add Pyth Accumulator (#776) This PR adds the support WormholeMerkle accumulator message to the ethereum contract while still supporting the old message format. The code is not optimized yet and with more optimizations we can achieve a better gas usage. Currently based on the gas benchmark below it has a 18% improvement with a single price feed. Although the cost of updating 5 feeds in the same batch is higher than the current approach but in reality the chances that all 5 feeds be in the same batch is very low. --- package-lock.json | 2 +- .../contracts/libraries/MerkleTree.sol | 145 ++++++ .../contracts/contracts/pyth/Pyth.sol | 21 +- .../contracts/pyth/PythAccumulator.sol | 319 +++++++++++++ .../contracts/forge-test/GasBenchmark.t.sol | 164 +++++-- .../Pyth.WormholeMerkleAccumulator.t.sol | 427 ++++++++++++++++++ .../ethereum/contracts/forge-test/Pyth.t.sol | 2 +- .../forge-test/VerificationExperiments.t.sol | 4 +- .../forge-test/utils/PythTestUtils.t.sol | 112 ++++- .../forge-test/utils/RandTestUtils.t.sol | 4 + target_chains/ethereum/contracts/package.json | 2 +- 11 files changed, 1155 insertions(+), 47 deletions(-) create mode 100644 target_chains/ethereum/contracts/contracts/libraries/MerkleTree.sol create mode 100644 target_chains/ethereum/contracts/contracts/pyth/PythAccumulator.sol create mode 100644 target_chains/ethereum/contracts/forge-test/Pyth.WormholeMerkleAccumulator.t.sol diff --git a/package-lock.json b/package-lock.json index e4efb5db..b265e9c4 100644 --- a/package-lock.json +++ b/package-lock.json @@ -49691,7 +49691,7 @@ }, "target_chains/ethereum/contracts": { "name": "@pythnetwork/pyth-evm-contract", - "version": "1.2.0", + "version": "1.3.0-alpha", "license": "ISC", "dependencies": { "@certusone/wormhole-sdk": "^0.9.9", diff --git a/target_chains/ethereum/contracts/contracts/libraries/MerkleTree.sol b/target_chains/ethereum/contracts/contracts/libraries/MerkleTree.sol new file mode 100644 index 00000000..17f5edce --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/libraries/MerkleTree.sol @@ -0,0 +1,145 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +import "./external/UnsafeBytesLib.sol"; + +/** + * @dev This library provides methods to construct and verify Merkle Tree proofs efficiently. + * + */ + +library MerkleTree { + uint8 constant MERKLE_LEAF_PREFIX = 0; + uint8 constant MERKLE_NODE_PREFIX = 1; + uint8 constant MERKLE_EMPTY_LEAF_PREFIX = 2; + + function hash(bytes memory input) internal pure returns (bytes20) { + return bytes20(keccak256(input)); + } + + function emptyLeafHash() internal pure returns (bytes20) { + return hash(abi.encodePacked(MERKLE_EMPTY_LEAF_PREFIX)); + } + + function leafHash(bytes memory data) internal pure returns (bytes20) { + return hash(abi.encodePacked(MERKLE_LEAF_PREFIX, data)); + } + + function nodeHash( + bytes20 childA, + bytes20 childB + ) internal pure returns (bytes20) { + if (childA > childB) { + (childA, childB) = (childB, childA); + } + return hash(abi.encodePacked(MERKLE_NODE_PREFIX, childA, childB)); + } + + /// @notice Verify Merkle Tree proof for given leaf data. + /// @dev To optimize gas usage, this method doesn't take the proof as a bytes array + /// but rather takes the encoded proof and the offset of the proof in the + /// encoded proof array possibly containing multiple proofs. Also, the method + /// does not perform any check on the boundry of the `encodedProof` and the + /// `proofOffset` parameters. It is the caller's responsibility to ensure + /// that the `encodedProof` is long enough to contain the proof and the + /// `proofOffset` is not out of bound. + function isProofValid( + bytes memory encodedProof, + uint proofOffset, + bytes20 root, + bytes memory leafData + ) internal pure returns (bool valid, uint endOffset) { + unchecked { + bytes20 currentDigest = MerkleTree.leafHash(leafData); + + uint8 proofSize = UnsafeBytesLib.toUint8(encodedProof, proofOffset); + proofOffset += 1; + + for (uint i = 0; i < proofSize; i++) { + bytes20 siblingDigest = bytes20( + UnsafeBytesLib.toAddress(encodedProof, proofOffset) + ); + proofOffset += 20; + + currentDigest = MerkleTree.nodeHash( + currentDigest, + siblingDigest + ); + } + + valid = currentDigest == root; + endOffset = proofOffset; + } + } + + /// @notice Construct Merkle Tree proofs for given list of messages. + /// @dev This function is only used for testing purposes and is not efficient + /// for production use-cases. + /// + /// This method creates a merkle tree with leaf size of (2^depth) with the + /// messages as leafs (in the same given order) and returns the root digest + /// and the proofs for each message. If the number of messages is not a power + /// of 2, the tree is padded with empty messages. + function constructProofs( + bytes[] memory messages, + uint8 depth + ) internal pure returns (bytes20 root, bytes[] memory proofs) { + require((1 << depth) >= messages.length, "depth too small"); + + bytes20[] memory tree = new bytes20[]((1 << (depth + 1))); + + // The tree is structured as follows: + // 1 + // 2 3 + // 4 5 6 7 + // ... + // In this structure the parent of node x is x//2 and the children + // of node x are x*2 and x*2 + 1. Also, the sibling of the node x + // is x^1. The root is at index 1 and index 0 is not used. + + // Filling the leaf hashes + bytes20 cachedEmptyLeafHash = emptyLeafHash(); + + for (uint i = 0; i < (1 << depth); i++) { + if (i < messages.length) { + tree[(1 << depth) + i] = leafHash(messages[i]); + } else { + tree[(1 << depth) + i] = cachedEmptyLeafHash; + } + } + + // Filling the node hashes from bottom to top + for (uint k = depth; k > 0; k--) { + uint level = k - 1; + uint levelNumNodes = (1 << level); + for (uint i = 0; i < levelNumNodes; i++) { + uint id = (1 << level) + i; + tree[id] = nodeHash(tree[id * 2], tree[id * 2 + 1]); + } + } + + root = tree[1]; + + proofs = new bytes[](messages.length); + + for (uint i = 0; i < messages.length; i++) { + // depth is the number of sibling nodes in the path from the leaf to the root + proofs[i] = abi.encodePacked(depth); + + uint idx = (1 << depth) + i; + + // This loop iterates through the leaf and its parents + // and keeps adding the sibling of the current node to the proof. + while (idx > 1) { + proofs[i] = abi.encodePacked( + proofs[i], + tree[idx ^ 1] // Sibling of this node + ); + + // Jump to parent + idx /= 2; + } + } + } +} diff --git a/target_chains/ethereum/contracts/contracts/pyth/Pyth.sol b/target_chains/ethereum/contracts/contracts/pyth/Pyth.sol index cb3753b3..06a07009 100644 --- a/target_chains/ethereum/contracts/contracts/pyth/Pyth.sol +++ b/target_chains/ethereum/contracts/contracts/pyth/Pyth.sol @@ -1,4 +1,3 @@ -// contracts/Bridge.sol // SPDX-License-Identifier: Apache 2 pragma solidity ^0.8.0; @@ -8,11 +7,17 @@ import "@pythnetwork/pyth-sdk-solidity/AbstractPyth.sol"; import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol"; import "@pythnetwork/pyth-sdk-solidity/PythErrors.sol"; +import "./PythAccumulator.sol"; import "./PythGetters.sol"; import "./PythSetters.sol"; import "./PythInternalStructs.sol"; -abstract contract Pyth is PythGetters, PythSetters, AbstractPyth { +abstract contract Pyth is + PythGetters, + PythSetters, + AbstractPyth, + PythAccumulator +{ function _initialize( address wormhole, uint16[] calldata dataSourceEmitterChainIds, @@ -66,11 +71,19 @@ abstract contract Pyth is PythGetters, PythSetters, AbstractPyth { function updatePriceFeeds( bytes[] calldata updateData ) public payable override { + // TODO: Is this fee model still good for accumulator? uint requiredFee = getUpdateFee(updateData); if (msg.value < requiredFee) revert PythErrors.InsufficientFee(); for (uint i = 0; i < updateData.length; ) { - updatePriceBatchFromVm(updateData[i]); + if ( + updateData[i].length > 4 && + UnsafeBytesLib.toUint32(updateData[i], 0) == ACCUMULATOR_MAGIC + ) { + updatePricesUsingAccumulator(updateData[i]); + } else { + updatePriceBatchFromVm(updateData[i]); + } unchecked { i++; @@ -536,6 +549,6 @@ abstract contract Pyth is PythGetters, PythSetters, AbstractPyth { } function version() public pure returns (string memory) { - return "1.2.0"; + return "1.3.0-alpha"; } } diff --git a/target_chains/ethereum/contracts/contracts/pyth/PythAccumulator.sol b/target_chains/ethereum/contracts/contracts/pyth/PythAccumulator.sol new file mode 100644 index 00000000..7e204e92 --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/pyth/PythAccumulator.sol @@ -0,0 +1,319 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +import "../libraries/external/UnsafeBytesLib.sol"; +import "@pythnetwork/pyth-sdk-solidity/AbstractPyth.sol"; +import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol"; + +import "@pythnetwork/pyth-sdk-solidity/PythErrors.sol"; +import "./PythGetters.sol"; +import "./PythSetters.sol"; +import "./PythInternalStructs.sol"; + +import "../libraries/MerkleTree.sol"; + +abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth { + uint32 constant ACCUMULATOR_MAGIC = 0x504e4155; // Stands for PNAU (Pyth Network Accumulator Update) + uint32 constant ACCUMULATOR_WORMHOLE_MAGIC = 0x41555756; // Stands for AUWV (Accumulator Update Wormhole Verficiation) + uint8 constant MINIMUM_ALLOWED_MINOR_VERSION = 0; + uint8 constant MAJOR_VERSION = 1; + + enum UpdateType { + WormholeMerkle + } + + enum MessageType { + PriceFeed + } + + // This method is also used by batch attestation but moved here + // as the batch attestation will deprecate soon. + function parseAndVerifyPythVM( + bytes memory encodedVm + ) internal view returns (IWormhole.VM memory vm) { + { + bool valid; + (vm, valid, ) = wormhole().parseAndVerifyVM(encodedVm); + if (!valid) revert PythErrors.InvalidWormholeVaa(); + } + + if (!isValidDataSource(vm.emitterChainId, vm.emitterAddress)) + revert PythErrors.InvalidUpdateDataSource(); + } + + function updatePricesUsingAccumulator( + bytes calldata accumulatorUpdate + ) internal { + unchecked { + uint offset = 0; + + { + uint32 magic = UnsafeBytesLib.toUint32( + accumulatorUpdate, + offset + ); + offset += 4; + + if (magic != ACCUMULATOR_MAGIC) + revert PythErrors.InvalidUpdateData(); + + uint8 majorVersion = UnsafeBytesLib.toUint8( + accumulatorUpdate, + offset + ); + offset += 1; + + if (majorVersion != MAJOR_VERSION) + revert PythErrors.InvalidUpdateData(); + + uint8 minorVersion = UnsafeBytesLib.toUint8( + accumulatorUpdate, + offset + ); + offset += 1; + + // Minor versions are forward compatible, so we only check + // that the minor version is not less than the minimum allowed + if (minorVersion < MINIMUM_ALLOWED_MINOR_VERSION) + revert PythErrors.InvalidUpdateData(); + + // This field ensure that we can add headers in the future + // without breaking the contract (future compatibility) + uint8 trailingHeaderSize = UnsafeBytesLib.toUint8( + accumulatorUpdate, + offset + ); + offset += 1; + + // We use another offset for the trailing header and in the end add the + // offset by trailingHeaderSize to skip the future headers. + // + // An example would be like this: + // uint trailingHeaderOffset = offset + // uint x = UnsafeBytesLib.ToUint8(accumulatorUpdate, trailingHeaderOffset) + // trailingHeaderOffset += 1 + + offset += trailingHeaderSize; + } + + UpdateType updateType = UpdateType( + UnsafeBytesLib.toUint8(accumulatorUpdate, offset) + ); + offset += 1; + + if (accumulatorUpdate.length < offset) + revert PythErrors.InvalidUpdateData(); + + if (updateType == UpdateType.WormholeMerkle) { + updatePricesUsingWormholeMerkle( + UnsafeBytesLib.slice( + accumulatorUpdate, + offset, + accumulatorUpdate.length - offset + ) + ); + } else { + revert PythErrors.InvalidUpdateData(); + } + } + } + + function updatePricesUsingWormholeMerkle(bytes memory encoded) private { + unchecked { + uint offset = 0; + + uint16 whProofSize = UnsafeBytesLib.toUint16(encoded, offset); + offset += 2; + + bytes20 digest; + + { + IWormhole.VM memory vm = parseAndVerifyPythVM( + UnsafeBytesLib.slice(encoded, offset, whProofSize) + ); + offset += whProofSize; + + // TODO: Do we need to emit an update for accumulator update? If so what should we emit? + // emit AccumulatorUpdate(vm.chainId, vm.sequence); + + bytes memory encodedPayload = vm.payload; + uint payloadoffset = 0; + + { + uint32 magic = UnsafeBytesLib.toUint32( + encodedPayload, + payloadoffset + ); + payloadoffset += 4; + + if (magic != ACCUMULATOR_WORMHOLE_MAGIC) + revert PythErrors.InvalidUpdateData(); + + UpdateType updateType = UpdateType( + UnsafeBytesLib.toUint8(encodedPayload, payloadoffset) + ); + payloadoffset += 1; + + if (updateType != UpdateType.WormholeMerkle) + revert PythErrors.InvalidUpdateData(); + + // This field is not used + // uint32 storageIndex = UnsafeBytesLib.toUint32(encodedPayload, payloadoffset); + payloadoffset += 4; + + digest = bytes20( + UnsafeBytesLib.toAddress(encodedPayload, payloadoffset) + ); + payloadoffset += 20; + + // We don't check equality to enable future compatibility. + if (payloadoffset > encodedPayload.length) + revert PythErrors.InvalidUpdateData(); + } + } + + uint8 numUpdates = UnsafeBytesLib.toUint8(encoded, offset); + offset += 1; + + for (uint i = 0; i < numUpdates; i++) { + offset = verifyAndUpdatePriceFeedFromMerkleProof( + digest, + encoded, + offset + ); + } + + if (offset != encoded.length) revert PythErrors.InvalidUpdateData(); + } + } + + function verifyAndUpdatePriceFeedFromMerkleProof( + bytes20 digest, + bytes memory encoded, + uint offset + ) private returns (uint endOffset) { + unchecked { + uint16 messageSize = UnsafeBytesLib.toUint16(encoded, offset); + offset += 2; + + bytes memory encodedMessage = UnsafeBytesLib.slice( + encoded, + offset, + messageSize + ); + offset += messageSize; + + bool valid; + (valid, offset) = MerkleTree.isProofValid( + encoded, + offset, + digest, + encodedMessage + ); + + if (!valid) { + revert PythErrors.InvalidUpdateData(); + } + + parseAndProcessMessage(encodedMessage); + + return offset; + } + } + + function parsePriceFeedMessage( + bytes memory encodedPriceFeed + ) + private + pure + returns ( + PythInternalStructs.PriceInfo memory priceInfo, + bytes32 priceId + ) + { + unchecked { + uint offset = 0; + + priceId = UnsafeBytesLib.toBytes32(encodedPriceFeed, offset); + offset += 32; + + priceInfo.price = int64( + UnsafeBytesLib.toUint64(encodedPriceFeed, offset) + ); + offset += 8; + + priceInfo.conf = UnsafeBytesLib.toUint64(encodedPriceFeed, offset); + offset += 8; + + priceInfo.expo = int32( + UnsafeBytesLib.toUint32(encodedPriceFeed, offset) + ); + offset += 4; + + // Publish time is i64 in some environments due to the standard in that + // environment. This would not cause any problem because since the signed + // integer is represented in two's complement, the value would be the same + // in both cases (for a million year at least) + priceInfo.publishTime = UnsafeBytesLib.toUint64( + encodedPriceFeed, + offset + ); + offset += 8; + + // We do not store this field because it is not used on the latest feed queries. + // uint64 prevPublishTime = UnsafeBytesLib.toUint64(encodedPriceFeed, offset); + offset += 8; + + priceInfo.emaPrice = int64( + UnsafeBytesLib.toUint64(encodedPriceFeed, offset) + ); + offset += 8; + + priceInfo.emaConf = UnsafeBytesLib.toUint64( + encodedPriceFeed, + offset + ); + offset += 8; + + // We don't check equality to enable future compatibility. + if (offset > encodedPriceFeed.length) + revert PythErrors.InvalidUpdateData(); + } + } + + function parseAndProcessMessage(bytes memory encodedMessage) private { + unchecked { + MessageType messageType = MessageType( + UnsafeBytesLib.toUint8(encodedMessage, 0) + ); + + if (messageType == MessageType.PriceFeed) { + ( + PythInternalStructs.PriceInfo memory info, + bytes32 priceId + ) = parsePriceFeedMessage( + UnsafeBytesLib.slice( + encodedMessage, + 1, + encodedMessage.length - 1 + ) + ); + + uint64 latestPublishTime = latestPriceInfoPublishTime(priceId); + + if (info.publishTime > latestPublishTime) { + setLatestPriceInfo(priceId, info); + emit PriceFeedUpdate( + priceId, + info.publishTime, + info.price, + info.conf + ); + } + } else { + revert PythErrors.InvalidUpdateData(); + } + } + } +} diff --git a/target_chains/ethereum/contracts/forge-test/GasBenchmark.t.sol b/target_chains/ethereum/contracts/forge-test/GasBenchmark.t.sol index 24f5c306..232f09a8 100644 --- a/target_chains/ethereum/contracts/forge-test/GasBenchmark.t.sol +++ b/target_chains/ethereum/contracts/forge-test/GasBenchmark.t.sol @@ -22,22 +22,30 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils { // We use 5 prices to form a batch of 5 prices, close to our mainnet transactions. uint8 constant NUM_PRICES = 5; + // We will have less than 512 price for a foreseeable future. + uint8 constant MERKLE_TREE_DEPTH = 9; + IPyth public pyth; bytes32[] priceIds; // Cached prices are populated in the setUp PythStructs.Price[] cachedPrices; - bytes[] cachedPricesUpdateData; - uint cachedPricesUpdateFee; + bytes[] cachedPricesWhBatchUpdateData; + uint cachedPricesWhBatchUpdateFee; uint64[] cachedPricesPublishTimes; + bytes[][] cachedPricesWhMerkleUpdateData; // i th element contains the update data for the first i prices + uint[] cachedPricesWhMerkleUpdateFee; // i th element contains the update fee for the first i prices + // Fresh prices are different prices that can be used // as a fresh price to update the prices PythStructs.Price[] freshPrices; - bytes[] freshPricesUpdateData; - uint freshPricesUpdateFee; + bytes[] freshPricesWhBatchUpdateData; + uint freshPricesWhBatchUpdateFee; uint64[] freshPricesPublishTimes; + bytes[][] freshPricesWhMerkleUpdateData; // i th element contains the update data for the first i prices + uint[] freshPricesWhMerkleUpdateFee; // i th element contains the update fee for the first i prices uint64 sequence; uint randSeed; @@ -76,21 +84,37 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils { ) ); freshPricesPublishTimes.push(publishTime); + + // Generate Wormhole Merkle update data and fee for the first i th prices + ( + bytes[] memory updateData, + uint updateFee + ) = generateWhMerkleUpdateDataAndFee(cachedPrices); + + cachedPricesWhMerkleUpdateData.push(updateData); + cachedPricesWhMerkleUpdateFee.push(updateFee); + + (updateData, updateFee) = generateWhMerkleUpdateDataAndFee( + freshPrices + ); + + freshPricesWhMerkleUpdateData.push(updateData); + freshPricesWhMerkleUpdateFee.push(updateFee); } // Populate the contract with the initial prices ( - cachedPricesUpdateData, - cachedPricesUpdateFee - ) = generateUpdateDataAndFee(cachedPrices); - pyth.updatePriceFeeds{value: cachedPricesUpdateFee}( - cachedPricesUpdateData + cachedPricesWhBatchUpdateData, + cachedPricesWhBatchUpdateFee + ) = generateWhBatchUpdateDataAndFee(cachedPrices); + pyth.updatePriceFeeds{value: cachedPricesWhBatchUpdateFee}( + cachedPricesWhBatchUpdateData ); ( - freshPricesUpdateData, - freshPricesUpdateFee - ) = generateUpdateDataAndFee(freshPrices); + freshPricesWhBatchUpdateData, + freshPricesWhBatchUpdateFee + ) = generateWhBatchUpdateDataAndFee(freshPrices); } function getRand() internal returns (uint val) { @@ -98,10 +122,10 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils { val = uint(keccak256(abi.encode(randSeed))); } - function generateUpdateDataAndFee( + function generateWhBatchUpdateDataAndFee( PythStructs.Price[] memory prices ) internal returns (bytes[] memory updateData, uint updateFee) { - bytes memory vaa = generatePriceFeedUpdateVAA( + bytes memory vaa = generateWhBatchUpdate( pricesToPriceAttestations(priceIds, prices), sequence, NUM_GUARDIAN_SIGNERS @@ -115,35 +139,109 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils { updateFee = pyth.getUpdateFee(updateData); } - function testBenchmarkUpdatePriceFeedsFresh() public { - pyth.updatePriceFeeds{value: freshPricesUpdateFee}( - freshPricesUpdateData + function generateWhMerkleUpdateDataAndFee( + PythStructs.Price[] memory prices + ) internal returns (bytes[] memory updateData, uint updateFee) { + updateData = new bytes[](1); + + updateData[0] = generateWhMerkleUpdate( + pricesToPriceFeedMessages(priceIds, prices), + MERKLE_TREE_DEPTH, + NUM_GUARDIAN_SIGNERS + ); + + updateFee = pyth.getUpdateFee(updateData); + } + + function testBenchmarkUpdatePriceFeedsWhBatchFresh() public { + pyth.updatePriceFeeds{value: freshPricesWhBatchUpdateFee}( + freshPricesWhBatchUpdateData ); } - function testBenchmarkUpdatePriceFeedsNotFresh() public { - pyth.updatePriceFeeds{value: cachedPricesUpdateFee}( - cachedPricesUpdateData + function testBenchmarkUpdatePriceFeedsWhBatchNotFresh() public { + pyth.updatePriceFeeds{value: cachedPricesWhBatchUpdateFee}( + cachedPricesWhBatchUpdateData ); } - function testBenchmarkUpdatePriceFeedsIfNecessaryFresh() public { + function testBenchmarkUpdatePriceFeedsWhMerkle1FeedFresh() public { + pyth.updatePriceFeeds{value: freshPricesWhMerkleUpdateFee[0]}( + freshPricesWhMerkleUpdateData[0] + ); + } + + function testBenchmarkUpdatePriceFeedsWhMerkle2FeedsFresh() public { + pyth.updatePriceFeeds{value: freshPricesWhMerkleUpdateFee[1]}( + freshPricesWhMerkleUpdateData[1] + ); + } + + function testBenchmarkUpdatePriceFeedsWhMerkle3FeedsFresh() public { + pyth.updatePriceFeeds{value: freshPricesWhMerkleUpdateFee[2]}( + freshPricesWhMerkleUpdateData[2] + ); + } + + function testBenchmarkUpdatePriceFeedsWhMerkle4FeedsFresh() public { + pyth.updatePriceFeeds{value: freshPricesWhMerkleUpdateFee[3]}( + freshPricesWhMerkleUpdateData[3] + ); + } + + function testBenchmarkUpdatePriceFeedsWhMerkle5FeedsFresh() public { + pyth.updatePriceFeeds{value: freshPricesWhMerkleUpdateFee[4]}( + freshPricesWhMerkleUpdateData[4] + ); + } + + function testBenchmarkUpdatePriceFeedsWhMerkle1FeedNotFresh() public { + pyth.updatePriceFeeds{value: cachedPricesWhMerkleUpdateFee[0]}( + cachedPricesWhMerkleUpdateData[0] + ); + } + + function testBenchmarkUpdatePriceFeedsWhMerkle2FeedsNotFresh() public { + pyth.updatePriceFeeds{value: cachedPricesWhMerkleUpdateFee[1]}( + cachedPricesWhMerkleUpdateData[1] + ); + } + + function testBenchmarkUpdatePriceFeedsWhMerkle3FeedsNotFresh() public { + pyth.updatePriceFeeds{value: cachedPricesWhMerkleUpdateFee[2]}( + cachedPricesWhMerkleUpdateData[2] + ); + } + + function testBenchmarkUpdatePriceFeedsWhMerkle4FeedsNotFresh() public { + pyth.updatePriceFeeds{value: cachedPricesWhMerkleUpdateFee[3]}( + cachedPricesWhMerkleUpdateData[3] + ); + } + + function testBenchmarkUpdatePriceFeedsWhMerkle5FeedsNotFresh() public { + pyth.updatePriceFeeds{value: cachedPricesWhMerkleUpdateFee[4]}( + cachedPricesWhMerkleUpdateData[4] + ); + } + + function testBenchmarkUpdatePriceFeedsIfNecessaryWhBatchFresh() public { // Since the prices have advanced, the publishTimes are newer than one in // the contract and hence, the call should succeed. - pyth.updatePriceFeedsIfNecessary{value: freshPricesUpdateFee}( - freshPricesUpdateData, + pyth.updatePriceFeedsIfNecessary{value: freshPricesWhBatchUpdateFee}( + freshPricesWhBatchUpdateData, priceIds, freshPricesPublishTimes ); } - function testBenchmarkUpdatePriceFeedsIfNecessaryNotFresh() public { + function testBenchmarkUpdatePriceFeedsIfNecessaryWhBatchNotFresh() public { // Since the price is not advanced, the publishTimes are the same as the // ones in the contract. vm.expectRevert(PythErrors.NoFreshUpdate.selector); - pyth.updatePriceFeedsIfNecessary{value: cachedPricesUpdateFee}( - cachedPricesUpdateData, + pyth.updatePriceFeedsIfNecessary{value: cachedPricesWhBatchUpdateFee}( + cachedPricesWhBatchUpdateData, priceIds, cachedPricesPublishTimes ); @@ -153,8 +251,8 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils { bytes32[] memory ids = new bytes32[](1); ids[0] = priceIds[0]; - pyth.parsePriceFeedUpdates{value: freshPricesUpdateFee}( - freshPricesUpdateData, + pyth.parsePriceFeedUpdates{value: freshPricesWhBatchUpdateFee}( + freshPricesWhBatchUpdateData, ids, 0, 50 @@ -166,8 +264,8 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils { ids[0] = priceIds[0]; ids[1] = priceIds[1]; - pyth.parsePriceFeedUpdates{value: freshPricesUpdateFee}( - freshPricesUpdateData, + pyth.parsePriceFeedUpdates{value: freshPricesWhBatchUpdateFee}( + freshPricesWhBatchUpdateData, ids, 0, 50 @@ -181,8 +279,8 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils { ids[0] = priceIds[0]; vm.expectRevert(PythErrors.PriceFeedNotFoundWithinRange.selector); - pyth.parsePriceFeedUpdates{value: freshPricesUpdateFee}( - freshPricesUpdateData, + pyth.parsePriceFeedUpdates{value: freshPricesWhBatchUpdateFee}( + freshPricesWhBatchUpdateData, ids, 50, 100 @@ -206,6 +304,6 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils { } function testBenchmarkGetUpdateFee() public view { - pyth.getUpdateFee(freshPricesUpdateData); + pyth.getUpdateFee(freshPricesWhBatchUpdateData); } } diff --git a/target_chains/ethereum/contracts/forge-test/Pyth.WormholeMerkleAccumulator.t.sol b/target_chains/ethereum/contracts/forge-test/Pyth.WormholeMerkleAccumulator.t.sol new file mode 100644 index 00000000..773f0740 --- /dev/null +++ b/target_chains/ethereum/contracts/forge-test/Pyth.WormholeMerkleAccumulator.t.sol @@ -0,0 +1,427 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; +import "forge-std/Test.sol"; + +import "@pythnetwork/pyth-sdk-solidity/IPyth.sol"; +import "@pythnetwork/pyth-sdk-solidity/PythErrors.sol"; +import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol"; +import "./utils/WormholeTestUtils.t.sol"; +import "./utils/PythTestUtils.t.sol"; +import "./utils/RandTestUtils.t.sol"; + +import "../contracts/libraries/MerkleTree.sol"; + +contract PythWormholeMerkleAccumulatorTest is + Test, + WormholeTestUtils, + PythTestUtils, + RandTestUtils +{ + IPyth public pyth; + + function setUp() public { + pyth = IPyth(setUpPyth(setUpWormhole(1))); + } + + function assertPriceFeedMessageStored( + PriceFeedMessage memory priceFeedMessage + ) internal { + PythStructs.Price memory aggregatePrice = pyth.getPriceUnsafe( + priceFeedMessage.priceId + ); + assertEq(aggregatePrice.price, priceFeedMessage.price); + assertEq(aggregatePrice.conf, priceFeedMessage.conf); + assertEq(aggregatePrice.expo, priceFeedMessage.expo); + assertEq(aggregatePrice.publishTime, priceFeedMessage.publishTime); + + PythStructs.Price memory emaPrice = pyth.getEmaPriceUnsafe( + priceFeedMessage.priceId + ); + assertEq(emaPrice.price, priceFeedMessage.emaPrice); + assertEq(emaPrice.conf, priceFeedMessage.emaConf); + assertEq(emaPrice.expo, priceFeedMessage.expo); + assertEq(emaPrice.publishTime, priceFeedMessage.publishTime); + } + + function generateRandomPriceFeedMessage( + uint numPriceFeeds + ) internal returns (PriceFeedMessage[] memory priceFeedMessages) { + priceFeedMessages = new PriceFeedMessage[](numPriceFeeds); + for (uint i = 0; i < numPriceFeeds; i++) { + priceFeedMessages[i] = PriceFeedMessage({ + priceId: getRandBytes32(), + price: getRandInt64(), + conf: getRandUint64(), + expo: getRandInt32(), + publishTime: getRandUint64(), + prevPublishTime: getRandUint64(), + emaPrice: getRandInt64(), + emaConf: getRandUint64() + }); + } + } + + function createWormholeMerkleUpdateData( + PriceFeedMessage[] memory priceFeedMessages + ) internal returns (bytes[] memory updateData, uint updateFee) { + updateData = new bytes[](1); + + uint8 depth = 0; + while ((1 << depth) < priceFeedMessages.length) { + depth++; + } + + depth += getRandUint8() % 3; + + updateData[0] = generateWhMerkleUpdate(priceFeedMessages, depth, 1); + + updateFee = pyth.getUpdateFee(updateData); + } + + /// Testing update price feeds method using wormhole merkle update type. + function testUpdatePriceFeedWithWormholeMerkleWorks(uint seed) public { + setRandSeed(seed); + + uint numPriceFeeds = (getRandUint() % 10) + 1; + PriceFeedMessage[] + memory priceFeedMessages = generateRandomPriceFeedMessage( + numPriceFeeds + ); + ( + bytes[] memory updateData, + uint updateFee + ) = createWormholeMerkleUpdateData(priceFeedMessages); + + pyth.updatePriceFeeds{value: updateFee}(updateData); + + for (uint i = 0; i < numPriceFeeds; i++) { + assertPriceFeedMessageStored(priceFeedMessages[i]); + } + + // Update the prices again with the same data should work + pyth.updatePriceFeeds{value: updateFee}(updateData); + + for (uint i = 0; i < numPriceFeeds; i++) { + assertPriceFeedMessageStored(priceFeedMessages[i]); + } + + // Update the prices again with updated data should update the prices + for (uint i = 0; i < numPriceFeeds; i++) { + priceFeedMessages[i].price = getRandInt64(); + priceFeedMessages[i].conf = getRandUint64(); + priceFeedMessages[i].expo = getRandInt32(); + + // Increase the publish time if it is not causing an overflow + if (priceFeedMessages[i].publishTime != type(uint64).max) { + priceFeedMessages[i].publishTime += 1; + } + priceFeedMessages[i].emaPrice = getRandInt64(); + priceFeedMessages[i].emaConf = getRandUint64(); + } + + (updateData, updateFee) = createWormholeMerkleUpdateData( + priceFeedMessages + ); + + pyth.updatePriceFeeds{value: updateFee}(updateData); + + for (uint i = 0; i < numPriceFeeds; i++) { + assertPriceFeedMessageStored(priceFeedMessages[i]); + } + } + + function testUpdatePriceFeedWithWormholeMerkleWorksOnMultiUpdate() public { + PriceFeedMessage[] + memory priceFeedMessages1 = generateRandomPriceFeedMessage(2); + PriceFeedMessage[] + memory priceFeedMessages2 = generateRandomPriceFeedMessage(2); + + // Make the 2nd message of the second update the same as the 1st message of the first update + priceFeedMessages2[1].priceId = priceFeedMessages1[0].priceId; + // Adjust the timestamps so the second timestamp is greater than the first + priceFeedMessages1[0].publishTime = 5; + priceFeedMessages2[1].publishTime = 10; + + bytes[] memory updateData = new bytes[](2); + + uint8 depth = 1; // 2 messages + uint8 numSigners = 1; + updateData[0] = generateWhMerkleUpdate( + priceFeedMessages1, + depth, + numSigners + ); + updateData[1] = generateWhMerkleUpdate( + priceFeedMessages2, + depth, + numSigners + ); + + uint updateFee = pyth.getUpdateFee(updateData); + + pyth.updatePriceFeeds{value: updateFee}(updateData); + + assertPriceFeedMessageStored(priceFeedMessages1[1]); + assertPriceFeedMessageStored(priceFeedMessages2[0]); + assertPriceFeedMessageStored(priceFeedMessages2[1]); + } + + function testUpdatePriceFeedWithWormholeMerkleIgnoresOutOfOrderUpdateSingleCall() + public + { + PriceFeedMessage[] + memory priceFeedMessages1 = generateRandomPriceFeedMessage(1); + PriceFeedMessage[] + memory priceFeedMessages2 = generateRandomPriceFeedMessage(1); + + // Make the price ids the same + priceFeedMessages2[0].priceId = priceFeedMessages1[0].priceId; + // Adjust the timestamps so the second timestamp is smaller than the first + // so it doesn't get stored. + priceFeedMessages1[0].publishTime = 10; + priceFeedMessages2[0].publishTime = 5; + + bytes[] memory updateData = new bytes[](2); + + uint8 depth = 0; // 1 messages + uint8 numSigners = 1; + + updateData[0] = generateWhMerkleUpdate( + priceFeedMessages1, + depth, + numSigners + ); + updateData[1] = generateWhMerkleUpdate( + priceFeedMessages2, + depth, + numSigners + ); + + uint updateFee = pyth.getUpdateFee(updateData); + + pyth.updatePriceFeeds{value: updateFee}(updateData); + + assertPriceFeedMessageStored(priceFeedMessages1[0]); + } + + function testUpdatePriceFeedWithWormholeMerkleIgnoresOutOfOrderUpdateMultiCall() + public + { + PriceFeedMessage[] + memory priceFeedMessages1 = generateRandomPriceFeedMessage(1); + PriceFeedMessage[] + memory priceFeedMessages2 = generateRandomPriceFeedMessage(1); + + // Make the price ids the same + priceFeedMessages2[0].priceId = priceFeedMessages1[0].priceId; + // Adjust the timestamps so the second timestamp is smaller than the first + // so it doesn't get stored. + priceFeedMessages1[0].publishTime = 10; + priceFeedMessages2[0].publishTime = 5; + + ( + bytes[] memory updateData, + uint updateFee + ) = createWormholeMerkleUpdateData(priceFeedMessages1); + pyth.updatePriceFeeds{value: updateFee}(updateData); + assertPriceFeedMessageStored(priceFeedMessages1[0]); + + (updateData, updateFee) = createWormholeMerkleUpdateData( + priceFeedMessages2 + ); + pyth.updatePriceFeeds{value: updateFee}(updateData); + // Make sure that the old value is still stored + assertPriceFeedMessageStored(priceFeedMessages1[0]); + } + + function isNotMatch( + bytes memory a, + bytes memory b + ) public pure returns (bool) { + return keccak256(a) != keccak256(b); + } + + /// @notice This method creates a forged invalid wormhole update data. + /// The caller should pass the forgeItem as string and if it matches the + /// expected value, that item will be forged to be invalid. + function createAndForgeWormholeMerkleUpdateData( + bytes memory forgeItem + ) public returns (bytes[] memory updateData, uint updateFee) { + uint numPriceFeeds = 10; + PriceFeedMessage[] + memory priceFeedMessages = generateRandomPriceFeedMessage( + numPriceFeeds + ); + + bytes[] memory encodedPriceFeedMessages = encodePriceFeedMessages( + priceFeedMessages + ); + + (bytes20 rootDigest, bytes[] memory proofs) = MerkleTree + .constructProofs(encodedPriceFeedMessages, 4); // 4 is the depth of the tree (enough for 16 messages) + + bytes memory wormholePayload; + unchecked { + wormholePayload = abi.encodePacked( + isNotMatch(forgeItem, "whMagic") + ? uint32(0x41555756) + : uint32(0x41555750), + isNotMatch(forgeItem, "whUpdateType") + ? uint8(PythAccumulator.UpdateType.WormholeMerkle) + : uint8(PythAccumulator.UpdateType.WormholeMerkle) + 1, + uint32(0), // Storage index, not used in target networks + isNotMatch(forgeItem, "rootDigest") + ? rootDigest + : bytes20(uint160(rootDigest) + 1) + ); + } + + bytes memory wormholeMerkleVaa = generateVaa( + 0, + isNotMatch(forgeItem, "whSourceChain") + ? SOURCE_EMITTER_CHAIN_ID + : SOURCE_EMITTER_CHAIN_ID + 1, + isNotMatch(forgeItem, "whSourceAddress") + ? SOURCE_EMITTER_ADDRESS + : bytes32( + 0x71f8dcb863d176e2c420ad6610cf687359612b6fb392e0642b0ca6b1f186aa00 + ), + 0, + wormholePayload, + 1 // num signers + ); + + updateData = new bytes[](1); + + updateData[0] = abi.encodePacked( + isNotMatch(forgeItem, "headerMagic") + ? uint32(0x504e4155) + : uint32(0x504e4150), // PythAccumulator.ACCUMULATOR_MAGIC + isNotMatch(forgeItem, "headerMajorVersion") ? uint8(1) : uint8(2), // major version + uint8(0), // minor version + uint8(0), // trailing header size + uint8(PythAccumulator.UpdateType.WormholeMerkle), + uint16(wormholeMerkleVaa.length), + wormholeMerkleVaa, + uint8(priceFeedMessages.length) + ); + + for (uint i = 0; i < priceFeedMessages.length; i++) { + updateData[0] = abi.encodePacked( + updateData[0], + uint16(encodedPriceFeedMessages[i].length), + encodedPriceFeedMessages[i], + isNotMatch(forgeItem, "proofItem") ? proofs[i] : proofs[0] + ); + } + + updateFee = pyth.getUpdateFee(updateData); + } + + function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongVAAPayloadMagic() + public + { + // In this test the Wormhole accumulator magic is wrong and the update gets reverted. + ( + bytes[] memory updateData, + uint updateFee + ) = createAndForgeWormholeMerkleUpdateData("whMagic"); + + vm.expectRevert(PythErrors.InvalidUpdateData.selector); + pyth.updatePriceFeeds{value: updateFee}(updateData); + } + + function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongVAAPayloadUpdateType() + public + { + // In this test the Wormhole accumulator magic is wrong and the update gets + // reverted. + + ( + bytes[] memory updateData, + uint updateFee + ) = createAndForgeWormholeMerkleUpdateData("whUpdateType"); + vm.expectRevert(); // Reason: Conversion into non-existent enum type. However it + // was not possible to check the revert reason in the test. + pyth.updatePriceFeeds{value: updateFee}(updateData); + } + + function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongVAASource() + public + { + // In this test the Wormhole message source is wrong. + ( + bytes[] memory updateData, + uint updateFee + ) = createAndForgeWormholeMerkleUpdateData("whSourceAddress"); + vm.expectRevert(PythErrors.InvalidUpdateDataSource.selector); + pyth.updatePriceFeeds{value: updateFee}(updateData); + (updateData, updateFee) = createAndForgeWormholeMerkleUpdateData( + "whSourceChain" + ); + vm.expectRevert(PythErrors.InvalidUpdateDataSource.selector); + pyth.updatePriceFeeds{value: updateFee}(updateData); + } + + function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongRootDigest() + public + { + // In this test the Wormhole merkle proof digest is wrong + ( + bytes[] memory updateData, + uint updateFee + ) = createAndForgeWormholeMerkleUpdateData("rootDigest"); + vm.expectRevert(PythErrors.InvalidUpdateData.selector); + pyth.updatePriceFeeds{value: updateFee}(updateData); + } + + function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongProofItem() + public + { + // In this test all Wormhole merkle proof items are the first item proof + ( + bytes[] memory updateData, + uint updateFee + ) = createAndForgeWormholeMerkleUpdateData("proofItem"); + vm.expectRevert(PythErrors.InvalidUpdateData.selector); + pyth.updatePriceFeeds{value: updateFee}(updateData); + } + + function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongHeader() + public + { + // In this test the message headers are wrong + ( + bytes[] memory updateData, + uint updateFee + ) = createAndForgeWormholeMerkleUpdateData("headerMagic"); + vm.expectRevert(); // The revert reason is not deterministic because when it doesn't match it goes through + // the old approach. + pyth.updatePriceFeeds{value: updateFee}(updateData); + + (updateData, updateFee) = createAndForgeWormholeMerkleUpdateData( + "headerMajorVersion" + ); + vm.expectRevert(PythErrors.InvalidUpdateData.selector); + pyth.updatePriceFeeds{value: updateFee}(updateData); + } + + function testUpdatePriceFeedWithWormholeMerkleRevertsIfUpdateFeeIsNotPaid() + public + { + uint numPriceFeeds = (getRandUint() % 10) + 1; + PriceFeedMessage[] + memory priceFeedMessages = generateRandomPriceFeedMessage( + numPriceFeeds + ); + (bytes[] memory updateData, ) = createWormholeMerkleUpdateData( + priceFeedMessages + ); + + vm.expectRevert(PythErrors.InsufficientFee.selector); + pyth.updatePriceFeeds{value: 0}(updateData); + } +} diff --git a/target_chains/ethereum/contracts/forge-test/Pyth.t.sol b/target_chains/ethereum/contracts/forge-test/Pyth.t.sol index f1fe76a0..0fcaf7bb 100644 --- a/target_chains/ethereum/contracts/forge-test/Pyth.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pyth.t.sol @@ -77,7 +77,7 @@ contract PythTest is Test, WormholeTestUtils, PythTestUtils, RandTestUtils { batchAttestations[j - i] = attestations[j]; } - updateData[i / batchSize] = generatePriceFeedUpdateVAA( + updateData[i / batchSize] = generateWhBatchUpdate( batchAttestations, 0, 1 diff --git a/target_chains/ethereum/contracts/forge-test/VerificationExperiments.t.sol b/target_chains/ethereum/contracts/forge-test/VerificationExperiments.t.sol index e4caa4d2..9ccba5fb 100644 --- a/target_chains/ethereum/contracts/forge-test/VerificationExperiments.t.sol +++ b/target_chains/ethereum/contracts/forge-test/VerificationExperiments.t.sol @@ -186,7 +186,7 @@ contract VerificationExperiments is function generateWormholeUpdateDataAndFee( PythStructs.Price[] memory prices ) internal returns (bytes[] memory updateData, uint updateFee) { - bytes memory vaa = generatePriceFeedUpdateVAA( + bytes memory vaa = generateWhBatchUpdate( pricesToPriceAttestations(priceIds, prices), sequence, NUM_GUARDIAN_SIGNERS @@ -310,7 +310,7 @@ contract VerificationExperiments is return ThresholdUpdate(signature, data); } - function testWormholeBatchUpdate() public { + function testWhBatchUpdate() public { pyth.updatePriceFeeds{value: freshPricesUpdateFee}( freshPricesUpdateData ); diff --git a/target_chains/ethereum/contracts/forge-test/utils/PythTestUtils.t.sol b/target_chains/ethereum/contracts/forge-test/utils/PythTestUtils.t.sol index d6436b88..df32f54f 100644 --- a/target_chains/ethereum/contracts/forge-test/utils/PythTestUtils.t.sol +++ b/target_chains/ethereum/contracts/forge-test/utils/PythTestUtils.t.sol @@ -4,6 +4,10 @@ pragma solidity ^0.8.0; import "../../contracts/pyth/PythUpgradable.sol"; import "../../contracts/pyth/PythInternalStructs.sol"; +import "../../contracts/pyth/PythAccumulator.sol"; + +import "../../contracts/libraries/MerkleTree.sol"; + import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol"; import "@pythnetwork/pyth-sdk-solidity/IPythEvents.sol"; @@ -74,9 +78,89 @@ abstract contract PythTestUtils is Test, WormholeTestUtils { uint64 prevConf; } + struct PriceFeedMessage { + bytes32 priceId; + int64 price; + uint64 conf; + int32 expo; + uint64 publishTime; + uint64 prevPublishTime; + int64 emaPrice; + uint64 emaConf; + } + + function encodePriceFeedMessages( + PriceFeedMessage[] memory priceFeedMessages + ) internal pure returns (bytes[] memory encodedPriceFeedMessages) { + encodedPriceFeedMessages = new bytes[](priceFeedMessages.length); + + for (uint i = 0; i < priceFeedMessages.length; i++) { + encodedPriceFeedMessages[i] = abi.encodePacked( + uint8(PythAccumulator.MessageType.PriceFeed), + priceFeedMessages[i].priceId, + priceFeedMessages[i].price, + priceFeedMessages[i].conf, + priceFeedMessages[i].expo, + priceFeedMessages[i].publishTime, + priceFeedMessages[i].prevPublishTime, + priceFeedMessages[i].emaPrice, + priceFeedMessages[i].emaConf + ); + } + } + + function generateWhMerkleUpdate( + PriceFeedMessage[] memory priceFeedMessages, + uint8 depth, + uint8 numSigners + ) internal returns (bytes memory whMerkleUpdateData) { + bytes[] memory encodedPriceFeedMessages = encodePriceFeedMessages( + priceFeedMessages + ); + + (bytes20 rootDigest, bytes[] memory proofs) = MerkleTree + .constructProofs(encodedPriceFeedMessages, depth); + + bytes memory wormholePayload = abi.encodePacked( + uint32(0x41555756), // PythAccumulator.ACCUMULATOR_WORMHOLE_MAGIC + uint8(PythAccumulator.UpdateType.WormholeMerkle), + uint32(0), // Storage index, not used in target networks + rootDigest + ); + + bytes memory wormholeMerkleVaa = generateVaa( + 0, + SOURCE_EMITTER_CHAIN_ID, + SOURCE_EMITTER_ADDRESS, + 0, + wormholePayload, + numSigners + ); + + whMerkleUpdateData = abi.encodePacked( + uint32(0x504e4155), // PythAccumulator.ACCUMULATOR_MAGIC + uint8(1), // major version + uint8(0), // minor version + uint8(0), // trailing header size + uint8(PythAccumulator.UpdateType.WormholeMerkle), + uint16(wormholeMerkleVaa.length), + wormholeMerkleVaa, + uint8(priceFeedMessages.length) + ); + + for (uint i = 0; i < priceFeedMessages.length; i++) { + whMerkleUpdateData = abi.encodePacked( + whMerkleUpdateData, + uint16(encodedPriceFeedMessages[i].length), + encodedPriceFeedMessages[i], + proofs[i] + ); + } + } + // Generates byte-encoded payload for the given price attestations. You can use this to mock wormhole // call using `vm.mockCall` and return a VM struct with this payload. - // You can use generatePriceFeedUpdateVAA to generate a VAA for a price update. + // You can use generatePriceFeedUpdate to generate a VAA for a price update. function generatePriceFeedUpdatePayload( PriceAttestation[] memory attestations ) public pure returns (bytes memory payload) { @@ -124,7 +208,7 @@ abstract contract PythTestUtils is Test, WormholeTestUtils { // Generates a VAA for the given attestations. // This method calls generatePriceFeedUpdatePayload and then creates a VAA with it. // The VAAs generated from this method use block timestamp as their timestamp. - function generatePriceFeedUpdateVAA( + function generateWhBatchUpdate( PriceAttestation[] memory attestations, uint64 sequence, uint8 numSigners @@ -170,6 +254,24 @@ abstract contract PythTestUtils is Test, WormholeTestUtils { attestations[i].prevConf = prices[i].conf; } } + + function pricesToPriceFeedMessages( + bytes32[] memory priceIds, + PythStructs.Price[] memory prices + ) public returns (PriceFeedMessage[] memory priceFeedMessages) { + assertGe(priceIds.length, prices.length); + priceFeedMessages = new PriceFeedMessage[](prices.length); + + for (uint i = 0; i < prices.length; ++i) { + priceFeedMessages[i].priceId = priceIds[i]; + priceFeedMessages[i].price = prices[i].price; + priceFeedMessages[i].conf = prices[i].conf; + priceFeedMessages[i].expo = prices[i].expo; + priceFeedMessages[i].publishTime = uint64(prices[i].publishTime); + priceFeedMessages[i].emaPrice = prices[i].price; + priceFeedMessages[i].emaConf = prices[i].conf; + } + } } contract PythTestUtilsTest is @@ -178,7 +280,7 @@ contract PythTestUtilsTest is PythTestUtils, IPythEvents { - function testGeneratePriceFeedUpdateVAAWorks() public { + function testGenerateWhBatchUpdateWorks() public { IPyth pyth = IPyth( setUpPyth( setUpWormhole( @@ -200,7 +302,7 @@ contract PythTestUtilsTest is 1 // Publish time ); - bytes memory vaa = generatePriceFeedUpdateVAA( + bytes memory vaa = generateWhBatchUpdate( pricesToPriceAttestations(priceIds, prices), 1, // Sequence 1 // No. Signers @@ -211,7 +313,7 @@ contract PythTestUtilsTest is uint updateFee = pyth.getUpdateFee(updateData); - vm.expectEmit(true, true, false, true); + vm.expectEmit(true, false, false, true); emit PriceFeedUpdate(priceIds[0], 1, 100, 10); pyth.updatePriceFeeds{value: updateFee}(updateData); diff --git a/target_chains/ethereum/contracts/forge-test/utils/RandTestUtils.t.sol b/target_chains/ethereum/contracts/forge-test/utils/RandTestUtils.t.sol index b6524466..eec2894d 100644 --- a/target_chains/ethereum/contracts/forge-test/utils/RandTestUtils.t.sol +++ b/target_chains/ethereum/contracts/forge-test/utils/RandTestUtils.t.sol @@ -38,4 +38,8 @@ contract RandTestUtils is Test { function getRandInt32() internal returns (int32) { return int32(getRandUint32()); } + + function getRandUint8() internal returns (uint8) { + return uint8(getRandUint()); + } } diff --git a/target_chains/ethereum/contracts/package.json b/target_chains/ethereum/contracts/package.json index 2796642b..3fb40d34 100644 --- a/target_chains/ethereum/contracts/package.json +++ b/target_chains/ethereum/contracts/package.json @@ -1,6 +1,6 @@ { "name": "@pythnetwork/pyth-evm-contract", - "version": "1.2.0", + "version": "1.3.0-alpha", "description": "", "private": "true", "devDependencies": {