[eth] Add Pyth Accumulator (#776)

This PR adds the support WormholeMerkle accumulator message to the ethereum contract while still supporting the old message format. The code is not optimized yet and with more optimizations we can achieve a better gas usage. Currently based on the gas benchmark below it has a 18% improvement with a single price feed. Although the cost of updating 5 feeds in the same batch is higher than the current approach but in reality the chances that all 5 feeds be in the same batch is very low.
This commit is contained in:
Ali Behjati 2023-05-03 16:08:53 +02:00 committed by GitHub
parent e7b72bf5c3
commit f94dceb1bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 1155 additions and 47 deletions

2
package-lock.json generated
View File

@ -49691,7 +49691,7 @@
}, },
"target_chains/ethereum/contracts": { "target_chains/ethereum/contracts": {
"name": "@pythnetwork/pyth-evm-contract", "name": "@pythnetwork/pyth-evm-contract",
"version": "1.2.0", "version": "1.3.0-alpha",
"license": "ISC", "license": "ISC",
"dependencies": { "dependencies": {
"@certusone/wormhole-sdk": "^0.9.9", "@certusone/wormhole-sdk": "^0.9.9",

View File

@ -0,0 +1,145 @@
// SPDX-License-Identifier: Apache 2
pragma solidity ^0.8.0;
import "./external/UnsafeBytesLib.sol";
/**
* @dev This library provides methods to construct and verify Merkle Tree proofs efficiently.
*
*/
library MerkleTree {
uint8 constant MERKLE_LEAF_PREFIX = 0;
uint8 constant MERKLE_NODE_PREFIX = 1;
uint8 constant MERKLE_EMPTY_LEAF_PREFIX = 2;
function hash(bytes memory input) internal pure returns (bytes20) {
return bytes20(keccak256(input));
}
function emptyLeafHash() internal pure returns (bytes20) {
return hash(abi.encodePacked(MERKLE_EMPTY_LEAF_PREFIX));
}
function leafHash(bytes memory data) internal pure returns (bytes20) {
return hash(abi.encodePacked(MERKLE_LEAF_PREFIX, data));
}
function nodeHash(
bytes20 childA,
bytes20 childB
) internal pure returns (bytes20) {
if (childA > childB) {
(childA, childB) = (childB, childA);
}
return hash(abi.encodePacked(MERKLE_NODE_PREFIX, childA, childB));
}
/// @notice Verify Merkle Tree proof for given leaf data.
/// @dev To optimize gas usage, this method doesn't take the proof as a bytes array
/// but rather takes the encoded proof and the offset of the proof in the
/// encoded proof array possibly containing multiple proofs. Also, the method
/// does not perform any check on the boundry of the `encodedProof` and the
/// `proofOffset` parameters. It is the caller's responsibility to ensure
/// that the `encodedProof` is long enough to contain the proof and the
/// `proofOffset` is not out of bound.
function isProofValid(
bytes memory encodedProof,
uint proofOffset,
bytes20 root,
bytes memory leafData
) internal pure returns (bool valid, uint endOffset) {
unchecked {
bytes20 currentDigest = MerkleTree.leafHash(leafData);
uint8 proofSize = UnsafeBytesLib.toUint8(encodedProof, proofOffset);
proofOffset += 1;
for (uint i = 0; i < proofSize; i++) {
bytes20 siblingDigest = bytes20(
UnsafeBytesLib.toAddress(encodedProof, proofOffset)
);
proofOffset += 20;
currentDigest = MerkleTree.nodeHash(
currentDigest,
siblingDigest
);
}
valid = currentDigest == root;
endOffset = proofOffset;
}
}
/// @notice Construct Merkle Tree proofs for given list of messages.
/// @dev This function is only used for testing purposes and is not efficient
/// for production use-cases.
///
/// This method creates a merkle tree with leaf size of (2^depth) with the
/// messages as leafs (in the same given order) and returns the root digest
/// and the proofs for each message. If the number of messages is not a power
/// of 2, the tree is padded with empty messages.
function constructProofs(
bytes[] memory messages,
uint8 depth
) internal pure returns (bytes20 root, bytes[] memory proofs) {
require((1 << depth) >= messages.length, "depth too small");
bytes20[] memory tree = new bytes20[]((1 << (depth + 1)));
// The tree is structured as follows:
// 1
// 2 3
// 4 5 6 7
// ...
// In this structure the parent of node x is x//2 and the children
// of node x are x*2 and x*2 + 1. Also, the sibling of the node x
// is x^1. The root is at index 1 and index 0 is not used.
// Filling the leaf hashes
bytes20 cachedEmptyLeafHash = emptyLeafHash();
for (uint i = 0; i < (1 << depth); i++) {
if (i < messages.length) {
tree[(1 << depth) + i] = leafHash(messages[i]);
} else {
tree[(1 << depth) + i] = cachedEmptyLeafHash;
}
}
// Filling the node hashes from bottom to top
for (uint k = depth; k > 0; k--) {
uint level = k - 1;
uint levelNumNodes = (1 << level);
for (uint i = 0; i < levelNumNodes; i++) {
uint id = (1 << level) + i;
tree[id] = nodeHash(tree[id * 2], tree[id * 2 + 1]);
}
}
root = tree[1];
proofs = new bytes[](messages.length);
for (uint i = 0; i < messages.length; i++) {
// depth is the number of sibling nodes in the path from the leaf to the root
proofs[i] = abi.encodePacked(depth);
uint idx = (1 << depth) + i;
// This loop iterates through the leaf and its parents
// and keeps adding the sibling of the current node to the proof.
while (idx > 1) {
proofs[i] = abi.encodePacked(
proofs[i],
tree[idx ^ 1] // Sibling of this node
);
// Jump to parent
idx /= 2;
}
}
}
}

View File

@ -1,4 +1,3 @@
// contracts/Bridge.sol
// SPDX-License-Identifier: Apache 2 // SPDX-License-Identifier: Apache 2
pragma solidity ^0.8.0; pragma solidity ^0.8.0;
@ -8,11 +7,17 @@ import "@pythnetwork/pyth-sdk-solidity/AbstractPyth.sol";
import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol"; import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol";
import "@pythnetwork/pyth-sdk-solidity/PythErrors.sol"; import "@pythnetwork/pyth-sdk-solidity/PythErrors.sol";
import "./PythAccumulator.sol";
import "./PythGetters.sol"; import "./PythGetters.sol";
import "./PythSetters.sol"; import "./PythSetters.sol";
import "./PythInternalStructs.sol"; import "./PythInternalStructs.sol";
abstract contract Pyth is PythGetters, PythSetters, AbstractPyth { abstract contract Pyth is
PythGetters,
PythSetters,
AbstractPyth,
PythAccumulator
{
function _initialize( function _initialize(
address wormhole, address wormhole,
uint16[] calldata dataSourceEmitterChainIds, uint16[] calldata dataSourceEmitterChainIds,
@ -66,11 +71,19 @@ abstract contract Pyth is PythGetters, PythSetters, AbstractPyth {
function updatePriceFeeds( function updatePriceFeeds(
bytes[] calldata updateData bytes[] calldata updateData
) public payable override { ) public payable override {
// TODO: Is this fee model still good for accumulator?
uint requiredFee = getUpdateFee(updateData); uint requiredFee = getUpdateFee(updateData);
if (msg.value < requiredFee) revert PythErrors.InsufficientFee(); if (msg.value < requiredFee) revert PythErrors.InsufficientFee();
for (uint i = 0; i < updateData.length; ) { for (uint i = 0; i < updateData.length; ) {
updatePriceBatchFromVm(updateData[i]); if (
updateData[i].length > 4 &&
UnsafeBytesLib.toUint32(updateData[i], 0) == ACCUMULATOR_MAGIC
) {
updatePricesUsingAccumulator(updateData[i]);
} else {
updatePriceBatchFromVm(updateData[i]);
}
unchecked { unchecked {
i++; i++;
@ -536,6 +549,6 @@ abstract contract Pyth is PythGetters, PythSetters, AbstractPyth {
} }
function version() public pure returns (string memory) { function version() public pure returns (string memory) {
return "1.2.0"; return "1.3.0-alpha";
} }
} }

View File

@ -0,0 +1,319 @@
// SPDX-License-Identifier: Apache 2
pragma solidity ^0.8.0;
import "../libraries/external/UnsafeBytesLib.sol";
import "@pythnetwork/pyth-sdk-solidity/AbstractPyth.sol";
import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol";
import "@pythnetwork/pyth-sdk-solidity/PythErrors.sol";
import "./PythGetters.sol";
import "./PythSetters.sol";
import "./PythInternalStructs.sol";
import "../libraries/MerkleTree.sol";
abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
uint32 constant ACCUMULATOR_MAGIC = 0x504e4155; // Stands for PNAU (Pyth Network Accumulator Update)
uint32 constant ACCUMULATOR_WORMHOLE_MAGIC = 0x41555756; // Stands for AUWV (Accumulator Update Wormhole Verficiation)
uint8 constant MINIMUM_ALLOWED_MINOR_VERSION = 0;
uint8 constant MAJOR_VERSION = 1;
enum UpdateType {
WormholeMerkle
}
enum MessageType {
PriceFeed
}
// This method is also used by batch attestation but moved here
// as the batch attestation will deprecate soon.
function parseAndVerifyPythVM(
bytes memory encodedVm
) internal view returns (IWormhole.VM memory vm) {
{
bool valid;
(vm, valid, ) = wormhole().parseAndVerifyVM(encodedVm);
if (!valid) revert PythErrors.InvalidWormholeVaa();
}
if (!isValidDataSource(vm.emitterChainId, vm.emitterAddress))
revert PythErrors.InvalidUpdateDataSource();
}
function updatePricesUsingAccumulator(
bytes calldata accumulatorUpdate
) internal {
unchecked {
uint offset = 0;
{
uint32 magic = UnsafeBytesLib.toUint32(
accumulatorUpdate,
offset
);
offset += 4;
if (magic != ACCUMULATOR_MAGIC)
revert PythErrors.InvalidUpdateData();
uint8 majorVersion = UnsafeBytesLib.toUint8(
accumulatorUpdate,
offset
);
offset += 1;
if (majorVersion != MAJOR_VERSION)
revert PythErrors.InvalidUpdateData();
uint8 minorVersion = UnsafeBytesLib.toUint8(
accumulatorUpdate,
offset
);
offset += 1;
// Minor versions are forward compatible, so we only check
// that the minor version is not less than the minimum allowed
if (minorVersion < MINIMUM_ALLOWED_MINOR_VERSION)
revert PythErrors.InvalidUpdateData();
// This field ensure that we can add headers in the future
// without breaking the contract (future compatibility)
uint8 trailingHeaderSize = UnsafeBytesLib.toUint8(
accumulatorUpdate,
offset
);
offset += 1;
// We use another offset for the trailing header and in the end add the
// offset by trailingHeaderSize to skip the future headers.
//
// An example would be like this:
// uint trailingHeaderOffset = offset
// uint x = UnsafeBytesLib.ToUint8(accumulatorUpdate, trailingHeaderOffset)
// trailingHeaderOffset += 1
offset += trailingHeaderSize;
}
UpdateType updateType = UpdateType(
UnsafeBytesLib.toUint8(accumulatorUpdate, offset)
);
offset += 1;
if (accumulatorUpdate.length < offset)
revert PythErrors.InvalidUpdateData();
if (updateType == UpdateType.WormholeMerkle) {
updatePricesUsingWormholeMerkle(
UnsafeBytesLib.slice(
accumulatorUpdate,
offset,
accumulatorUpdate.length - offset
)
);
} else {
revert PythErrors.InvalidUpdateData();
}
}
}
function updatePricesUsingWormholeMerkle(bytes memory encoded) private {
unchecked {
uint offset = 0;
uint16 whProofSize = UnsafeBytesLib.toUint16(encoded, offset);
offset += 2;
bytes20 digest;
{
IWormhole.VM memory vm = parseAndVerifyPythVM(
UnsafeBytesLib.slice(encoded, offset, whProofSize)
);
offset += whProofSize;
// TODO: Do we need to emit an update for accumulator update? If so what should we emit?
// emit AccumulatorUpdate(vm.chainId, vm.sequence);
bytes memory encodedPayload = vm.payload;
uint payloadoffset = 0;
{
uint32 magic = UnsafeBytesLib.toUint32(
encodedPayload,
payloadoffset
);
payloadoffset += 4;
if (magic != ACCUMULATOR_WORMHOLE_MAGIC)
revert PythErrors.InvalidUpdateData();
UpdateType updateType = UpdateType(
UnsafeBytesLib.toUint8(encodedPayload, payloadoffset)
);
payloadoffset += 1;
if (updateType != UpdateType.WormholeMerkle)
revert PythErrors.InvalidUpdateData();
// This field is not used
// uint32 storageIndex = UnsafeBytesLib.toUint32(encodedPayload, payloadoffset);
payloadoffset += 4;
digest = bytes20(
UnsafeBytesLib.toAddress(encodedPayload, payloadoffset)
);
payloadoffset += 20;
// We don't check equality to enable future compatibility.
if (payloadoffset > encodedPayload.length)
revert PythErrors.InvalidUpdateData();
}
}
uint8 numUpdates = UnsafeBytesLib.toUint8(encoded, offset);
offset += 1;
for (uint i = 0; i < numUpdates; i++) {
offset = verifyAndUpdatePriceFeedFromMerkleProof(
digest,
encoded,
offset
);
}
if (offset != encoded.length) revert PythErrors.InvalidUpdateData();
}
}
function verifyAndUpdatePriceFeedFromMerkleProof(
bytes20 digest,
bytes memory encoded,
uint offset
) private returns (uint endOffset) {
unchecked {
uint16 messageSize = UnsafeBytesLib.toUint16(encoded, offset);
offset += 2;
bytes memory encodedMessage = UnsafeBytesLib.slice(
encoded,
offset,
messageSize
);
offset += messageSize;
bool valid;
(valid, offset) = MerkleTree.isProofValid(
encoded,
offset,
digest,
encodedMessage
);
if (!valid) {
revert PythErrors.InvalidUpdateData();
}
parseAndProcessMessage(encodedMessage);
return offset;
}
}
function parsePriceFeedMessage(
bytes memory encodedPriceFeed
)
private
pure
returns (
PythInternalStructs.PriceInfo memory priceInfo,
bytes32 priceId
)
{
unchecked {
uint offset = 0;
priceId = UnsafeBytesLib.toBytes32(encodedPriceFeed, offset);
offset += 32;
priceInfo.price = int64(
UnsafeBytesLib.toUint64(encodedPriceFeed, offset)
);
offset += 8;
priceInfo.conf = UnsafeBytesLib.toUint64(encodedPriceFeed, offset);
offset += 8;
priceInfo.expo = int32(
UnsafeBytesLib.toUint32(encodedPriceFeed, offset)
);
offset += 4;
// Publish time is i64 in some environments due to the standard in that
// environment. This would not cause any problem because since the signed
// integer is represented in two's complement, the value would be the same
// in both cases (for a million year at least)
priceInfo.publishTime = UnsafeBytesLib.toUint64(
encodedPriceFeed,
offset
);
offset += 8;
// We do not store this field because it is not used on the latest feed queries.
// uint64 prevPublishTime = UnsafeBytesLib.toUint64(encodedPriceFeed, offset);
offset += 8;
priceInfo.emaPrice = int64(
UnsafeBytesLib.toUint64(encodedPriceFeed, offset)
);
offset += 8;
priceInfo.emaConf = UnsafeBytesLib.toUint64(
encodedPriceFeed,
offset
);
offset += 8;
// We don't check equality to enable future compatibility.
if (offset > encodedPriceFeed.length)
revert PythErrors.InvalidUpdateData();
}
}
function parseAndProcessMessage(bytes memory encodedMessage) private {
unchecked {
MessageType messageType = MessageType(
UnsafeBytesLib.toUint8(encodedMessage, 0)
);
if (messageType == MessageType.PriceFeed) {
(
PythInternalStructs.PriceInfo memory info,
bytes32 priceId
) = parsePriceFeedMessage(
UnsafeBytesLib.slice(
encodedMessage,
1,
encodedMessage.length - 1
)
);
uint64 latestPublishTime = latestPriceInfoPublishTime(priceId);
if (info.publishTime > latestPublishTime) {
setLatestPriceInfo(priceId, info);
emit PriceFeedUpdate(
priceId,
info.publishTime,
info.price,
info.conf
);
}
} else {
revert PythErrors.InvalidUpdateData();
}
}
}
}

View File

@ -22,22 +22,30 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils {
// We use 5 prices to form a batch of 5 prices, close to our mainnet transactions. // We use 5 prices to form a batch of 5 prices, close to our mainnet transactions.
uint8 constant NUM_PRICES = 5; uint8 constant NUM_PRICES = 5;
// We will have less than 512 price for a foreseeable future.
uint8 constant MERKLE_TREE_DEPTH = 9;
IPyth public pyth; IPyth public pyth;
bytes32[] priceIds; bytes32[] priceIds;
// Cached prices are populated in the setUp // Cached prices are populated in the setUp
PythStructs.Price[] cachedPrices; PythStructs.Price[] cachedPrices;
bytes[] cachedPricesUpdateData; bytes[] cachedPricesWhBatchUpdateData;
uint cachedPricesUpdateFee; uint cachedPricesWhBatchUpdateFee;
uint64[] cachedPricesPublishTimes; uint64[] cachedPricesPublishTimes;
bytes[][] cachedPricesWhMerkleUpdateData; // i th element contains the update data for the first i prices
uint[] cachedPricesWhMerkleUpdateFee; // i th element contains the update fee for the first i prices
// Fresh prices are different prices that can be used // Fresh prices are different prices that can be used
// as a fresh price to update the prices // as a fresh price to update the prices
PythStructs.Price[] freshPrices; PythStructs.Price[] freshPrices;
bytes[] freshPricesUpdateData; bytes[] freshPricesWhBatchUpdateData;
uint freshPricesUpdateFee; uint freshPricesWhBatchUpdateFee;
uint64[] freshPricesPublishTimes; uint64[] freshPricesPublishTimes;
bytes[][] freshPricesWhMerkleUpdateData; // i th element contains the update data for the first i prices
uint[] freshPricesWhMerkleUpdateFee; // i th element contains the update fee for the first i prices
uint64 sequence; uint64 sequence;
uint randSeed; uint randSeed;
@ -76,21 +84,37 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils {
) )
); );
freshPricesPublishTimes.push(publishTime); freshPricesPublishTimes.push(publishTime);
// Generate Wormhole Merkle update data and fee for the first i th prices
(
bytes[] memory updateData,
uint updateFee
) = generateWhMerkleUpdateDataAndFee(cachedPrices);
cachedPricesWhMerkleUpdateData.push(updateData);
cachedPricesWhMerkleUpdateFee.push(updateFee);
(updateData, updateFee) = generateWhMerkleUpdateDataAndFee(
freshPrices
);
freshPricesWhMerkleUpdateData.push(updateData);
freshPricesWhMerkleUpdateFee.push(updateFee);
} }
// Populate the contract with the initial prices // Populate the contract with the initial prices
( (
cachedPricesUpdateData, cachedPricesWhBatchUpdateData,
cachedPricesUpdateFee cachedPricesWhBatchUpdateFee
) = generateUpdateDataAndFee(cachedPrices); ) = generateWhBatchUpdateDataAndFee(cachedPrices);
pyth.updatePriceFeeds{value: cachedPricesUpdateFee}( pyth.updatePriceFeeds{value: cachedPricesWhBatchUpdateFee}(
cachedPricesUpdateData cachedPricesWhBatchUpdateData
); );
( (
freshPricesUpdateData, freshPricesWhBatchUpdateData,
freshPricesUpdateFee freshPricesWhBatchUpdateFee
) = generateUpdateDataAndFee(freshPrices); ) = generateWhBatchUpdateDataAndFee(freshPrices);
} }
function getRand() internal returns (uint val) { function getRand() internal returns (uint val) {
@ -98,10 +122,10 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils {
val = uint(keccak256(abi.encode(randSeed))); val = uint(keccak256(abi.encode(randSeed)));
} }
function generateUpdateDataAndFee( function generateWhBatchUpdateDataAndFee(
PythStructs.Price[] memory prices PythStructs.Price[] memory prices
) internal returns (bytes[] memory updateData, uint updateFee) { ) internal returns (bytes[] memory updateData, uint updateFee) {
bytes memory vaa = generatePriceFeedUpdateVAA( bytes memory vaa = generateWhBatchUpdate(
pricesToPriceAttestations(priceIds, prices), pricesToPriceAttestations(priceIds, prices),
sequence, sequence,
NUM_GUARDIAN_SIGNERS NUM_GUARDIAN_SIGNERS
@ -115,35 +139,109 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils {
updateFee = pyth.getUpdateFee(updateData); updateFee = pyth.getUpdateFee(updateData);
} }
function testBenchmarkUpdatePriceFeedsFresh() public { function generateWhMerkleUpdateDataAndFee(
pyth.updatePriceFeeds{value: freshPricesUpdateFee}( PythStructs.Price[] memory prices
freshPricesUpdateData ) internal returns (bytes[] memory updateData, uint updateFee) {
updateData = new bytes[](1);
updateData[0] = generateWhMerkleUpdate(
pricesToPriceFeedMessages(priceIds, prices),
MERKLE_TREE_DEPTH,
NUM_GUARDIAN_SIGNERS
);
updateFee = pyth.getUpdateFee(updateData);
}
function testBenchmarkUpdatePriceFeedsWhBatchFresh() public {
pyth.updatePriceFeeds{value: freshPricesWhBatchUpdateFee}(
freshPricesWhBatchUpdateData
); );
} }
function testBenchmarkUpdatePriceFeedsNotFresh() public { function testBenchmarkUpdatePriceFeedsWhBatchNotFresh() public {
pyth.updatePriceFeeds{value: cachedPricesUpdateFee}( pyth.updatePriceFeeds{value: cachedPricesWhBatchUpdateFee}(
cachedPricesUpdateData cachedPricesWhBatchUpdateData
); );
} }
function testBenchmarkUpdatePriceFeedsIfNecessaryFresh() public { function testBenchmarkUpdatePriceFeedsWhMerkle1FeedFresh() public {
pyth.updatePriceFeeds{value: freshPricesWhMerkleUpdateFee[0]}(
freshPricesWhMerkleUpdateData[0]
);
}
function testBenchmarkUpdatePriceFeedsWhMerkle2FeedsFresh() public {
pyth.updatePriceFeeds{value: freshPricesWhMerkleUpdateFee[1]}(
freshPricesWhMerkleUpdateData[1]
);
}
function testBenchmarkUpdatePriceFeedsWhMerkle3FeedsFresh() public {
pyth.updatePriceFeeds{value: freshPricesWhMerkleUpdateFee[2]}(
freshPricesWhMerkleUpdateData[2]
);
}
function testBenchmarkUpdatePriceFeedsWhMerkle4FeedsFresh() public {
pyth.updatePriceFeeds{value: freshPricesWhMerkleUpdateFee[3]}(
freshPricesWhMerkleUpdateData[3]
);
}
function testBenchmarkUpdatePriceFeedsWhMerkle5FeedsFresh() public {
pyth.updatePriceFeeds{value: freshPricesWhMerkleUpdateFee[4]}(
freshPricesWhMerkleUpdateData[4]
);
}
function testBenchmarkUpdatePriceFeedsWhMerkle1FeedNotFresh() public {
pyth.updatePriceFeeds{value: cachedPricesWhMerkleUpdateFee[0]}(
cachedPricesWhMerkleUpdateData[0]
);
}
function testBenchmarkUpdatePriceFeedsWhMerkle2FeedsNotFresh() public {
pyth.updatePriceFeeds{value: cachedPricesWhMerkleUpdateFee[1]}(
cachedPricesWhMerkleUpdateData[1]
);
}
function testBenchmarkUpdatePriceFeedsWhMerkle3FeedsNotFresh() public {
pyth.updatePriceFeeds{value: cachedPricesWhMerkleUpdateFee[2]}(
cachedPricesWhMerkleUpdateData[2]
);
}
function testBenchmarkUpdatePriceFeedsWhMerkle4FeedsNotFresh() public {
pyth.updatePriceFeeds{value: cachedPricesWhMerkleUpdateFee[3]}(
cachedPricesWhMerkleUpdateData[3]
);
}
function testBenchmarkUpdatePriceFeedsWhMerkle5FeedsNotFresh() public {
pyth.updatePriceFeeds{value: cachedPricesWhMerkleUpdateFee[4]}(
cachedPricesWhMerkleUpdateData[4]
);
}
function testBenchmarkUpdatePriceFeedsIfNecessaryWhBatchFresh() public {
// Since the prices have advanced, the publishTimes are newer than one in // Since the prices have advanced, the publishTimes are newer than one in
// the contract and hence, the call should succeed. // the contract and hence, the call should succeed.
pyth.updatePriceFeedsIfNecessary{value: freshPricesUpdateFee}( pyth.updatePriceFeedsIfNecessary{value: freshPricesWhBatchUpdateFee}(
freshPricesUpdateData, freshPricesWhBatchUpdateData,
priceIds, priceIds,
freshPricesPublishTimes freshPricesPublishTimes
); );
} }
function testBenchmarkUpdatePriceFeedsIfNecessaryNotFresh() public { function testBenchmarkUpdatePriceFeedsIfNecessaryWhBatchNotFresh() public {
// Since the price is not advanced, the publishTimes are the same as the // Since the price is not advanced, the publishTimes are the same as the
// ones in the contract. // ones in the contract.
vm.expectRevert(PythErrors.NoFreshUpdate.selector); vm.expectRevert(PythErrors.NoFreshUpdate.selector);
pyth.updatePriceFeedsIfNecessary{value: cachedPricesUpdateFee}( pyth.updatePriceFeedsIfNecessary{value: cachedPricesWhBatchUpdateFee}(
cachedPricesUpdateData, cachedPricesWhBatchUpdateData,
priceIds, priceIds,
cachedPricesPublishTimes cachedPricesPublishTimes
); );
@ -153,8 +251,8 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils {
bytes32[] memory ids = new bytes32[](1); bytes32[] memory ids = new bytes32[](1);
ids[0] = priceIds[0]; ids[0] = priceIds[0];
pyth.parsePriceFeedUpdates{value: freshPricesUpdateFee}( pyth.parsePriceFeedUpdates{value: freshPricesWhBatchUpdateFee}(
freshPricesUpdateData, freshPricesWhBatchUpdateData,
ids, ids,
0, 0,
50 50
@ -166,8 +264,8 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils {
ids[0] = priceIds[0]; ids[0] = priceIds[0];
ids[1] = priceIds[1]; ids[1] = priceIds[1];
pyth.parsePriceFeedUpdates{value: freshPricesUpdateFee}( pyth.parsePriceFeedUpdates{value: freshPricesWhBatchUpdateFee}(
freshPricesUpdateData, freshPricesWhBatchUpdateData,
ids, ids,
0, 0,
50 50
@ -181,8 +279,8 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils {
ids[0] = priceIds[0]; ids[0] = priceIds[0];
vm.expectRevert(PythErrors.PriceFeedNotFoundWithinRange.selector); vm.expectRevert(PythErrors.PriceFeedNotFoundWithinRange.selector);
pyth.parsePriceFeedUpdates{value: freshPricesUpdateFee}( pyth.parsePriceFeedUpdates{value: freshPricesWhBatchUpdateFee}(
freshPricesUpdateData, freshPricesWhBatchUpdateData,
ids, ids,
50, 50,
100 100
@ -206,6 +304,6 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils {
} }
function testBenchmarkGetUpdateFee() public view { function testBenchmarkGetUpdateFee() public view {
pyth.getUpdateFee(freshPricesUpdateData); pyth.getUpdateFee(freshPricesWhBatchUpdateData);
} }
} }

View File

@ -0,0 +1,427 @@
// SPDX-License-Identifier: Apache 2
pragma solidity ^0.8.0;
import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol";
import "forge-std/Test.sol";
import "@pythnetwork/pyth-sdk-solidity/IPyth.sol";
import "@pythnetwork/pyth-sdk-solidity/PythErrors.sol";
import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol";
import "./utils/WormholeTestUtils.t.sol";
import "./utils/PythTestUtils.t.sol";
import "./utils/RandTestUtils.t.sol";
import "../contracts/libraries/MerkleTree.sol";
contract PythWormholeMerkleAccumulatorTest is
Test,
WormholeTestUtils,
PythTestUtils,
RandTestUtils
{
IPyth public pyth;
function setUp() public {
pyth = IPyth(setUpPyth(setUpWormhole(1)));
}
function assertPriceFeedMessageStored(
PriceFeedMessage memory priceFeedMessage
) internal {
PythStructs.Price memory aggregatePrice = pyth.getPriceUnsafe(
priceFeedMessage.priceId
);
assertEq(aggregatePrice.price, priceFeedMessage.price);
assertEq(aggregatePrice.conf, priceFeedMessage.conf);
assertEq(aggregatePrice.expo, priceFeedMessage.expo);
assertEq(aggregatePrice.publishTime, priceFeedMessage.publishTime);
PythStructs.Price memory emaPrice = pyth.getEmaPriceUnsafe(
priceFeedMessage.priceId
);
assertEq(emaPrice.price, priceFeedMessage.emaPrice);
assertEq(emaPrice.conf, priceFeedMessage.emaConf);
assertEq(emaPrice.expo, priceFeedMessage.expo);
assertEq(emaPrice.publishTime, priceFeedMessage.publishTime);
}
function generateRandomPriceFeedMessage(
uint numPriceFeeds
) internal returns (PriceFeedMessage[] memory priceFeedMessages) {
priceFeedMessages = new PriceFeedMessage[](numPriceFeeds);
for (uint i = 0; i < numPriceFeeds; i++) {
priceFeedMessages[i] = PriceFeedMessage({
priceId: getRandBytes32(),
price: getRandInt64(),
conf: getRandUint64(),
expo: getRandInt32(),
publishTime: getRandUint64(),
prevPublishTime: getRandUint64(),
emaPrice: getRandInt64(),
emaConf: getRandUint64()
});
}
}
function createWormholeMerkleUpdateData(
PriceFeedMessage[] memory priceFeedMessages
) internal returns (bytes[] memory updateData, uint updateFee) {
updateData = new bytes[](1);
uint8 depth = 0;
while ((1 << depth) < priceFeedMessages.length) {
depth++;
}
depth += getRandUint8() % 3;
updateData[0] = generateWhMerkleUpdate(priceFeedMessages, depth, 1);
updateFee = pyth.getUpdateFee(updateData);
}
/// Testing update price feeds method using wormhole merkle update type.
function testUpdatePriceFeedWithWormholeMerkleWorks(uint seed) public {
setRandSeed(seed);
uint numPriceFeeds = (getRandUint() % 10) + 1;
PriceFeedMessage[]
memory priceFeedMessages = generateRandomPriceFeedMessage(
numPriceFeeds
);
(
bytes[] memory updateData,
uint updateFee
) = createWormholeMerkleUpdateData(priceFeedMessages);
pyth.updatePriceFeeds{value: updateFee}(updateData);
for (uint i = 0; i < numPriceFeeds; i++) {
assertPriceFeedMessageStored(priceFeedMessages[i]);
}
// Update the prices again with the same data should work
pyth.updatePriceFeeds{value: updateFee}(updateData);
for (uint i = 0; i < numPriceFeeds; i++) {
assertPriceFeedMessageStored(priceFeedMessages[i]);
}
// Update the prices again with updated data should update the prices
for (uint i = 0; i < numPriceFeeds; i++) {
priceFeedMessages[i].price = getRandInt64();
priceFeedMessages[i].conf = getRandUint64();
priceFeedMessages[i].expo = getRandInt32();
// Increase the publish time if it is not causing an overflow
if (priceFeedMessages[i].publishTime != type(uint64).max) {
priceFeedMessages[i].publishTime += 1;
}
priceFeedMessages[i].emaPrice = getRandInt64();
priceFeedMessages[i].emaConf = getRandUint64();
}
(updateData, updateFee) = createWormholeMerkleUpdateData(
priceFeedMessages
);
pyth.updatePriceFeeds{value: updateFee}(updateData);
for (uint i = 0; i < numPriceFeeds; i++) {
assertPriceFeedMessageStored(priceFeedMessages[i]);
}
}
function testUpdatePriceFeedWithWormholeMerkleWorksOnMultiUpdate() public {
PriceFeedMessage[]
memory priceFeedMessages1 = generateRandomPriceFeedMessage(2);
PriceFeedMessage[]
memory priceFeedMessages2 = generateRandomPriceFeedMessage(2);
// Make the 2nd message of the second update the same as the 1st message of the first update
priceFeedMessages2[1].priceId = priceFeedMessages1[0].priceId;
// Adjust the timestamps so the second timestamp is greater than the first
priceFeedMessages1[0].publishTime = 5;
priceFeedMessages2[1].publishTime = 10;
bytes[] memory updateData = new bytes[](2);
uint8 depth = 1; // 2 messages
uint8 numSigners = 1;
updateData[0] = generateWhMerkleUpdate(
priceFeedMessages1,
depth,
numSigners
);
updateData[1] = generateWhMerkleUpdate(
priceFeedMessages2,
depth,
numSigners
);
uint updateFee = pyth.getUpdateFee(updateData);
pyth.updatePriceFeeds{value: updateFee}(updateData);
assertPriceFeedMessageStored(priceFeedMessages1[1]);
assertPriceFeedMessageStored(priceFeedMessages2[0]);
assertPriceFeedMessageStored(priceFeedMessages2[1]);
}
function testUpdatePriceFeedWithWormholeMerkleIgnoresOutOfOrderUpdateSingleCall()
public
{
PriceFeedMessage[]
memory priceFeedMessages1 = generateRandomPriceFeedMessage(1);
PriceFeedMessage[]
memory priceFeedMessages2 = generateRandomPriceFeedMessage(1);
// Make the price ids the same
priceFeedMessages2[0].priceId = priceFeedMessages1[0].priceId;
// Adjust the timestamps so the second timestamp is smaller than the first
// so it doesn't get stored.
priceFeedMessages1[0].publishTime = 10;
priceFeedMessages2[0].publishTime = 5;
bytes[] memory updateData = new bytes[](2);
uint8 depth = 0; // 1 messages
uint8 numSigners = 1;
updateData[0] = generateWhMerkleUpdate(
priceFeedMessages1,
depth,
numSigners
);
updateData[1] = generateWhMerkleUpdate(
priceFeedMessages2,
depth,
numSigners
);
uint updateFee = pyth.getUpdateFee(updateData);
pyth.updatePriceFeeds{value: updateFee}(updateData);
assertPriceFeedMessageStored(priceFeedMessages1[0]);
}
function testUpdatePriceFeedWithWormholeMerkleIgnoresOutOfOrderUpdateMultiCall()
public
{
PriceFeedMessage[]
memory priceFeedMessages1 = generateRandomPriceFeedMessage(1);
PriceFeedMessage[]
memory priceFeedMessages2 = generateRandomPriceFeedMessage(1);
// Make the price ids the same
priceFeedMessages2[0].priceId = priceFeedMessages1[0].priceId;
// Adjust the timestamps so the second timestamp is smaller than the first
// so it doesn't get stored.
priceFeedMessages1[0].publishTime = 10;
priceFeedMessages2[0].publishTime = 5;
(
bytes[] memory updateData,
uint updateFee
) = createWormholeMerkleUpdateData(priceFeedMessages1);
pyth.updatePriceFeeds{value: updateFee}(updateData);
assertPriceFeedMessageStored(priceFeedMessages1[0]);
(updateData, updateFee) = createWormholeMerkleUpdateData(
priceFeedMessages2
);
pyth.updatePriceFeeds{value: updateFee}(updateData);
// Make sure that the old value is still stored
assertPriceFeedMessageStored(priceFeedMessages1[0]);
}
function isNotMatch(
bytes memory a,
bytes memory b
) public pure returns (bool) {
return keccak256(a) != keccak256(b);
}
/// @notice This method creates a forged invalid wormhole update data.
/// The caller should pass the forgeItem as string and if it matches the
/// expected value, that item will be forged to be invalid.
function createAndForgeWormholeMerkleUpdateData(
bytes memory forgeItem
) public returns (bytes[] memory updateData, uint updateFee) {
uint numPriceFeeds = 10;
PriceFeedMessage[]
memory priceFeedMessages = generateRandomPriceFeedMessage(
numPriceFeeds
);
bytes[] memory encodedPriceFeedMessages = encodePriceFeedMessages(
priceFeedMessages
);
(bytes20 rootDigest, bytes[] memory proofs) = MerkleTree
.constructProofs(encodedPriceFeedMessages, 4); // 4 is the depth of the tree (enough for 16 messages)
bytes memory wormholePayload;
unchecked {
wormholePayload = abi.encodePacked(
isNotMatch(forgeItem, "whMagic")
? uint32(0x41555756)
: uint32(0x41555750),
isNotMatch(forgeItem, "whUpdateType")
? uint8(PythAccumulator.UpdateType.WormholeMerkle)
: uint8(PythAccumulator.UpdateType.WormholeMerkle) + 1,
uint32(0), // Storage index, not used in target networks
isNotMatch(forgeItem, "rootDigest")
? rootDigest
: bytes20(uint160(rootDigest) + 1)
);
}
bytes memory wormholeMerkleVaa = generateVaa(
0,
isNotMatch(forgeItem, "whSourceChain")
? SOURCE_EMITTER_CHAIN_ID
: SOURCE_EMITTER_CHAIN_ID + 1,
isNotMatch(forgeItem, "whSourceAddress")
? SOURCE_EMITTER_ADDRESS
: bytes32(
0x71f8dcb863d176e2c420ad6610cf687359612b6fb392e0642b0ca6b1f186aa00
),
0,
wormholePayload,
1 // num signers
);
updateData = new bytes[](1);
updateData[0] = abi.encodePacked(
isNotMatch(forgeItem, "headerMagic")
? uint32(0x504e4155)
: uint32(0x504e4150), // PythAccumulator.ACCUMULATOR_MAGIC
isNotMatch(forgeItem, "headerMajorVersion") ? uint8(1) : uint8(2), // major version
uint8(0), // minor version
uint8(0), // trailing header size
uint8(PythAccumulator.UpdateType.WormholeMerkle),
uint16(wormholeMerkleVaa.length),
wormholeMerkleVaa,
uint8(priceFeedMessages.length)
);
for (uint i = 0; i < priceFeedMessages.length; i++) {
updateData[0] = abi.encodePacked(
updateData[0],
uint16(encodedPriceFeedMessages[i].length),
encodedPriceFeedMessages[i],
isNotMatch(forgeItem, "proofItem") ? proofs[i] : proofs[0]
);
}
updateFee = pyth.getUpdateFee(updateData);
}
function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongVAAPayloadMagic()
public
{
// In this test the Wormhole accumulator magic is wrong and the update gets reverted.
(
bytes[] memory updateData,
uint updateFee
) = createAndForgeWormholeMerkleUpdateData("whMagic");
vm.expectRevert(PythErrors.InvalidUpdateData.selector);
pyth.updatePriceFeeds{value: updateFee}(updateData);
}
function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongVAAPayloadUpdateType()
public
{
// In this test the Wormhole accumulator magic is wrong and the update gets
// reverted.
(
bytes[] memory updateData,
uint updateFee
) = createAndForgeWormholeMerkleUpdateData("whUpdateType");
vm.expectRevert(); // Reason: Conversion into non-existent enum type. However it
// was not possible to check the revert reason in the test.
pyth.updatePriceFeeds{value: updateFee}(updateData);
}
function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongVAASource()
public
{
// In this test the Wormhole message source is wrong.
(
bytes[] memory updateData,
uint updateFee
) = createAndForgeWormholeMerkleUpdateData("whSourceAddress");
vm.expectRevert(PythErrors.InvalidUpdateDataSource.selector);
pyth.updatePriceFeeds{value: updateFee}(updateData);
(updateData, updateFee) = createAndForgeWormholeMerkleUpdateData(
"whSourceChain"
);
vm.expectRevert(PythErrors.InvalidUpdateDataSource.selector);
pyth.updatePriceFeeds{value: updateFee}(updateData);
}
function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongRootDigest()
public
{
// In this test the Wormhole merkle proof digest is wrong
(
bytes[] memory updateData,
uint updateFee
) = createAndForgeWormholeMerkleUpdateData("rootDigest");
vm.expectRevert(PythErrors.InvalidUpdateData.selector);
pyth.updatePriceFeeds{value: updateFee}(updateData);
}
function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongProofItem()
public
{
// In this test all Wormhole merkle proof items are the first item proof
(
bytes[] memory updateData,
uint updateFee
) = createAndForgeWormholeMerkleUpdateData("proofItem");
vm.expectRevert(PythErrors.InvalidUpdateData.selector);
pyth.updatePriceFeeds{value: updateFee}(updateData);
}
function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongHeader()
public
{
// In this test the message headers are wrong
(
bytes[] memory updateData,
uint updateFee
) = createAndForgeWormholeMerkleUpdateData("headerMagic");
vm.expectRevert(); // The revert reason is not deterministic because when it doesn't match it goes through
// the old approach.
pyth.updatePriceFeeds{value: updateFee}(updateData);
(updateData, updateFee) = createAndForgeWormholeMerkleUpdateData(
"headerMajorVersion"
);
vm.expectRevert(PythErrors.InvalidUpdateData.selector);
pyth.updatePriceFeeds{value: updateFee}(updateData);
}
function testUpdatePriceFeedWithWormholeMerkleRevertsIfUpdateFeeIsNotPaid()
public
{
uint numPriceFeeds = (getRandUint() % 10) + 1;
PriceFeedMessage[]
memory priceFeedMessages = generateRandomPriceFeedMessage(
numPriceFeeds
);
(bytes[] memory updateData, ) = createWormholeMerkleUpdateData(
priceFeedMessages
);
vm.expectRevert(PythErrors.InsufficientFee.selector);
pyth.updatePriceFeeds{value: 0}(updateData);
}
}

View File

@ -77,7 +77,7 @@ contract PythTest is Test, WormholeTestUtils, PythTestUtils, RandTestUtils {
batchAttestations[j - i] = attestations[j]; batchAttestations[j - i] = attestations[j];
} }
updateData[i / batchSize] = generatePriceFeedUpdateVAA( updateData[i / batchSize] = generateWhBatchUpdate(
batchAttestations, batchAttestations,
0, 0,
1 1

View File

@ -186,7 +186,7 @@ contract VerificationExperiments is
function generateWormholeUpdateDataAndFee( function generateWormholeUpdateDataAndFee(
PythStructs.Price[] memory prices PythStructs.Price[] memory prices
) internal returns (bytes[] memory updateData, uint updateFee) { ) internal returns (bytes[] memory updateData, uint updateFee) {
bytes memory vaa = generatePriceFeedUpdateVAA( bytes memory vaa = generateWhBatchUpdate(
pricesToPriceAttestations(priceIds, prices), pricesToPriceAttestations(priceIds, prices),
sequence, sequence,
NUM_GUARDIAN_SIGNERS NUM_GUARDIAN_SIGNERS
@ -310,7 +310,7 @@ contract VerificationExperiments is
return ThresholdUpdate(signature, data); return ThresholdUpdate(signature, data);
} }
function testWormholeBatchUpdate() public { function testWhBatchUpdate() public {
pyth.updatePriceFeeds{value: freshPricesUpdateFee}( pyth.updatePriceFeeds{value: freshPricesUpdateFee}(
freshPricesUpdateData freshPricesUpdateData
); );

View File

@ -4,6 +4,10 @@ pragma solidity ^0.8.0;
import "../../contracts/pyth/PythUpgradable.sol"; import "../../contracts/pyth/PythUpgradable.sol";
import "../../contracts/pyth/PythInternalStructs.sol"; import "../../contracts/pyth/PythInternalStructs.sol";
import "../../contracts/pyth/PythAccumulator.sol";
import "../../contracts/libraries/MerkleTree.sol";
import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol";
import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol"; import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol";
import "@pythnetwork/pyth-sdk-solidity/IPythEvents.sol"; import "@pythnetwork/pyth-sdk-solidity/IPythEvents.sol";
@ -74,9 +78,89 @@ abstract contract PythTestUtils is Test, WormholeTestUtils {
uint64 prevConf; uint64 prevConf;
} }
struct PriceFeedMessage {
bytes32 priceId;
int64 price;
uint64 conf;
int32 expo;
uint64 publishTime;
uint64 prevPublishTime;
int64 emaPrice;
uint64 emaConf;
}
function encodePriceFeedMessages(
PriceFeedMessage[] memory priceFeedMessages
) internal pure returns (bytes[] memory encodedPriceFeedMessages) {
encodedPriceFeedMessages = new bytes[](priceFeedMessages.length);
for (uint i = 0; i < priceFeedMessages.length; i++) {
encodedPriceFeedMessages[i] = abi.encodePacked(
uint8(PythAccumulator.MessageType.PriceFeed),
priceFeedMessages[i].priceId,
priceFeedMessages[i].price,
priceFeedMessages[i].conf,
priceFeedMessages[i].expo,
priceFeedMessages[i].publishTime,
priceFeedMessages[i].prevPublishTime,
priceFeedMessages[i].emaPrice,
priceFeedMessages[i].emaConf
);
}
}
function generateWhMerkleUpdate(
PriceFeedMessage[] memory priceFeedMessages,
uint8 depth,
uint8 numSigners
) internal returns (bytes memory whMerkleUpdateData) {
bytes[] memory encodedPriceFeedMessages = encodePriceFeedMessages(
priceFeedMessages
);
(bytes20 rootDigest, bytes[] memory proofs) = MerkleTree
.constructProofs(encodedPriceFeedMessages, depth);
bytes memory wormholePayload = abi.encodePacked(
uint32(0x41555756), // PythAccumulator.ACCUMULATOR_WORMHOLE_MAGIC
uint8(PythAccumulator.UpdateType.WormholeMerkle),
uint32(0), // Storage index, not used in target networks
rootDigest
);
bytes memory wormholeMerkleVaa = generateVaa(
0,
SOURCE_EMITTER_CHAIN_ID,
SOURCE_EMITTER_ADDRESS,
0,
wormholePayload,
numSigners
);
whMerkleUpdateData = abi.encodePacked(
uint32(0x504e4155), // PythAccumulator.ACCUMULATOR_MAGIC
uint8(1), // major version
uint8(0), // minor version
uint8(0), // trailing header size
uint8(PythAccumulator.UpdateType.WormholeMerkle),
uint16(wormholeMerkleVaa.length),
wormholeMerkleVaa,
uint8(priceFeedMessages.length)
);
for (uint i = 0; i < priceFeedMessages.length; i++) {
whMerkleUpdateData = abi.encodePacked(
whMerkleUpdateData,
uint16(encodedPriceFeedMessages[i].length),
encodedPriceFeedMessages[i],
proofs[i]
);
}
}
// Generates byte-encoded payload for the given price attestations. You can use this to mock wormhole // Generates byte-encoded payload for the given price attestations. You can use this to mock wormhole
// call using `vm.mockCall` and return a VM struct with this payload. // call using `vm.mockCall` and return a VM struct with this payload.
// You can use generatePriceFeedUpdateVAA to generate a VAA for a price update. // You can use generatePriceFeedUpdate to generate a VAA for a price update.
function generatePriceFeedUpdatePayload( function generatePriceFeedUpdatePayload(
PriceAttestation[] memory attestations PriceAttestation[] memory attestations
) public pure returns (bytes memory payload) { ) public pure returns (bytes memory payload) {
@ -124,7 +208,7 @@ abstract contract PythTestUtils is Test, WormholeTestUtils {
// Generates a VAA for the given attestations. // Generates a VAA for the given attestations.
// This method calls generatePriceFeedUpdatePayload and then creates a VAA with it. // This method calls generatePriceFeedUpdatePayload and then creates a VAA with it.
// The VAAs generated from this method use block timestamp as their timestamp. // The VAAs generated from this method use block timestamp as their timestamp.
function generatePriceFeedUpdateVAA( function generateWhBatchUpdate(
PriceAttestation[] memory attestations, PriceAttestation[] memory attestations,
uint64 sequence, uint64 sequence,
uint8 numSigners uint8 numSigners
@ -170,6 +254,24 @@ abstract contract PythTestUtils is Test, WormholeTestUtils {
attestations[i].prevConf = prices[i].conf; attestations[i].prevConf = prices[i].conf;
} }
} }
function pricesToPriceFeedMessages(
bytes32[] memory priceIds,
PythStructs.Price[] memory prices
) public returns (PriceFeedMessage[] memory priceFeedMessages) {
assertGe(priceIds.length, prices.length);
priceFeedMessages = new PriceFeedMessage[](prices.length);
for (uint i = 0; i < prices.length; ++i) {
priceFeedMessages[i].priceId = priceIds[i];
priceFeedMessages[i].price = prices[i].price;
priceFeedMessages[i].conf = prices[i].conf;
priceFeedMessages[i].expo = prices[i].expo;
priceFeedMessages[i].publishTime = uint64(prices[i].publishTime);
priceFeedMessages[i].emaPrice = prices[i].price;
priceFeedMessages[i].emaConf = prices[i].conf;
}
}
} }
contract PythTestUtilsTest is contract PythTestUtilsTest is
@ -178,7 +280,7 @@ contract PythTestUtilsTest is
PythTestUtils, PythTestUtils,
IPythEvents IPythEvents
{ {
function testGeneratePriceFeedUpdateVAAWorks() public { function testGenerateWhBatchUpdateWorks() public {
IPyth pyth = IPyth( IPyth pyth = IPyth(
setUpPyth( setUpPyth(
setUpWormhole( setUpWormhole(
@ -200,7 +302,7 @@ contract PythTestUtilsTest is
1 // Publish time 1 // Publish time
); );
bytes memory vaa = generatePriceFeedUpdateVAA( bytes memory vaa = generateWhBatchUpdate(
pricesToPriceAttestations(priceIds, prices), pricesToPriceAttestations(priceIds, prices),
1, // Sequence 1, // Sequence
1 // No. Signers 1 // No. Signers
@ -211,7 +313,7 @@ contract PythTestUtilsTest is
uint updateFee = pyth.getUpdateFee(updateData); uint updateFee = pyth.getUpdateFee(updateData);
vm.expectEmit(true, true, false, true); vm.expectEmit(true, false, false, true);
emit PriceFeedUpdate(priceIds[0], 1, 100, 10); emit PriceFeedUpdate(priceIds[0], 1, 100, 10);
pyth.updatePriceFeeds{value: updateFee}(updateData); pyth.updatePriceFeeds{value: updateFee}(updateData);

View File

@ -38,4 +38,8 @@ contract RandTestUtils is Test {
function getRandInt32() internal returns (int32) { function getRandInt32() internal returns (int32) {
return int32(getRandUint32()); return int32(getRandUint32());
} }
function getRandUint8() internal returns (uint8) {
return uint8(getRandUint());
}
} }

View File

@ -1,6 +1,6 @@
{ {
"name": "@pythnetwork/pyth-evm-contract", "name": "@pythnetwork/pyth-evm-contract",
"version": "1.2.0", "version": "1.3.0-alpha",
"description": "", "description": "",
"private": "true", "private": "true",
"devDependencies": { "devDependencies": {