[eth] Add parsePriceFeedUpdates method (#392)

* Add the implementation with tests and benchmark

* Refactor the contract to reduce redundancy

* Reduce optimization runs as the contract was huge

It has 177 more gas usage on some benchmark tests
This commit is contained in:
Ali Behjati 2022-11-23 17:16:28 +01:00 committed by GitHub
parent 598b0dde1b
commit 275c7b8d1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 837 additions and 178 deletions

View File

@ -23,12 +23,7 @@ abstract contract Pyth is PythGetters, PythSetters, AbstractPyth {
} }
function updatePriceBatchFromVm(bytes calldata encodedVm) private { function updatePriceBatchFromVm(bytes calldata encodedVm) private {
(IWormhole.VM memory vm, bool valid, string memory reason) = wormhole().parseAndVerifyVM(encodedVm); parseAndProcessBatchPriceAttestation(parseAndVerifyBatchAttestationVM(encodedVm));
require(valid, reason);
require(verifyPythVM(vm), "invalid data source chain/emitter ID");
parseAndProcessBatchPriceAttestation(vm);
} }
function updatePriceFeeds(bytes[] calldata updateData) public override payable { function updatePriceFeeds(bytes[] calldata updateData) public override payable {
@ -61,135 +56,14 @@ abstract contract Pyth is PythGetters, PythSetters, AbstractPyth {
// a comment explaining why it is safe. Also, byteslib // a comment explaining why it is safe. Also, byteslib
// operations have proper require. // operations have proper require.
unchecked { unchecked {
bytes memory encoded = vm.payload; bytes memory encoded = vm.payload;
uint index = 0;
// Check header (uint index, uint nAttestations, uint attestationSize) =
{ parseBatchAttestationHeader(encoded);
uint32 magic = UnsafeBytesLib.toUint32(encoded, index);
index += 4;
require(magic == 0x50325748, "invalid magic value");
uint16 versionMajor = UnsafeBytesLib.toUint16(encoded, index);
index += 2;
require(versionMajor == 3, "invalid version major, expected 3");
uint16 versionMinor = UnsafeBytesLib.toUint16(encoded, index);
index += 2;
require(versionMinor >= 0, "invalid version minor, expected 0 or more");
uint16 hdrSize = UnsafeBytesLib.toUint16(encoded, index);
index += 2;
// NOTE(2022-04-19): Currently, only payloadId comes after
// hdrSize. Future extra header fields must be read using a
// separate offset to respect hdrSize, i.e.:
//
// uint hdrIndex = 0;
// bpa.header.payloadId = UnsafeBytesLib.toUint8(encoded, index + hdrIndex);
// hdrIndex += 1;
//
// bpa.header.someNewField = UnsafeBytesLib.toUint32(encoded, index + hdrIndex);
// hdrIndex += 4;
//
// // Skip remaining unknown header bytes
// index += bpa.header.hdrSize;
uint8 payloadId = UnsafeBytesLib.toUint8(encoded, index);
// Skip remaining unknown header bytes
index += hdrSize;
// Payload ID of 2 required for batch headerBa
require(payloadId == 2, "invalid payload ID, expected 2 for BatchPriceAttestation");
}
// Parse the number of attestations
uint16 nAttestations = UnsafeBytesLib.toUint16(encoded, index);
index += 2;
// Parse the attestation size
uint16 attestationSize = UnsafeBytesLib.toUint16(encoded, index);
index += 2;
// Given the message is valid the arithmetic below should not overflow, and
// even if it overflows then the require would fail.
require(encoded.length == (index + (attestationSize * nAttestations)), "invalid BatchPriceAttestation size");
PythInternalStructs.PriceInfo memory info;
bytes32 priceId;
// Deserialize each attestation // Deserialize each attestation
for (uint j=0; j < nAttestations; j++) { for (uint j=0; j < nAttestations; j++) {
// NOTE: We don't advance the global index immediately. (PythInternalStructs.PriceInfo memory info, bytes32 priceId) = parseSingleAttestationFromBatch(encoded, index, attestationSize);
// attestationIndex is an attestation-local offset used
// for readability and easier debugging.
uint attestationIndex = 0;
// Unused bytes32 product id
attestationIndex += 32;
priceId = UnsafeBytesLib.toBytes32(encoded, index + attestationIndex);
attestationIndex += 32;
info.price = int64(UnsafeBytesLib.toUint64(encoded, index + attestationIndex));
attestationIndex += 8;
info.conf = UnsafeBytesLib.toUint64(encoded, index + attestationIndex);
attestationIndex += 8;
info.expo = int32(UnsafeBytesLib.toUint32(encoded, index + attestationIndex));
attestationIndex += 4;
info.emaPrice = int64(UnsafeBytesLib.toUint64(encoded, index + attestationIndex));
attestationIndex += 8;
info.emaConf = UnsafeBytesLib.toUint64(encoded, index + attestationIndex);
attestationIndex += 8;
{
// Status is an enum (encoded as uint8) with the following values:
// 0 = UNKNOWN: The price feed is not currently updating for an unknown reason.
// 1 = TRADING: The price feed is updating as expected.
// 2 = HALTED: The price feed is not currently updating because trading in the product has been halted.
// 3 = AUCTION: The price feed is not currently updating because an auction is setting the price.
uint8 status = UnsafeBytesLib.toUint8(encoded, index + attestationIndex);
attestationIndex += 1;
// Unused uint32 numPublishers
attestationIndex += 4;
// Unused uint32 numPublishers
attestationIndex += 4;
// Unused uint64 attestationTime
attestationIndex += 8;
info.publishTime = UnsafeBytesLib.toUint64(encoded, index + attestationIndex);
attestationIndex += 8;
if (status == 1) { // status == TRADING
attestationIndex += 24;
} else {
// If status is not trading then the latest available price is
// the previous price info that are passed here.
// Previous publish time
info.publishTime = UnsafeBytesLib.toUint64(encoded, index + attestationIndex);
attestationIndex += 8;
// Previous price
info.price = int64(UnsafeBytesLib.toUint64(encoded, index + attestationIndex));
attestationIndex += 8;
// Previous confidence
info.conf = UnsafeBytesLib.toUint64(encoded, index + attestationIndex);
attestationIndex += 8;
}
}
require(attestationIndex <= attestationSize, "INTERNAL: Consumed more than `attestationSize` bytes");
// Respect specified attestation size for forward-compat // Respect specified attestation size for forward-compat
index += attestationSize; index += attestationSize;
@ -207,6 +81,86 @@ abstract contract Pyth is PythGetters, PythSetters, AbstractPyth {
} }
} }
function parseSingleAttestationFromBatch(
bytes memory encoded,
uint index,
uint attestationSize
) internal pure returns (
PythInternalStructs.PriceInfo memory info,
bytes32 priceId
) {
unchecked {
// NOTE: We don't advance the global index immediately.
// attestationIndex is an attestation-local offset used
// for readability and easier debugging.
uint attestationIndex = 0;
// Unused bytes32 product id
attestationIndex += 32;
priceId = UnsafeBytesLib.toBytes32(encoded, index + attestationIndex);
attestationIndex += 32;
info.price = int64(UnsafeBytesLib.toUint64(encoded, index + attestationIndex));
attestationIndex += 8;
info.conf = UnsafeBytesLib.toUint64(encoded, index + attestationIndex);
attestationIndex += 8;
info.expo = int32(UnsafeBytesLib.toUint32(encoded, index + attestationIndex));
attestationIndex += 4;
info.emaPrice = int64(UnsafeBytesLib.toUint64(encoded, index + attestationIndex));
attestationIndex += 8;
info.emaConf = UnsafeBytesLib.toUint64(encoded, index + attestationIndex);
attestationIndex += 8;
{
// Status is an enum (encoded as uint8) with the following values:
// 0 = UNKNOWN: The price feed is not currently updating for an unknown reason.
// 1 = TRADING: The price feed is updating as expected.
// 2 = HALTED: The price feed is not currently updating because trading in the product has been halted.
// 3 = AUCTION: The price feed is not currently updating because an auction is setting the price.
uint8 status = UnsafeBytesLib.toUint8(encoded, index + attestationIndex);
attestationIndex += 1;
// Unused uint32 numPublishers
attestationIndex += 4;
// Unused uint32 numPublishers
attestationIndex += 4;
// Unused uint64 attestationTime
attestationIndex += 8;
info.publishTime = UnsafeBytesLib.toUint64(encoded, index + attestationIndex);
attestationIndex += 8;
if (status == 1) { // status == TRADING
attestationIndex += 24;
} else {
// If status is not trading then the latest available price is
// the previous price info that are passed here.
// Previous publish time
info.publishTime = UnsafeBytesLib.toUint64(encoded, index + attestationIndex);
attestationIndex += 8;
// Previous price
info.price = int64(UnsafeBytesLib.toUint64(encoded, index + attestationIndex));
attestationIndex += 8;
// Previous confidence
info.conf = UnsafeBytesLib.toUint64(encoded, index + attestationIndex);
attestationIndex += 8;
}
}
require(attestationIndex <= attestationSize, "INTERNAL: Consumed more than `attestationSize` bytes");
}
}
// This is an overwrite of the same method in AbstractPyth.sol // This is an overwrite of the same method in AbstractPyth.sol
// to be more gas efficient. // to be more gas efficient.
function updatePriceFeedsIfNecessary( function updatePriceFeedsIfNecessary(
@ -267,14 +221,168 @@ abstract contract Pyth is PythGetters, PythSetters, AbstractPyth {
require(price.publishTime != 0, "price feed for the given id is not pushed or does not exist"); require(price.publishTime != 0, "price feed for the given id is not pushed or does not exist");
} }
function parseBatchAttestationHeader(
bytes memory encoded
) internal pure returns (
uint index,
uint nAttestations,
uint attestationSize
) {
unchecked {
index = 0;
// Check header
{
uint32 magic = UnsafeBytesLib.toUint32(encoded, index);
index += 4;
require(magic == 0x50325748, "invalid magic value");
uint16 versionMajor = UnsafeBytesLib.toUint16(encoded, index);
index += 2;
require(versionMajor == 3, "invalid version major, expected 3");
uint16 versionMinor = UnsafeBytesLib.toUint16(encoded, index);
index += 2;
require(versionMinor >= 0, "invalid version minor, expected 0 or more");
uint16 hdrSize = UnsafeBytesLib.toUint16(encoded, index);
index += 2;
// NOTE(2022-04-19): Currently, only payloadId comes after
// hdrSize. Future extra header fields must be read using a
// separate offset to respect hdrSize, i.e.:
//
// uint hdrIndex = 0;
// bpa.header.payloadId = UnsafeBytesLib.toUint8(encoded, index + hdrIndex);
// hdrIndex += 1;
//
// bpa.header.someNewField = UnsafeBytesLib.toUint32(encoded, index + hdrIndex);
// hdrIndex += 4;
//
// // Skip remaining unknown header bytes
// index += bpa.header.hdrSize;
uint8 payloadId = UnsafeBytesLib.toUint8(encoded, index);
// Skip remaining unknown header bytes
index += hdrSize;
// Payload ID of 2 required for batch headerBa
require(payloadId == 2, "invalid payload ID, expected 2 for BatchPriceAttestation");
}
// Parse the number of attestations
nAttestations = UnsafeBytesLib.toUint16(encoded, index);
index += 2;
// Parse the attestation size
attestationSize = UnsafeBytesLib.toUint16(encoded, index);
index += 2;
// Given the message is valid the arithmetic below should not overflow, and
// even if it overflows then the require would fail.
require(encoded.length == (index + (attestationSize * nAttestations)), "invalid BatchPriceAttestation size");
}
}
function parseAndVerifyBatchAttestationVM(
bytes calldata encodedVm
) internal view returns (
IWormhole.VM memory vm
) {
{
bool valid;
string memory reason;
(vm, valid, reason) = wormhole().parseAndVerifyVM(encodedVm);
require(valid, reason);
}
require(verifyPythVM(vm), "invalid data source chain/emitter ID");
}
function parsePriceFeedUpdates( function parsePriceFeedUpdates(
bytes[] calldata updateData, bytes[] calldata updateData,
bytes32[] calldata priceIds, bytes32[] calldata priceIds,
uint64 minPublishTime, uint64 minPublishTime,
uint64 maxPublishTime uint64 maxPublishTime
) external payable override returns (PythStructs.PriceFeed[] memory priceFeeds) { ) external payable override returns (PythStructs.PriceFeed[] memory priceFeeds) {
// TODO: To be implemented soon. unchecked {
revert("unimplemented"); {
uint requiredFee = getUpdateFee(updateData);
require(msg.value >= requiredFee, "insufficient paid fee amount");
}
priceFeeds = new PythStructs.PriceFeed[](priceIds.length);
for(uint i = 0; i < updateData.length; i++) {
bytes memory encoded;
{
IWormhole.VM memory vm = parseAndVerifyBatchAttestationVM(updateData[i]);
encoded = vm.payload;
}
(uint index, uint nAttestations, uint attestationSize) =
parseBatchAttestationHeader(encoded);
// Deserialize each attestation
for (uint j=0; j < nAttestations; j++) {
// NOTE: We don't advance the global index immediately.
// attestationIndex is an attestation-local offset used
// for readability and easier debugging.
uint attestationIndex = 0;
// Unused bytes32 product id
attestationIndex += 32;
bytes32 priceId = UnsafeBytesLib.toBytes32(encoded, index + attestationIndex);
// Check whether the caller requested for this data.
uint k = 0;
for(; k < priceIds.length; k++) {
if (priceIds[k] == priceId) {
break;
}
}
// If priceFeed[k].id != 0 then it means that there was a valid
// update for priceIds[k] and we don't need to process this one.
if (k == priceIds.length || priceFeeds[k].id != 0) {
index += attestationSize;
continue;
}
(PythInternalStructs.PriceInfo memory info, ) = parseSingleAttestationFromBatch(encoded, index, attestationSize);
require(info.publishTime != 0, "price feed for the given id is not pushed or does not exist");
priceFeeds[k].id = priceId;
priceFeeds[k].price.price = info.price;
priceFeeds[k].price.conf = info.conf;
priceFeeds[k].price.expo = info.expo;
priceFeeds[k].price.publishTime = uint(info.publishTime);
priceFeeds[k].emaPrice.price = info.emaPrice;
priceFeeds[k].emaPrice.conf = info.emaConf;
priceFeeds[k].emaPrice.expo = info.expo;
priceFeeds[k].emaPrice.publishTime = uint(info.publishTime);
// Check the publish time of the price is within the given range
// if it is not, then set the id to 0 to indicate that this price id
// still does not have a valid price feed. This will allow other updates
// for this price id to be processed.
if (priceFeeds[k].price.publishTime < minPublishTime ||
priceFeeds[k].price.publishTime > maxPublishTime) {
priceFeeds[k].id = 0;
}
index += attestationSize;
}
}
for (uint k = 0; k < priceIds.length; k++) {
require(priceFeeds[k].id != 0,
"1 or more price feeds are not found in the updateData or they are out of the given time range");
}
}
} }
function queryPriceFeed(bytes32 id) public view override returns (PythStructs.PriceFeed memory priceFeed){ function queryPriceFeed(bytes32 id) public view override returns (PythStructs.PriceFeed memory priceFeed){

View File

@ -86,8 +86,7 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils {
function generateUpdateDataAndFee(PythStructs.Price[] memory prices) internal returns (bytes[] memory updateData, uint updateFee) { function generateUpdateDataAndFee(PythStructs.Price[] memory prices) internal returns (bytes[] memory updateData, uint updateFee) {
bytes memory vaa = generatePriceFeedUpdateVAA( bytes memory vaa = generatePriceFeedUpdateVAA(
priceIds, pricesToPriceAttestations(priceIds, prices),
prices,
sequence, sequence,
NUM_GUARDIAN_SIGNERS NUM_GUARDIAN_SIGNERS
); );
@ -122,6 +121,29 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils {
pyth.updatePriceFeedsIfNecessary{value: cachedPricesUpdateFee}(cachedPricesUpdateData, priceIds, cachedPricesPublishTimes); pyth.updatePriceFeedsIfNecessary{value: cachedPricesUpdateFee}(cachedPricesUpdateData, priceIds, cachedPricesPublishTimes);
} }
function testBenchmarkParsePriceFeedUpdatesForOnePriceFeed() public {
bytes32[] memory ids = new bytes32[](1);
ids[0] = priceIds[0];
pyth.parsePriceFeedUpdates{value: freshPricesUpdateFee}(freshPricesUpdateData, ids, 0, 50);
}
function testBenchmarkParsePriceFeedUpdatesForTwoPriceFeed() public {
bytes32[] memory ids = new bytes32[](2);
ids[0] = priceIds[0];
ids[1] = priceIds[1];
pyth.parsePriceFeedUpdates{value: freshPricesUpdateFee}(freshPricesUpdateData, ids, 0, 50);
}
function testBenchmarkParsePriceFeedUpdatesForOnePriceFeedNotWithinRange() public {
bytes32[] memory ids = new bytes32[](1);
ids[0] = priceIds[0];
vm.expectRevert(bytes("1 or more price feeds are not found in the updateData or they are out of the given time range"));
pyth.parsePriceFeedUpdates{value: freshPricesUpdateFee}(freshPricesUpdateData, ids, 50, 100);
}
function testBenchmarkGetPrice() public { function testBenchmarkGetPrice() public {
// Set the block timestamp to 0. As prices have < 10 timestamp and staleness // Set the block timestamp to 0. As prices have < 10 timestamp and staleness
// is set to 60 seconds, the getPrice should work as expected. // is set to 60 seconds, the getPrice should work as expected.

View File

@ -0,0 +1,444 @@
// SPDX-License-Identifier: Apache 2
pragma solidity ^0.8.0;
import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol";
import "forge-std/Test.sol";
import "@pythnetwork/pyth-sdk-solidity/IPyth.sol";
import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol";
import "./utils/WormholeTestUtils.t.sol";
import "./utils/PythTestUtils.t.sol";
import "./utils/RandTestUtils.t.sol";
contract PythTest is Test, WormholeTestUtils, PythTestUtils, RandTestUtils {
IPyth public pyth;
// -1 is equal to 0x111111 which is the biggest uint if converted back
uint64 constant MAX_UINT64 = uint64(int64(-1));
function setUp() public {
pyth = IPyth(setUpPyth(setUpWormhole(1)));
}
function generateRandomPriceAttestations(
uint length
) internal returns (
bytes32[] memory priceIds,
PriceAttestation[] memory attestations
) {
attestations = new PriceAttestation[](length);
priceIds = new bytes32[](length);
for(uint i = 0; i < length; i++) {
attestations[i].productId = getRandBytes32();
attestations[i].priceId = bytes32(i+1); // price ids should be non-zero and unique
attestations[i].price = getRandInt64();
attestations[i].conf = getRandUint64();
attestations[i].expo = getRandInt32();
attestations[i].emaPrice = getRandInt64();
attestations[i].emaConf = getRandUint64();
attestations[i].status = PriceAttestationStatus(getRandUint() % 2);
attestations[i].numPublishers = getRandUint32();
attestations[i].maxNumPublishers = getRandUint32();
attestations[i].attestationTime = getRandUint64();
attestations[i].publishTime = getRandUint64();
attestations[i].prevPublishTime = getRandUint64();
attestations[i].price = getRandInt64();
attestations[i].conf = getRandUint64();
priceIds[i] = attestations[i].priceId;
}
}
// This method divides attestations into a couple of batches and creates
// updateData for them. It returns the updateData and the updateFee
function createBatchedUpdateDataFromAttestations(
PriceAttestation[] memory attestations
) internal returns (bytes[] memory updateData, uint updateFee) {
uint batchSize = 1 + getRandUint() % attestations.length;
uint numBatches = (attestations.length + batchSize - 1) / batchSize;
updateData = new bytes[](numBatches);
for(uint i = 0; i < attestations.length; i += batchSize) {
uint len = batchSize;
if(attestations.length - i < len) {
len = attestations.length - i;
}
PriceAttestation[] memory batchAttestations = new PriceAttestation[](len);
for(uint j = i; j < i+len; j++) {
batchAttestations[j-i] = attestations[j];
}
updateData[i / batchSize] = generatePriceFeedUpdateVAA(
batchAttestations,
0,
1
);
}
updateFee = pyth.getUpdateFee(updateData);
}
/// Testing parsePriceFeedUpdates method.
function testParsePriceFeedUpdatesWorksWithTradingStatus(uint seed) public {
setRandSeed(seed);
uint numAttestations = 1 + getRandUint() % 10;
(bytes32[] memory priceIds, PriceAttestation[] memory attestations) =
generateRandomPriceAttestations(numAttestations);
for(uint i = 0; i < numAttestations; i++) {
attestations[i].status = PriceAttestationStatus.Trading;
}
(bytes[] memory updateData, uint updateFee) =
createBatchedUpdateDataFromAttestations(attestations);
PythStructs.PriceFeed[] memory priceFeeds =
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
MAX_UINT64
);
for(uint i = 0; i < numAttestations; i++) {
assertEq(priceFeeds[i].id, priceIds[i]);
assertEq(priceFeeds[i].price.price, attestations[i].price);
assertEq(priceFeeds[i].price.conf, attestations[i].conf);
assertEq(priceFeeds[i].price.expo, attestations[i].expo);
assertEq(priceFeeds[i].price.publishTime, attestations[i].publishTime);
assertEq(priceFeeds[i].emaPrice.price, attestations[i].emaPrice);
assertEq(priceFeeds[i].emaPrice.conf, attestations[i].emaConf);
assertEq(priceFeeds[i].emaPrice.expo, attestations[i].expo);
assertEq(priceFeeds[i].emaPrice.publishTime, attestations[i].publishTime);
}
}
function testParsePriceFeedUpdatesWorksWithUnknownStatus(uint seed) public {
setRandSeed(seed);
uint numAttestations = 1 + getRandUint() % 10;
(bytes32[] memory priceIds, PriceAttestation[] memory attestations) =
generateRandomPriceAttestations(numAttestations);
for(uint i = 0; i < numAttestations; i++) {
attestations[i].status = PriceAttestationStatus.Unknown;
}
(bytes[] memory updateData, uint updateFee) =
createBatchedUpdateDataFromAttestations(attestations);
PythStructs.PriceFeed[] memory priceFeeds =
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
MAX_UINT64
);
for(uint i = 0; i < numAttestations; i++) {
assertEq(priceFeeds[i].id, priceIds[i]);
assertEq(priceFeeds[i].price.price, attestations[i].prevPrice);
assertEq(priceFeeds[i].price.conf, attestations[i].prevConf);
assertEq(priceFeeds[i].price.expo, attestations[i].expo);
assertEq(priceFeeds[i].price.publishTime, attestations[i].prevPublishTime);
assertEq(priceFeeds[i].emaPrice.price, attestations[i].emaPrice);
assertEq(priceFeeds[i].emaPrice.conf, attestations[i].emaConf);
assertEq(priceFeeds[i].emaPrice.expo, attestations[i].expo);
assertEq(priceFeeds[i].emaPrice.publishTime, attestations[i].prevPublishTime);
}
}
function testParsePriceFeedUpdatesWorksWithRandomDistinctUpdatesInput(uint seed) public {
setRandSeed(seed);
uint numAttestations = 1 + getRandUint() % 30;
(bytes32[] memory priceIds, PriceAttestation[] memory attestations) =
generateRandomPriceAttestations(numAttestations);
(bytes[] memory updateData, uint updateFee) =
createBatchedUpdateDataFromAttestations(attestations);
// Shuffle the attestations
for (uint i = 1; i < numAttestations; i++) {
uint swapWith = getRandUint() % (i+1);
(attestations[i], attestations[swapWith]) = (attestations[swapWith], attestations[i]);
(priceIds[i], priceIds[swapWith]) = (priceIds[swapWith], priceIds[i]);
}
// Select only first numSelectedAttestations. numSelectedAttestations will be in [0, numAttestations]
uint numSelectedAttestations = getRandUint() % (numAttestations + 1);
PriceAttestation[] memory selectedAttestations = new PriceAttestation[](numSelectedAttestations);
bytes32[] memory selectedPriceIds = new bytes32[](numSelectedAttestations);
for (uint i = 0; i < numSelectedAttestations; i++) {
selectedAttestations[i] = attestations[i];
selectedPriceIds[i] = priceIds[i];
}
// Only parse selected attestations
PythStructs.PriceFeed[] memory priceFeeds =
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
selectedPriceIds,
0,
MAX_UINT64
);
for(uint i = 0; i < numSelectedAttestations; i++) {
assertEq(priceFeeds[i].id, selectedPriceIds[i]);
assertEq(priceFeeds[i].price.expo, selectedAttestations[i].expo);
assertEq(priceFeeds[i].emaPrice.price, selectedAttestations[i].emaPrice);
assertEq(priceFeeds[i].emaPrice.conf, selectedAttestations[i].emaConf);
assertEq(priceFeeds[i].emaPrice.expo, selectedAttestations[i].expo);
if (selectedAttestations[i].status == PriceAttestationStatus.Trading) {
assertEq(priceFeeds[i].price.price, selectedAttestations[i].price);
assertEq(priceFeeds[i].price.conf, selectedAttestations[i].conf);
assertEq(priceFeeds[i].price.publishTime, selectedAttestations[i].publishTime);
assertEq(priceFeeds[i].emaPrice.publishTime, selectedAttestations[i].publishTime);
} else {
assertEq(priceFeeds[i].price.price, selectedAttestations[i].prevPrice);
assertEq(priceFeeds[i].price.conf, selectedAttestations[i].prevConf);
assertEq(priceFeeds[i].price.publishTime, selectedAttestations[i].prevPublishTime);
assertEq(priceFeeds[i].emaPrice.publishTime, selectedAttestations[i].prevPublishTime);
}
}
}
function testParsePriceFeedUpdatesWorksWithOverlappingWithinTimeRangeUpdates() public {
PriceAttestation[] memory attestations = new PriceAttestation[](2);
attestations[0].priceId = bytes32(uint(1));
attestations[0].status = PriceAttestationStatus.Trading;
attestations[0].price = 1000;
attestations[0].publishTime = 10;
attestations[1].priceId = bytes32(uint(1));
attestations[1].status = PriceAttestationStatus.Trading;
attestations[1].price = 2000;
attestations[1].publishTime = 20;
(bytes[] memory updateData, uint updateFee) =
createBatchedUpdateDataFromAttestations(attestations);
bytes32[] memory priceIds = new bytes32[](1);
priceIds[0] = bytes32(uint(1));
PythStructs.PriceFeed[] memory priceFeeds =
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
20
);
assertEq(priceFeeds.length, 1);
assertEq(priceFeeds[0].id, bytes32(uint(1)));
assertTrue(
(priceFeeds[0].price.price == 1000 && priceFeeds[0].price.publishTime == 10) ||
(priceFeeds[0].price.price == 2000 && priceFeeds[0].price.publishTime == 20)
);
}
function testParsePriceFeedUpdatesWorksWithOverlappingMixedTimeRangeUpdates() public {
PriceAttestation[] memory attestations = new PriceAttestation[](2);
attestations[0].priceId = bytes32(uint(1));
attestations[0].status = PriceAttestationStatus.Trading;
attestations[0].price = 1000;
attestations[0].publishTime = 10;
attestations[1].priceId = bytes32(uint(1));
attestations[1].status = PriceAttestationStatus.Trading;
attestations[1].price = 2000;
attestations[1].publishTime = 20;
(bytes[] memory updateData, uint updateFee) =
createBatchedUpdateDataFromAttestations(attestations);
bytes32[] memory priceIds = new bytes32[](1);
priceIds[0] = bytes32(uint(1));
PythStructs.PriceFeed[] memory priceFeeds =
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
5,
15
);
assertEq(priceFeeds.length, 1);
assertEq(priceFeeds[0].id, bytes32(uint(1)));
assertEq(priceFeeds[0].price.price, 1000);
assertEq(priceFeeds[0].price.publishTime, 10);
priceFeeds =
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
15,
25
);
assertEq(priceFeeds.length, 1);
assertEq(priceFeeds[0].id, bytes32(uint(1)));
assertEq(priceFeeds[0].price.price, 2000);
assertEq(priceFeeds[0].price.publishTime, 20);
}
function testParsePriceFeedUpdatesRevertsIfUpdateFeeIsNotPaid() public {
uint numAttestations = 10;
(bytes32[] memory priceIds, PriceAttestation[] memory attestations) =
generateRandomPriceAttestations(numAttestations);
(bytes[] memory updateData, uint updateFee) =
createBatchedUpdateDataFromAttestations(attestations);
// Since attestations are not empty the fee should be at least 1
assertGe(updateFee, 1);
vm.expectRevert(bytes("insufficient paid fee amount"));
pyth.parsePriceFeedUpdates{value: updateFee-1}(
updateData,
priceIds,
0,
MAX_UINT64
);
}
function testParsePriceFeedUpdatesRevertsIfUpdateVAAIsInvalid(uint seed) public {
setRandSeed(seed);
uint numAttestations = 1 + getRandUint() % 10;
(bytes32[] memory priceIds, PriceAttestation[] memory attestations) =
generateRandomPriceAttestations(numAttestations);
(bytes[] memory updateData, uint updateFee) =
createBatchedUpdateDataFromAttestations(attestations);
uint mutPos = getRandUint() % updateData[0].length;
// mutate the random position by 1 bit
updateData[0][mutPos] = bytes1(uint8(updateData[0][mutPos])^1);
// It might revert due to different wormhole errors
vm.expectRevert();
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
MAX_UINT64
);
}
function testParsePriceFeedUpdatesRevertsIfUpdateSourceChainIsInvalid() public {
uint numAttestations = 10;
(bytes32[] memory priceIds, PriceAttestation[] memory attestations) =
generateRandomPriceAttestations(numAttestations);
bytes[] memory updateData = new bytes[](1);
updateData[0] = generateVaa(
uint32(block.timestamp),
SOURCE_EMITTER_CHAIN_ID + 1,
SOURCE_EMITTER_ADDRESS,
1, // Sequence
generatePriceFeedUpdatePayload(attestations),
1 // Num signers
);
uint updateFee = pyth.getUpdateFee(updateData);
vm.expectRevert(bytes("invalid data source chain/emitter ID"));
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
MAX_UINT64
);
}
function testParsePriceFeedUpdatesRevertsIfUpdateSourceAddressIsInvalid() public {
uint numAttestations = 10;
(bytes32[] memory priceIds, PriceAttestation[] memory attestations) =
generateRandomPriceAttestations(numAttestations);
bytes[] memory updateData = new bytes[](1);
updateData[0] = generateVaa(
uint32(block.timestamp),
SOURCE_EMITTER_CHAIN_ID,
0x00000000000000000000000000000000000000000000000000000000000000aa, // Random wrong source address
1, // Sequence
generatePriceFeedUpdatePayload(attestations),
1 // Num signers
);
uint updateFee = pyth.getUpdateFee(updateData);
vm.expectRevert(bytes("invalid data source chain/emitter ID"));
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
MAX_UINT64
);
}
function testParsePriceFeedUpdatesRevertsIfPriceIdNotIncluded() public {
PriceAttestation[] memory attestations = new PriceAttestation[](1);
attestations[0].priceId = bytes32(uint(1));
attestations[0].status = PriceAttestationStatus.Trading;
attestations[0].price = 1000;
attestations[0].publishTime = 10;
(bytes[] memory updateData, uint updateFee) =
createBatchedUpdateDataFromAttestations(attestations);
bytes32[] memory priceIds = new bytes32[](1);
priceIds[0] = bytes32(uint(2));
vm.expectRevert(bytes("1 or more price feeds are not found in the updateData or they are out of the given time range"));
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
MAX_UINT64
);
}
function testParsePriceFeedUpdateRevertsIfPricesOutOfTimeRange() public {
uint numAttestations = 10;
(bytes32[] memory priceIds, PriceAttestation[] memory attestations) =
generateRandomPriceAttestations(numAttestations);
for(uint i = 0; i < numAttestations; i++) {
// Set status to Trading so publishTime is used
attestations[i].status = PriceAttestationStatus.Trading;
attestations[i].publishTime = uint64(100 + getRandUint() % 101); // All between [100, 200]
}
(bytes[] memory updateData, uint updateFee) =
createBatchedUpdateDataFromAttestations(attestations);
// Request for parse within the given time range should work
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
100,
200
);
// Request for parse after the time range should revert.
vm.expectRevert(bytes("1 or more price feeds are not found in the updateData or they are out of the given time range"));
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
300,
MAX_UINT64
);
}
}

View File

@ -53,44 +53,63 @@ abstract contract PythTestUtils is Test, WormholeTestUtils {
return address(pyth); return address(pyth);
} }
// Generates byte-encoded payload for the given prices. It sets the emaPrice the same /// Utilities to help generating price attestations and VAAs for them
// as the given price. You can use this to mock wormhole call using `vm.mockCall` and
// return a VM struct with this payload. enum PriceAttestationStatus {
Unknown,
Trading
}
struct PriceAttestation {
bytes32 productId;
bytes32 priceId;
int64 price;
uint64 conf;
int32 expo;
int64 emaPrice;
uint64 emaConf;
PriceAttestationStatus status;
uint32 numPublishers;
uint32 maxNumPublishers;
uint64 attestationTime;
uint64 publishTime;
uint64 prevPublishTime;
int64 prevPrice;
uint64 prevConf;
}
// Generates byte-encoded payload for the given price attestations. You can use this to mock wormhole
// call using `vm.mockCall` and return a VM struct with this payload.
// You can use generatePriceFeedUpdateVAA to generate a VAA for a price update. // You can use generatePriceFeedUpdateVAA to generate a VAA for a price update.
function generatePriceFeedUpdatePayload( function generatePriceFeedUpdatePayload(
bytes32[] memory priceIds, PriceAttestation[] memory attestations
PythStructs.Price[] memory prices ) public pure returns (bytes memory payload) {
) public returns (bytes memory payload) { bytes memory encodedAttestations = new bytes(0);
assertEq(priceIds.length, prices.length);
bytes memory attestations = new bytes(0); for (uint i = 0; i < attestations.length; ++i) {
for (uint i = 0; i < prices.length; ++i) {
// encodePacked uses padding for arrays and we don't want it, so we manually concat them. // encodePacked uses padding for arrays and we don't want it, so we manually concat them.
attestations = abi.encodePacked( encodedAttestations = abi.encodePacked(
attestations, encodedAttestations,
priceIds[i], // Product ID, we use the same price Id. This field is not used. attestations[i].productId,
priceIds[i], // Price ID, attestations[i].priceId,
prices[i].price, // Price attestations[i].price,
prices[i].conf, // Confidence attestations[i].conf,
prices[i].expo, // Exponent attestations[i].expo,
prices[i].price, // EMA price attestations[i].emaPrice,
prices[i].conf // EMA confidence attestations[i].emaConf
); );
// Breaking this in two encodePackes because of the limited EVM stack. // Breaking this in two encodePackes because of the limited EVM stack.
attestations = abi.encodePacked( encodedAttestations = abi.encodePacked(
attestations, encodedAttestations,
uint8(1), // status = 1 = Trading uint8(attestations[i].status),
uint32(5), // Number of publishers. This field is not used. attestations[i].numPublishers,
uint32(10), // Maximum number of publishers. This field is not used. attestations[i].maxNumPublishers,
uint64(prices[i].publishTime), // Attestation time. This field is not used. attestations[i].attestationTime,
uint64(prices[i].publishTime), // Publish time. attestations[i].publishTime,
// Previous values are unused as status is trading. We use the same value attestations[i].prevPublishTime,
// to make sure the test is irrelevant of the logic of which price is chosen. attestations[i].prevPrice,
uint64(prices[i].publishTime), // Previous publish time. attestations[i].prevConf
prices[i].price, // Previous price
prices[i].conf // Previous confidence
); );
} }
@ -100,24 +119,22 @@ abstract contract PythTestUtils is Test, WormholeTestUtils {
uint16(0), // Minor version uint16(0), // Minor version
uint16(1), // Header size of 1 byte as it only contains payloadId uint16(1), // Header size of 1 byte as it only contains payloadId
uint8(2), // Payload ID 2 means it's a batch price attestation uint8(2), // Payload ID 2 means it's a batch price attestation
uint16(prices.length), // Number of attestations uint16(attestations.length), // Number of attestations
uint16(attestations.length / prices.length), // Size of a single price attestation. uint16(encodedAttestations.length / attestations.length), // Size of a single price attestation.
attestations encodedAttestations
); );
} }
// Generates a VAA for the given prices. // Generates a VAA for the given attestations.
// This method calls generatePriceFeedUpdatePayload and then creates a VAA with it. // This method calls generatePriceFeedUpdatePayload and then creates a VAA with it.
// The VAAs generated from this method use block timestamp as their timestamp. // The VAAs generated from this method use block timestamp as their timestamp.
function generatePriceFeedUpdateVAA( function generatePriceFeedUpdateVAA(
bytes32[] memory priceIds, PriceAttestation[] memory attestations,
PythStructs.Price[] memory prices,
uint64 sequence, uint64 sequence,
uint8 numSigners uint8 numSigners
) public returns (bytes memory vaa) { ) public returns (bytes memory vaa) {
bytes memory payload = generatePriceFeedUpdatePayload( bytes memory payload = generatePriceFeedUpdatePayload(
priceIds, attestations
prices
); );
vaa = generateVaa( vaa = generateVaa(
@ -129,6 +146,36 @@ abstract contract PythTestUtils is Test, WormholeTestUtils {
numSigners numSigners
); );
} }
function pricesToPriceAttestations(
bytes32[] memory priceIds,
PythStructs.Price[] memory prices
) public returns (PriceAttestation[] memory attestations) {
assertEq(priceIds.length, prices.length);
attestations = new PriceAttestation[](prices.length);
for (uint i = 0; i < prices.length; ++i) {
// Product ID, we use the same price Id. This field is not used.
attestations[i].productId = priceIds[i];
attestations[i].priceId = priceIds[i];
attestations[i].price = prices[i].price;
attestations[i].conf = prices[i].conf;
attestations[i].expo = prices[i].expo;
// Same price and conf is used for emaPrice and emaConf
attestations[i].emaPrice = prices[i].price;
attestations[i].emaConf = prices[i].conf;
attestations[i].status = PriceAttestationStatus.Trading;
attestations[i].numPublishers = 5; // This field is not used
attestations[i].maxNumPublishers = 10; // This field is not used
attestations[i].attestationTime = uint64(prices[i].publishTime); // This field is not used
attestations[i].publishTime = uint64(prices[i].publishTime);
// Fields below are not used when status is Trading. just setting them to
// the same value as the prices.
attestations[i].prevPublishTime = uint64(prices[i].publishTime);
attestations[i].prevPrice = prices[i].price;
attestations[i].prevConf = prices[i].conf;
}
}
} }
contract PythTestUtilsTest is Test, WormholeTestUtils, PythTestUtils, IPythEvents { contract PythTestUtilsTest is Test, WormholeTestUtils, PythTestUtils, IPythEvents {
@ -149,8 +196,7 @@ contract PythTestUtilsTest is Test, WormholeTestUtils, PythTestUtils, IPythEvent
); );
bytes memory vaa = generatePriceFeedUpdateVAA( bytes memory vaa = generatePriceFeedUpdateVAA(
priceIds, pricesToPriceAttestations(priceIds, prices),
prices,
1, // Sequence 1, // Sequence
1 // No. Signers 1 // No. Signers
); );

View File

@ -0,0 +1,39 @@
// SPDX-License-Identifier: Apache 2
pragma solidity ^0.8.0;
import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol";
import "forge-std/Test.sol";
contract RandTestUtils is Test {
uint randSeed;
function setRandSeed(uint seed) internal {
randSeed = seed;
}
function getRandBytes32() internal returns (bytes32) {
unchecked { randSeed++; }
return keccak256(abi.encode(randSeed));
}
function getRandUint() internal returns (uint) {
return uint(getRandBytes32());
}
function getRandUint64() internal returns (uint64) {
return uint64(getRandUint());
}
function getRandInt64() internal returns (int64) {
return int64(getRandUint64());
}
function getRandUint32() internal returns (uint32) {
return uint32(getRandUint());
}
function getRandInt32() internal returns (int32) {
return int32(getRandUint32());
}
}

View File

@ -1,7 +1,7 @@
[profile.default] [profile.default]
solc_version = '0.8.4' solc_version = '0.8.4'
optimizer = true optimizer = true
optimizer_runs = 10000 optimizer_runs = 5000
src = 'contracts' src = 'contracts'
# We put the tests into the forge-test directory (instead of test) so that # We put the tests into the forge-test directory (instead of test) so that
# truffle doesn't try to build them # truffle doesn't try to build them

View File

@ -230,7 +230,7 @@ module.exports = {
settings: { settings: {
optimizer: { optimizer: {
enabled: true, enabled: true,
runs: 10000, runs: 5000,
}, },
}, },
}, },