[eth] Add parsePriceFeedUpdates method (#392)

* Add the implementation with tests and benchmark

* Refactor the contract to reduce redundancy

* Reduce optimization runs as the contract was huge

It has 177 more gas usage on some benchmark tests
This commit is contained in:
Ali Behjati 2022-11-23 17:16:28 +01:00 committed by GitHub
parent 598b0dde1b
commit 275c7b8d1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 837 additions and 178 deletions

View File

@ -23,12 +23,7 @@ abstract contract Pyth is PythGetters, PythSetters, AbstractPyth {
}
function updatePriceBatchFromVm(bytes calldata encodedVm) private {
(IWormhole.VM memory vm, bool valid, string memory reason) = wormhole().parseAndVerifyVM(encodedVm);
require(valid, reason);
require(verifyPythVM(vm), "invalid data source chain/emitter ID");
parseAndProcessBatchPriceAttestation(vm);
parseAndProcessBatchPriceAttestation(parseAndVerifyBatchAttestationVM(encodedVm));
}
function updatePriceFeeds(bytes[] calldata updateData) public override payable {
@ -62,65 +57,39 @@ abstract contract Pyth is PythGetters, PythSetters, AbstractPyth {
// operations have proper require.
unchecked {
bytes memory encoded = vm.payload;
uint index = 0;
// Check header
{
uint32 magic = UnsafeBytesLib.toUint32(encoded, index);
index += 4;
require(magic == 0x50325748, "invalid magic value");
uint16 versionMajor = UnsafeBytesLib.toUint16(encoded, index);
index += 2;
require(versionMajor == 3, "invalid version major, expected 3");
uint16 versionMinor = UnsafeBytesLib.toUint16(encoded, index);
index += 2;
require(versionMinor >= 0, "invalid version minor, expected 0 or more");
uint16 hdrSize = UnsafeBytesLib.toUint16(encoded, index);
index += 2;
// NOTE(2022-04-19): Currently, only payloadId comes after
// hdrSize. Future extra header fields must be read using a
// separate offset to respect hdrSize, i.e.:
//
// uint hdrIndex = 0;
// bpa.header.payloadId = UnsafeBytesLib.toUint8(encoded, index + hdrIndex);
// hdrIndex += 1;
//
// bpa.header.someNewField = UnsafeBytesLib.toUint32(encoded, index + hdrIndex);
// hdrIndex += 4;
//
// // Skip remaining unknown header bytes
// index += bpa.header.hdrSize;
uint8 payloadId = UnsafeBytesLib.toUint8(encoded, index);
// Skip remaining unknown header bytes
index += hdrSize;
// Payload ID of 2 required for batch headerBa
require(payloadId == 2, "invalid payload ID, expected 2 for BatchPriceAttestation");
}
// Parse the number of attestations
uint16 nAttestations = UnsafeBytesLib.toUint16(encoded, index);
index += 2;
// Parse the attestation size
uint16 attestationSize = UnsafeBytesLib.toUint16(encoded, index);
index += 2;
// Given the message is valid the arithmetic below should not overflow, and
// even if it overflows then the require would fail.
require(encoded.length == (index + (attestationSize * nAttestations)), "invalid BatchPriceAttestation size");
PythInternalStructs.PriceInfo memory info;
bytes32 priceId;
(uint index, uint nAttestations, uint attestationSize) =
parseBatchAttestationHeader(encoded);
// Deserialize each attestation
for (uint j=0; j < nAttestations; j++) {
(PythInternalStructs.PriceInfo memory info, bytes32 priceId) = parseSingleAttestationFromBatch(encoded, index, attestationSize);
// Respect specified attestation size for forward-compat
index += attestationSize;
// Store the attestation
uint64 latestPublishTime = latestPriceInfoPublishTime(priceId);
if(info.publishTime > latestPublishTime) {
setLatestPriceInfo(priceId, info);
emit PriceFeedUpdate(priceId, info.publishTime, info.price, info.conf);
}
}
emit BatchPriceFeedUpdate(vm.emitterChainId, vm.sequence);
}
}
function parseSingleAttestationFromBatch(
bytes memory encoded,
uint index,
uint attestationSize
) internal pure returns (
PythInternalStructs.PriceInfo memory info,
bytes32 priceId
) {
unchecked {
// NOTE: We don't advance the global index immediately.
// attestationIndex is an attestation-local offset used
// for readability and easier debugging.
@ -188,22 +157,7 @@ abstract contract Pyth is PythGetters, PythSetters, AbstractPyth {
}
}
require(attestationIndex <= attestationSize, "INTERNAL: Consumed more than `attestationSize` bytes");
// Respect specified attestation size for forward-compat
index += attestationSize;
// Store the attestation
uint64 latestPublishTime = latestPriceInfoPublishTime(priceId);
if(info.publishTime > latestPublishTime) {
setLatestPriceInfo(priceId, info);
emit PriceFeedUpdate(priceId, info.publishTime, info.price, info.conf);
}
}
emit BatchPriceFeedUpdate(vm.emitterChainId, vm.sequence);
}
}
@ -267,14 +221,168 @@ abstract contract Pyth is PythGetters, PythSetters, AbstractPyth {
require(price.publishTime != 0, "price feed for the given id is not pushed or does not exist");
}
function parseBatchAttestationHeader(
bytes memory encoded
) internal pure returns (
uint index,
uint nAttestations,
uint attestationSize
) {
unchecked {
index = 0;
// Check header
{
uint32 magic = UnsafeBytesLib.toUint32(encoded, index);
index += 4;
require(magic == 0x50325748, "invalid magic value");
uint16 versionMajor = UnsafeBytesLib.toUint16(encoded, index);
index += 2;
require(versionMajor == 3, "invalid version major, expected 3");
uint16 versionMinor = UnsafeBytesLib.toUint16(encoded, index);
index += 2;
require(versionMinor >= 0, "invalid version minor, expected 0 or more");
uint16 hdrSize = UnsafeBytesLib.toUint16(encoded, index);
index += 2;
// NOTE(2022-04-19): Currently, only payloadId comes after
// hdrSize. Future extra header fields must be read using a
// separate offset to respect hdrSize, i.e.:
//
// uint hdrIndex = 0;
// bpa.header.payloadId = UnsafeBytesLib.toUint8(encoded, index + hdrIndex);
// hdrIndex += 1;
//
// bpa.header.someNewField = UnsafeBytesLib.toUint32(encoded, index + hdrIndex);
// hdrIndex += 4;
//
// // Skip remaining unknown header bytes
// index += bpa.header.hdrSize;
uint8 payloadId = UnsafeBytesLib.toUint8(encoded, index);
// Skip remaining unknown header bytes
index += hdrSize;
// Payload ID of 2 required for batch headerBa
require(payloadId == 2, "invalid payload ID, expected 2 for BatchPriceAttestation");
}
// Parse the number of attestations
nAttestations = UnsafeBytesLib.toUint16(encoded, index);
index += 2;
// Parse the attestation size
attestationSize = UnsafeBytesLib.toUint16(encoded, index);
index += 2;
// Given the message is valid the arithmetic below should not overflow, and
// even if it overflows then the require would fail.
require(encoded.length == (index + (attestationSize * nAttestations)), "invalid BatchPriceAttestation size");
}
}
function parseAndVerifyBatchAttestationVM(
bytes calldata encodedVm
) internal view returns (
IWormhole.VM memory vm
) {
{
bool valid;
string memory reason;
(vm, valid, reason) = wormhole().parseAndVerifyVM(encodedVm);
require(valid, reason);
}
require(verifyPythVM(vm), "invalid data source chain/emitter ID");
}
function parsePriceFeedUpdates(
bytes[] calldata updateData,
bytes32[] calldata priceIds,
uint64 minPublishTime,
uint64 maxPublishTime
) external payable override returns (PythStructs.PriceFeed[] memory priceFeeds) {
// TODO: To be implemented soon.
revert("unimplemented");
unchecked {
{
uint requiredFee = getUpdateFee(updateData);
require(msg.value >= requiredFee, "insufficient paid fee amount");
}
priceFeeds = new PythStructs.PriceFeed[](priceIds.length);
for(uint i = 0; i < updateData.length; i++) {
bytes memory encoded;
{
IWormhole.VM memory vm = parseAndVerifyBatchAttestationVM(updateData[i]);
encoded = vm.payload;
}
(uint index, uint nAttestations, uint attestationSize) =
parseBatchAttestationHeader(encoded);
// Deserialize each attestation
for (uint j=0; j < nAttestations; j++) {
// NOTE: We don't advance the global index immediately.
// attestationIndex is an attestation-local offset used
// for readability and easier debugging.
uint attestationIndex = 0;
// Unused bytes32 product id
attestationIndex += 32;
bytes32 priceId = UnsafeBytesLib.toBytes32(encoded, index + attestationIndex);
// Check whether the caller requested for this data.
uint k = 0;
for(; k < priceIds.length; k++) {
if (priceIds[k] == priceId) {
break;
}
}
// If priceFeed[k].id != 0 then it means that there was a valid
// update for priceIds[k] and we don't need to process this one.
if (k == priceIds.length || priceFeeds[k].id != 0) {
index += attestationSize;
continue;
}
(PythInternalStructs.PriceInfo memory info, ) = parseSingleAttestationFromBatch(encoded, index, attestationSize);
require(info.publishTime != 0, "price feed for the given id is not pushed or does not exist");
priceFeeds[k].id = priceId;
priceFeeds[k].price.price = info.price;
priceFeeds[k].price.conf = info.conf;
priceFeeds[k].price.expo = info.expo;
priceFeeds[k].price.publishTime = uint(info.publishTime);
priceFeeds[k].emaPrice.price = info.emaPrice;
priceFeeds[k].emaPrice.conf = info.emaConf;
priceFeeds[k].emaPrice.expo = info.expo;
priceFeeds[k].emaPrice.publishTime = uint(info.publishTime);
// Check the publish time of the price is within the given range
// if it is not, then set the id to 0 to indicate that this price id
// still does not have a valid price feed. This will allow other updates
// for this price id to be processed.
if (priceFeeds[k].price.publishTime < minPublishTime ||
priceFeeds[k].price.publishTime > maxPublishTime) {
priceFeeds[k].id = 0;
}
index += attestationSize;
}
}
for (uint k = 0; k < priceIds.length; k++) {
require(priceFeeds[k].id != 0,
"1 or more price feeds are not found in the updateData or they are out of the given time range");
}
}
}
function queryPriceFeed(bytes32 id) public view override returns (PythStructs.PriceFeed memory priceFeed){

View File

@ -86,8 +86,7 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils {
function generateUpdateDataAndFee(PythStructs.Price[] memory prices) internal returns (bytes[] memory updateData, uint updateFee) {
bytes memory vaa = generatePriceFeedUpdateVAA(
priceIds,
prices,
pricesToPriceAttestations(priceIds, prices),
sequence,
NUM_GUARDIAN_SIGNERS
);
@ -122,6 +121,29 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils {
pyth.updatePriceFeedsIfNecessary{value: cachedPricesUpdateFee}(cachedPricesUpdateData, priceIds, cachedPricesPublishTimes);
}
function testBenchmarkParsePriceFeedUpdatesForOnePriceFeed() public {
bytes32[] memory ids = new bytes32[](1);
ids[0] = priceIds[0];
pyth.parsePriceFeedUpdates{value: freshPricesUpdateFee}(freshPricesUpdateData, ids, 0, 50);
}
function testBenchmarkParsePriceFeedUpdatesForTwoPriceFeed() public {
bytes32[] memory ids = new bytes32[](2);
ids[0] = priceIds[0];
ids[1] = priceIds[1];
pyth.parsePriceFeedUpdates{value: freshPricesUpdateFee}(freshPricesUpdateData, ids, 0, 50);
}
function testBenchmarkParsePriceFeedUpdatesForOnePriceFeedNotWithinRange() public {
bytes32[] memory ids = new bytes32[](1);
ids[0] = priceIds[0];
vm.expectRevert(bytes("1 or more price feeds are not found in the updateData or they are out of the given time range"));
pyth.parsePriceFeedUpdates{value: freshPricesUpdateFee}(freshPricesUpdateData, ids, 50, 100);
}
function testBenchmarkGetPrice() public {
// Set the block timestamp to 0. As prices have < 10 timestamp and staleness
// is set to 60 seconds, the getPrice should work as expected.

View File

@ -0,0 +1,444 @@
// SPDX-License-Identifier: Apache 2
pragma solidity ^0.8.0;
import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol";
import "forge-std/Test.sol";
import "@pythnetwork/pyth-sdk-solidity/IPyth.sol";
import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol";
import "./utils/WormholeTestUtils.t.sol";
import "./utils/PythTestUtils.t.sol";
import "./utils/RandTestUtils.t.sol";
contract PythTest is Test, WormholeTestUtils, PythTestUtils, RandTestUtils {
IPyth public pyth;
// -1 is equal to 0x111111 which is the biggest uint if converted back
uint64 constant MAX_UINT64 = uint64(int64(-1));
function setUp() public {
pyth = IPyth(setUpPyth(setUpWormhole(1)));
}
function generateRandomPriceAttestations(
uint length
) internal returns (
bytes32[] memory priceIds,
PriceAttestation[] memory attestations
) {
attestations = new PriceAttestation[](length);
priceIds = new bytes32[](length);
for(uint i = 0; i < length; i++) {
attestations[i].productId = getRandBytes32();
attestations[i].priceId = bytes32(i+1); // price ids should be non-zero and unique
attestations[i].price = getRandInt64();
attestations[i].conf = getRandUint64();
attestations[i].expo = getRandInt32();
attestations[i].emaPrice = getRandInt64();
attestations[i].emaConf = getRandUint64();
attestations[i].status = PriceAttestationStatus(getRandUint() % 2);
attestations[i].numPublishers = getRandUint32();
attestations[i].maxNumPublishers = getRandUint32();
attestations[i].attestationTime = getRandUint64();
attestations[i].publishTime = getRandUint64();
attestations[i].prevPublishTime = getRandUint64();
attestations[i].price = getRandInt64();
attestations[i].conf = getRandUint64();
priceIds[i] = attestations[i].priceId;
}
}
// This method divides attestations into a couple of batches and creates
// updateData for them. It returns the updateData and the updateFee
function createBatchedUpdateDataFromAttestations(
PriceAttestation[] memory attestations
) internal returns (bytes[] memory updateData, uint updateFee) {
uint batchSize = 1 + getRandUint() % attestations.length;
uint numBatches = (attestations.length + batchSize - 1) / batchSize;
updateData = new bytes[](numBatches);
for(uint i = 0; i < attestations.length; i += batchSize) {
uint len = batchSize;
if(attestations.length - i < len) {
len = attestations.length - i;
}
PriceAttestation[] memory batchAttestations = new PriceAttestation[](len);
for(uint j = i; j < i+len; j++) {
batchAttestations[j-i] = attestations[j];
}
updateData[i / batchSize] = generatePriceFeedUpdateVAA(
batchAttestations,
0,
1
);
}
updateFee = pyth.getUpdateFee(updateData);
}
/// Testing parsePriceFeedUpdates method.
function testParsePriceFeedUpdatesWorksWithTradingStatus(uint seed) public {
setRandSeed(seed);
uint numAttestations = 1 + getRandUint() % 10;
(bytes32[] memory priceIds, PriceAttestation[] memory attestations) =
generateRandomPriceAttestations(numAttestations);
for(uint i = 0; i < numAttestations; i++) {
attestations[i].status = PriceAttestationStatus.Trading;
}
(bytes[] memory updateData, uint updateFee) =
createBatchedUpdateDataFromAttestations(attestations);
PythStructs.PriceFeed[] memory priceFeeds =
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
MAX_UINT64
);
for(uint i = 0; i < numAttestations; i++) {
assertEq(priceFeeds[i].id, priceIds[i]);
assertEq(priceFeeds[i].price.price, attestations[i].price);
assertEq(priceFeeds[i].price.conf, attestations[i].conf);
assertEq(priceFeeds[i].price.expo, attestations[i].expo);
assertEq(priceFeeds[i].price.publishTime, attestations[i].publishTime);
assertEq(priceFeeds[i].emaPrice.price, attestations[i].emaPrice);
assertEq(priceFeeds[i].emaPrice.conf, attestations[i].emaConf);
assertEq(priceFeeds[i].emaPrice.expo, attestations[i].expo);
assertEq(priceFeeds[i].emaPrice.publishTime, attestations[i].publishTime);
}
}
function testParsePriceFeedUpdatesWorksWithUnknownStatus(uint seed) public {
setRandSeed(seed);
uint numAttestations = 1 + getRandUint() % 10;
(bytes32[] memory priceIds, PriceAttestation[] memory attestations) =
generateRandomPriceAttestations(numAttestations);
for(uint i = 0; i < numAttestations; i++) {
attestations[i].status = PriceAttestationStatus.Unknown;
}
(bytes[] memory updateData, uint updateFee) =
createBatchedUpdateDataFromAttestations(attestations);
PythStructs.PriceFeed[] memory priceFeeds =
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
MAX_UINT64
);
for(uint i = 0; i < numAttestations; i++) {
assertEq(priceFeeds[i].id, priceIds[i]);
assertEq(priceFeeds[i].price.price, attestations[i].prevPrice);
assertEq(priceFeeds[i].price.conf, attestations[i].prevConf);
assertEq(priceFeeds[i].price.expo, attestations[i].expo);
assertEq(priceFeeds[i].price.publishTime, attestations[i].prevPublishTime);
assertEq(priceFeeds[i].emaPrice.price, attestations[i].emaPrice);
assertEq(priceFeeds[i].emaPrice.conf, attestations[i].emaConf);
assertEq(priceFeeds[i].emaPrice.expo, attestations[i].expo);
assertEq(priceFeeds[i].emaPrice.publishTime, attestations[i].prevPublishTime);
}
}
function testParsePriceFeedUpdatesWorksWithRandomDistinctUpdatesInput(uint seed) public {
setRandSeed(seed);
uint numAttestations = 1 + getRandUint() % 30;
(bytes32[] memory priceIds, PriceAttestation[] memory attestations) =
generateRandomPriceAttestations(numAttestations);
(bytes[] memory updateData, uint updateFee) =
createBatchedUpdateDataFromAttestations(attestations);
// Shuffle the attestations
for (uint i = 1; i < numAttestations; i++) {
uint swapWith = getRandUint() % (i+1);
(attestations[i], attestations[swapWith]) = (attestations[swapWith], attestations[i]);
(priceIds[i], priceIds[swapWith]) = (priceIds[swapWith], priceIds[i]);
}
// Select only first numSelectedAttestations. numSelectedAttestations will be in [0, numAttestations]
uint numSelectedAttestations = getRandUint() % (numAttestations + 1);
PriceAttestation[] memory selectedAttestations = new PriceAttestation[](numSelectedAttestations);
bytes32[] memory selectedPriceIds = new bytes32[](numSelectedAttestations);
for (uint i = 0; i < numSelectedAttestations; i++) {
selectedAttestations[i] = attestations[i];
selectedPriceIds[i] = priceIds[i];
}
// Only parse selected attestations
PythStructs.PriceFeed[] memory priceFeeds =
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
selectedPriceIds,
0,
MAX_UINT64
);
for(uint i = 0; i < numSelectedAttestations; i++) {
assertEq(priceFeeds[i].id, selectedPriceIds[i]);
assertEq(priceFeeds[i].price.expo, selectedAttestations[i].expo);
assertEq(priceFeeds[i].emaPrice.price, selectedAttestations[i].emaPrice);
assertEq(priceFeeds[i].emaPrice.conf, selectedAttestations[i].emaConf);
assertEq(priceFeeds[i].emaPrice.expo, selectedAttestations[i].expo);
if (selectedAttestations[i].status == PriceAttestationStatus.Trading) {
assertEq(priceFeeds[i].price.price, selectedAttestations[i].price);
assertEq(priceFeeds[i].price.conf, selectedAttestations[i].conf);
assertEq(priceFeeds[i].price.publishTime, selectedAttestations[i].publishTime);
assertEq(priceFeeds[i].emaPrice.publishTime, selectedAttestations[i].publishTime);
} else {
assertEq(priceFeeds[i].price.price, selectedAttestations[i].prevPrice);
assertEq(priceFeeds[i].price.conf, selectedAttestations[i].prevConf);
assertEq(priceFeeds[i].price.publishTime, selectedAttestations[i].prevPublishTime);
assertEq(priceFeeds[i].emaPrice.publishTime, selectedAttestations[i].prevPublishTime);
}
}
}
function testParsePriceFeedUpdatesWorksWithOverlappingWithinTimeRangeUpdates() public {
PriceAttestation[] memory attestations = new PriceAttestation[](2);
attestations[0].priceId = bytes32(uint(1));
attestations[0].status = PriceAttestationStatus.Trading;
attestations[0].price = 1000;
attestations[0].publishTime = 10;
attestations[1].priceId = bytes32(uint(1));
attestations[1].status = PriceAttestationStatus.Trading;
attestations[1].price = 2000;
attestations[1].publishTime = 20;
(bytes[] memory updateData, uint updateFee) =
createBatchedUpdateDataFromAttestations(attestations);
bytes32[] memory priceIds = new bytes32[](1);
priceIds[0] = bytes32(uint(1));
PythStructs.PriceFeed[] memory priceFeeds =
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
20
);
assertEq(priceFeeds.length, 1);
assertEq(priceFeeds[0].id, bytes32(uint(1)));
assertTrue(
(priceFeeds[0].price.price == 1000 && priceFeeds[0].price.publishTime == 10) ||
(priceFeeds[0].price.price == 2000 && priceFeeds[0].price.publishTime == 20)
);
}
function testParsePriceFeedUpdatesWorksWithOverlappingMixedTimeRangeUpdates() public {
PriceAttestation[] memory attestations = new PriceAttestation[](2);
attestations[0].priceId = bytes32(uint(1));
attestations[0].status = PriceAttestationStatus.Trading;
attestations[0].price = 1000;
attestations[0].publishTime = 10;
attestations[1].priceId = bytes32(uint(1));
attestations[1].status = PriceAttestationStatus.Trading;
attestations[1].price = 2000;
attestations[1].publishTime = 20;
(bytes[] memory updateData, uint updateFee) =
createBatchedUpdateDataFromAttestations(attestations);
bytes32[] memory priceIds = new bytes32[](1);
priceIds[0] = bytes32(uint(1));
PythStructs.PriceFeed[] memory priceFeeds =
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
5,
15
);
assertEq(priceFeeds.length, 1);
assertEq(priceFeeds[0].id, bytes32(uint(1)));
assertEq(priceFeeds[0].price.price, 1000);
assertEq(priceFeeds[0].price.publishTime, 10);
priceFeeds =
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
15,
25
);
assertEq(priceFeeds.length, 1);
assertEq(priceFeeds[0].id, bytes32(uint(1)));
assertEq(priceFeeds[0].price.price, 2000);
assertEq(priceFeeds[0].price.publishTime, 20);
}
function testParsePriceFeedUpdatesRevertsIfUpdateFeeIsNotPaid() public {
uint numAttestations = 10;
(bytes32[] memory priceIds, PriceAttestation[] memory attestations) =
generateRandomPriceAttestations(numAttestations);
(bytes[] memory updateData, uint updateFee) =
createBatchedUpdateDataFromAttestations(attestations);
// Since attestations are not empty the fee should be at least 1
assertGe(updateFee, 1);
vm.expectRevert(bytes("insufficient paid fee amount"));
pyth.parsePriceFeedUpdates{value: updateFee-1}(
updateData,
priceIds,
0,
MAX_UINT64
);
}
function testParsePriceFeedUpdatesRevertsIfUpdateVAAIsInvalid(uint seed) public {
setRandSeed(seed);
uint numAttestations = 1 + getRandUint() % 10;
(bytes32[] memory priceIds, PriceAttestation[] memory attestations) =
generateRandomPriceAttestations(numAttestations);
(bytes[] memory updateData, uint updateFee) =
createBatchedUpdateDataFromAttestations(attestations);
uint mutPos = getRandUint() % updateData[0].length;
// mutate the random position by 1 bit
updateData[0][mutPos] = bytes1(uint8(updateData[0][mutPos])^1);
// It might revert due to different wormhole errors
vm.expectRevert();
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
MAX_UINT64
);
}
function testParsePriceFeedUpdatesRevertsIfUpdateSourceChainIsInvalid() public {
uint numAttestations = 10;
(bytes32[] memory priceIds, PriceAttestation[] memory attestations) =
generateRandomPriceAttestations(numAttestations);
bytes[] memory updateData = new bytes[](1);
updateData[0] = generateVaa(
uint32(block.timestamp),
SOURCE_EMITTER_CHAIN_ID + 1,
SOURCE_EMITTER_ADDRESS,
1, // Sequence
generatePriceFeedUpdatePayload(attestations),
1 // Num signers
);
uint updateFee = pyth.getUpdateFee(updateData);
vm.expectRevert(bytes("invalid data source chain/emitter ID"));
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
MAX_UINT64
);
}
function testParsePriceFeedUpdatesRevertsIfUpdateSourceAddressIsInvalid() public {
uint numAttestations = 10;
(bytes32[] memory priceIds, PriceAttestation[] memory attestations) =
generateRandomPriceAttestations(numAttestations);
bytes[] memory updateData = new bytes[](1);
updateData[0] = generateVaa(
uint32(block.timestamp),
SOURCE_EMITTER_CHAIN_ID,
0x00000000000000000000000000000000000000000000000000000000000000aa, // Random wrong source address
1, // Sequence
generatePriceFeedUpdatePayload(attestations),
1 // Num signers
);
uint updateFee = pyth.getUpdateFee(updateData);
vm.expectRevert(bytes("invalid data source chain/emitter ID"));
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
MAX_UINT64
);
}
function testParsePriceFeedUpdatesRevertsIfPriceIdNotIncluded() public {
PriceAttestation[] memory attestations = new PriceAttestation[](1);
attestations[0].priceId = bytes32(uint(1));
attestations[0].status = PriceAttestationStatus.Trading;
attestations[0].price = 1000;
attestations[0].publishTime = 10;
(bytes[] memory updateData, uint updateFee) =
createBatchedUpdateDataFromAttestations(attestations);
bytes32[] memory priceIds = new bytes32[](1);
priceIds[0] = bytes32(uint(2));
vm.expectRevert(bytes("1 or more price feeds are not found in the updateData or they are out of the given time range"));
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
MAX_UINT64
);
}
function testParsePriceFeedUpdateRevertsIfPricesOutOfTimeRange() public {
uint numAttestations = 10;
(bytes32[] memory priceIds, PriceAttestation[] memory attestations) =
generateRandomPriceAttestations(numAttestations);
for(uint i = 0; i < numAttestations; i++) {
// Set status to Trading so publishTime is used
attestations[i].status = PriceAttestationStatus.Trading;
attestations[i].publishTime = uint64(100 + getRandUint() % 101); // All between [100, 200]
}
(bytes[] memory updateData, uint updateFee) =
createBatchedUpdateDataFromAttestations(attestations);
// Request for parse within the given time range should work
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
100,
200
);
// Request for parse after the time range should revert.
vm.expectRevert(bytes("1 or more price feeds are not found in the updateData or they are out of the given time range"));
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
300,
MAX_UINT64
);
}
}

View File

@ -53,44 +53,63 @@ abstract contract PythTestUtils is Test, WormholeTestUtils {
return address(pyth);
}
// Generates byte-encoded payload for the given prices. It sets the emaPrice the same
// as the given price. You can use this to mock wormhole call using `vm.mockCall` and
// return a VM struct with this payload.
/// Utilities to help generating price attestations and VAAs for them
enum PriceAttestationStatus {
Unknown,
Trading
}
struct PriceAttestation {
bytes32 productId;
bytes32 priceId;
int64 price;
uint64 conf;
int32 expo;
int64 emaPrice;
uint64 emaConf;
PriceAttestationStatus status;
uint32 numPublishers;
uint32 maxNumPublishers;
uint64 attestationTime;
uint64 publishTime;
uint64 prevPublishTime;
int64 prevPrice;
uint64 prevConf;
}
// Generates byte-encoded payload for the given price attestations. You can use this to mock wormhole
// call using `vm.mockCall` and return a VM struct with this payload.
// You can use generatePriceFeedUpdateVAA to generate a VAA for a price update.
function generatePriceFeedUpdatePayload(
bytes32[] memory priceIds,
PythStructs.Price[] memory prices
) public returns (bytes memory payload) {
assertEq(priceIds.length, prices.length);
PriceAttestation[] memory attestations
) public pure returns (bytes memory payload) {
bytes memory encodedAttestations = new bytes(0);
bytes memory attestations = new bytes(0);
for (uint i = 0; i < prices.length; ++i) {
for (uint i = 0; i < attestations.length; ++i) {
// encodePacked uses padding for arrays and we don't want it, so we manually concat them.
attestations = abi.encodePacked(
attestations,
priceIds[i], // Product ID, we use the same price Id. This field is not used.
priceIds[i], // Price ID,
prices[i].price, // Price
prices[i].conf, // Confidence
prices[i].expo, // Exponent
prices[i].price, // EMA price
prices[i].conf // EMA confidence
encodedAttestations = abi.encodePacked(
encodedAttestations,
attestations[i].productId,
attestations[i].priceId,
attestations[i].price,
attestations[i].conf,
attestations[i].expo,
attestations[i].emaPrice,
attestations[i].emaConf
);
// Breaking this in two encodePackes because of the limited EVM stack.
attestations = abi.encodePacked(
attestations,
uint8(1), // status = 1 = Trading
uint32(5), // Number of publishers. This field is not used.
uint32(10), // Maximum number of publishers. This field is not used.
uint64(prices[i].publishTime), // Attestation time. This field is not used.
uint64(prices[i].publishTime), // Publish time.
// Previous values are unused as status is trading. We use the same value
// to make sure the test is irrelevant of the logic of which price is chosen.
uint64(prices[i].publishTime), // Previous publish time.
prices[i].price, // Previous price
prices[i].conf // Previous confidence
encodedAttestations = abi.encodePacked(
encodedAttestations,
uint8(attestations[i].status),
attestations[i].numPublishers,
attestations[i].maxNumPublishers,
attestations[i].attestationTime,
attestations[i].publishTime,
attestations[i].prevPublishTime,
attestations[i].prevPrice,
attestations[i].prevConf
);
}
@ -100,24 +119,22 @@ abstract contract PythTestUtils is Test, WormholeTestUtils {
uint16(0), // Minor version
uint16(1), // Header size of 1 byte as it only contains payloadId
uint8(2), // Payload ID 2 means it's a batch price attestation
uint16(prices.length), // Number of attestations
uint16(attestations.length / prices.length), // Size of a single price attestation.
attestations
uint16(attestations.length), // Number of attestations
uint16(encodedAttestations.length / attestations.length), // Size of a single price attestation.
encodedAttestations
);
}
// Generates a VAA for the given prices.
// Generates a VAA for the given attestations.
// This method calls generatePriceFeedUpdatePayload and then creates a VAA with it.
// The VAAs generated from this method use block timestamp as their timestamp.
function generatePriceFeedUpdateVAA(
bytes32[] memory priceIds,
PythStructs.Price[] memory prices,
PriceAttestation[] memory attestations,
uint64 sequence,
uint8 numSigners
) public returns (bytes memory vaa) {
bytes memory payload = generatePriceFeedUpdatePayload(
priceIds,
prices
attestations
);
vaa = generateVaa(
@ -129,6 +146,36 @@ abstract contract PythTestUtils is Test, WormholeTestUtils {
numSigners
);
}
function pricesToPriceAttestations(
bytes32[] memory priceIds,
PythStructs.Price[] memory prices
) public returns (PriceAttestation[] memory attestations) {
assertEq(priceIds.length, prices.length);
attestations = new PriceAttestation[](prices.length);
for (uint i = 0; i < prices.length; ++i) {
// Product ID, we use the same price Id. This field is not used.
attestations[i].productId = priceIds[i];
attestations[i].priceId = priceIds[i];
attestations[i].price = prices[i].price;
attestations[i].conf = prices[i].conf;
attestations[i].expo = prices[i].expo;
// Same price and conf is used for emaPrice and emaConf
attestations[i].emaPrice = prices[i].price;
attestations[i].emaConf = prices[i].conf;
attestations[i].status = PriceAttestationStatus.Trading;
attestations[i].numPublishers = 5; // This field is not used
attestations[i].maxNumPublishers = 10; // This field is not used
attestations[i].attestationTime = uint64(prices[i].publishTime); // This field is not used
attestations[i].publishTime = uint64(prices[i].publishTime);
// Fields below are not used when status is Trading. just setting them to
// the same value as the prices.
attestations[i].prevPublishTime = uint64(prices[i].publishTime);
attestations[i].prevPrice = prices[i].price;
attestations[i].prevConf = prices[i].conf;
}
}
}
contract PythTestUtilsTest is Test, WormholeTestUtils, PythTestUtils, IPythEvents {
@ -149,8 +196,7 @@ contract PythTestUtilsTest is Test, WormholeTestUtils, PythTestUtils, IPythEvent
);
bytes memory vaa = generatePriceFeedUpdateVAA(
priceIds,
prices,
pricesToPriceAttestations(priceIds, prices),
1, // Sequence
1 // No. Signers
);

View File

@ -0,0 +1,39 @@
// SPDX-License-Identifier: Apache 2
pragma solidity ^0.8.0;
import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol";
import "forge-std/Test.sol";
contract RandTestUtils is Test {
uint randSeed;
function setRandSeed(uint seed) internal {
randSeed = seed;
}
function getRandBytes32() internal returns (bytes32) {
unchecked { randSeed++; }
return keccak256(abi.encode(randSeed));
}
function getRandUint() internal returns (uint) {
return uint(getRandBytes32());
}
function getRandUint64() internal returns (uint64) {
return uint64(getRandUint());
}
function getRandInt64() internal returns (int64) {
return int64(getRandUint64());
}
function getRandUint32() internal returns (uint32) {
return uint32(getRandUint());
}
function getRandInt32() internal returns (int32) {
return int32(getRandUint32());
}
}

View File

@ -1,7 +1,7 @@
[profile.default]
solc_version = '0.8.4'
optimizer = true
optimizer_runs = 10000
optimizer_runs = 5000
src = 'contracts'
# We put the tests into the forge-test directory (instead of test) so that
# truffle doesn't try to build them

View File

@ -230,7 +230,7 @@ module.exports = {
settings: {
optimizer: {
enabled: true,
runs: 10000,
runs: 5000,
},
},
},