Ethereum/parse price feed updates accumulators (#855)

* feat(target-chains/ethereum): add accumulator support for parsePriceFeedUpdates

* feat(target-chains/ethereum): working impl & test of parsePriceFeedUpdates w/ accumulator data

* refactor(target-contracts/ethereum): refactor pyth accumulator

* refactor: remove console logs & imports

* refactor(target-chain/eth): refactor and more tests

* feat(target-chains/ethereum): address PR feedback

refactor, add parse revert tests

* chore: fix comment

* test(target-chains/ethereum): add/clean up tests

* test: add another test

* test: address more feedback
This commit is contained in:
swimricky 2023-06-07 12:44:47 -07:00 committed by GitHub
parent bdc3fede24
commit 9ddc7fdc7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 864 additions and 169 deletions

View File

@ -80,7 +80,7 @@ abstract contract Pyth is
updateData[i].length > 4 &&
UnsafeBytesLib.toUint32(updateData[i], 0) == ACCUMULATOR_MAGIC
) {
updatePricesUsingAccumulator(updateData[i]);
updatePriceInfosFromAccumulatorUpdate(updateData[i]);
} else {
updatePriceBatchFromVm(updateData[i]);
}
@ -432,84 +432,131 @@ abstract contract Pyth is
}
priceFeeds = new PythStructs.PriceFeed[](priceIds.length);
for (uint i = 0; i < updateData.length; i++) {
bytes memory encoded;
{
IWormhole.VM memory vm = parseAndVerifyBatchAttestationVM(
updateData[i]
);
encoded = vm.payload;
}
(
uint index,
uint nAttestations,
uint attestationSize
) = parseBatchAttestationHeader(encoded);
// Deserialize each attestation
for (uint j = 0; j < nAttestations; j++) {
// NOTE: We don't advance the global index immediately.
// attestationIndex is an attestation-local offset used
// for readability and easier debugging.
uint attestationIndex = 0;
// Unused bytes32 product id
attestationIndex += 32;
bytes32 priceId = UnsafeBytesLib.toBytes32(
encoded,
index + attestationIndex
);
// Check whether the caller requested for this data.
uint k = 0;
for (; k < priceIds.length; k++) {
if (priceIds[k] == priceId) {
break;
}
}
// If priceFeed[k].id != 0 then it means that there was a valid
// update for priceIds[k] and we don't need to process this one.
if (k == priceIds.length || priceFeeds[k].id != 0) {
index += attestationSize;
continue;
}
if (
updateData[i].length > 4 &&
UnsafeBytesLib.toUint32(updateData[i], 0) ==
ACCUMULATOR_MAGIC
) {
(
PythInternalStructs.PriceInfo memory info,
PythInternalStructs.PriceInfo[]
memory accumulatorPriceInfos,
bytes32[] memory accumulatorPriceIds
) = extractPriceInfosFromAccumulatorUpdate(updateData[i]);
) = parseSingleAttestationFromBatch(
encoded,
index,
attestationSize
for (
uint accDataIdx = 0;
accDataIdx < accumulatorPriceIds.length;
accDataIdx++
) {
bytes32 accumulatorPriceId = accumulatorPriceIds[
accDataIdx
];
// check whether caller requested for this data
uint k = findIndexOfPriceId(
priceIds,
accumulatorPriceId
);
priceFeeds[k].id = priceId;
priceFeeds[k].price.price = info.price;
priceFeeds[k].price.conf = info.conf;
priceFeeds[k].price.expo = info.expo;
priceFeeds[k].price.publishTime = uint(info.publishTime);
priceFeeds[k].emaPrice.price = info.emaPrice;
priceFeeds[k].emaPrice.conf = info.emaConf;
priceFeeds[k].emaPrice.expo = info.expo;
priceFeeds[k].emaPrice.publishTime = uint(info.publishTime);
// If priceFeed[k].id != 0 then it means that there was a valid
// update for priceIds[k] and we don't need to process this one.
if (k == priceIds.length || priceFeeds[k].id != 0) {
continue;
}
// Check the publish time of the price is within the given range
// if it is not, then set the id to 0 to indicate that this price id
// still does not have a valid price feed. This will allow other updates
// for this price id to be processed.
if (
priceFeeds[k].price.publishTime < minPublishTime ||
priceFeeds[k].price.publishTime > maxPublishTime
) {
priceFeeds[k].id = 0;
PythInternalStructs.PriceInfo
memory info = accumulatorPriceInfos[accDataIdx];
uint publishTime = uint(info.publishTime);
// Check the publish time of the price is within the given range
// and only fill the priceFeedsInfo if it is.
// If is not, default id value of 0 will still be set and
// this will allow other updates for this price id to be processed.
if (
publishTime >= minPublishTime &&
publishTime <= maxPublishTime
) {
fillPriceFeedFromPriceInfo(
priceFeeds,
k,
accumulatorPriceId,
info,
publishTime
);
}
}
} else {
bytes memory encoded;
{
IWormhole.VM
memory vm = parseAndVerifyBatchAttestationVM(
updateData[i]
);
encoded = vm.payload;
}
index += attestationSize;
/** Batch price logic */
// TODO: gas optimization
(
uint index,
uint nAttestations,
uint attestationSize
) = parseBatchAttestationHeader(encoded);
// Deserialize each attestation
for (uint j = 0; j < nAttestations; j++) {
// NOTE: We don't advance the global index immediately.
// attestationIndex is an attestation-local offset used
// for readability and easier debugging.
uint attestationIndex = 0;
// Unused bytes32 product id
attestationIndex += 32;
bytes32 priceId = UnsafeBytesLib.toBytes32(
encoded,
index + attestationIndex
);
// check whether caller requested for this data
uint k = findIndexOfPriceId(priceIds, priceId);
// If priceFeed[k].id != 0 then it means that there was a valid
// update for priceIds[k] and we don't need to process this one.
if (k == priceIds.length || priceFeeds[k].id != 0) {
index += attestationSize;
continue;
}
(
PythInternalStructs.PriceInfo memory info,
) = parseSingleAttestationFromBatch(
encoded,
index,
attestationSize
);
uint publishTime = uint(info.publishTime);
// Check the publish time of the price is within the given range
// and only fill the priceFeedsInfo if it is.
// If is not, default id value of 0 will still be set and
// this will allow other updates for this price id to be processed.
if (
publishTime >= minPublishTime &&
publishTime <= maxPublishTime
) {
fillPriceFeedFromPriceInfo(
priceFeeds,
k,
priceId,
info,
publishTime
);
}
index += attestationSize;
}
}
}
@ -521,6 +568,38 @@ abstract contract Pyth is
}
}
function findIndexOfPriceId(
bytes32[] calldata priceIds,
bytes32 targetPriceId
) private pure returns (uint index) {
uint k = 0;
uint len = priceIds.length;
for (; k < len; k++) {
if (priceIds[k] == targetPriceId) {
break;
}
}
return k;
}
function fillPriceFeedFromPriceInfo(
PythStructs.PriceFeed[] memory priceFeeds,
uint k,
bytes32 priceId,
PythInternalStructs.PriceInfo memory info,
uint publishTime
) private pure {
priceFeeds[k].id = priceId;
priceFeeds[k].price.price = info.price;
priceFeeds[k].price.conf = info.conf;
priceFeeds[k].price.expo = info.expo;
priceFeeds[k].price.publishTime = publishTime;
priceFeeds[k].emaPrice.price = info.emaPrice;
priceFeeds[k].emaPrice.conf = info.emaConf;
priceFeeds[k].emaPrice.expo = info.expo;
priceFeeds[k].emaPrice.publishTime = publishTime;
}
function queryPriceFeed(
bytes32 id
) public view override returns (PythStructs.PriceFeed memory priceFeed) {

View File

@ -42,11 +42,38 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
revert PythErrors.InvalidUpdateDataSource();
}
function updatePricesUsingAccumulator(
bytes calldata accumulatorUpdate
) internal {
function extractPriceInfosFromAccumulatorUpdate(
bytes memory accumulatorUpdate
)
internal
view
returns (
PythInternalStructs.PriceInfo[] memory priceInfos,
bytes32[] memory priceIds
)
{
(
uint offset,
UpdateType updateType
) = extractUpdateTypeFromAccumulatorHeader(accumulatorUpdate);
if (updateType != UpdateType.WormholeMerkle) {
revert PythErrors.InvalidUpdateData();
}
(priceInfos, priceIds) = extractPriceInfosFromWormholeMerkle(
UnsafeBytesLib.slice(
accumulatorUpdate,
offset,
accumulatorUpdate.length - offset
)
);
}
function extractUpdateTypeFromAccumulatorHeader(
bytes memory accumulatorUpdate
) internal pure returns (uint offset, UpdateType updateType) {
unchecked {
uint offset = 0;
offset = 0;
{
uint32 magic = UnsafeBytesLib.toUint32(
@ -97,37 +124,56 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
offset += trailingHeaderSize;
}
UpdateType updateType = UpdateType(
updateType = UpdateType(
UnsafeBytesLib.toUint8(accumulatorUpdate, offset)
);
offset += 1;
if (accumulatorUpdate.length < offset)
revert PythErrors.InvalidUpdateData();
if (updateType == UpdateType.WormholeMerkle) {
updatePricesUsingWormholeMerkle(
UnsafeBytesLib.slice(
accumulatorUpdate,
offset,
accumulatorUpdate.length - offset
)
);
} else {
revert PythErrors.InvalidUpdateData();
}
}
}
function updatePricesUsingWormholeMerkle(bytes memory encoded) private {
function extractPriceInfosFromWormholeMerkle(
bytes memory encoded
)
internal
view
returns (
PythInternalStructs.PriceInfo[] memory priceInfos,
bytes32[] memory priceIds
)
{
unchecked {
uint offset = 0;
(
uint offset,
bytes20 digest,
uint8 numUpdates
) = extractWormholeMerkleHeaderDigestAndNumUpdates(encoded);
priceInfos = new PythInternalStructs.PriceInfo[](numUpdates);
priceIds = new bytes32[](numUpdates);
for (uint i = 0; i < numUpdates; i++) {
(
offset,
priceInfos[i],
priceIds[i]
) = extractPriceFeedFromMerkleProof(digest, encoded, offset);
}
if (offset != encoded.length) revert PythErrors.InvalidUpdateData();
}
}
function extractWormholeMerkleHeaderDigestAndNumUpdates(
bytes memory encoded
) internal view returns (uint offset, bytes20 digest, uint8 numUpdates) {
unchecked {
offset = 0;
uint16 whProofSize = UnsafeBytesLib.toUint16(encoded, offset);
offset += 2;
bytes20 digest;
{
IWormhole.VM memory vm = parseAndVerifyPythVM(
UnsafeBytesLib.slice(encoded, offset, whProofSize)
@ -138,94 +184,135 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
// emit AccumulatorUpdate(vm.chainId, vm.sequence);
bytes memory encodedPayload = vm.payload;
uint payloadoffset = 0;
uint payloadOffset = 0;
{
uint32 magic = UnsafeBytesLib.toUint32(
encodedPayload,
payloadoffset
payloadOffset
);
payloadoffset += 4;
payloadOffset += 4;
if (magic != ACCUMULATOR_WORMHOLE_MAGIC)
revert PythErrors.InvalidUpdateData();
UpdateType updateType = UpdateType(
UnsafeBytesLib.toUint8(encodedPayload, payloadoffset)
UnsafeBytesLib.toUint8(encodedPayload, payloadOffset)
);
payloadoffset += 1;
payloadOffset += 1;
if (updateType != UpdateType.WormholeMerkle)
revert PythErrors.InvalidUpdateData();
// This field is not used
// uint64 slot = UnsafeBytesLib.toUint64(encodedPayload, payloadoffset);
payloadoffset += 8;
payloadOffset += 8;
// This field is not used
// uint32 ringSize = UnsafeBytesLib.toUint32(encodedPayload, payloadoffset);
payloadoffset += 4;
payloadOffset += 4;
digest = bytes20(
UnsafeBytesLib.toAddress(encodedPayload, payloadoffset)
UnsafeBytesLib.toAddress(encodedPayload, payloadOffset)
);
payloadoffset += 20;
payloadOffset += 20;
// We don't check equality to enable future compatibility.
if (payloadoffset > encodedPayload.length)
if (payloadOffset > encodedPayload.length)
revert PythErrors.InvalidUpdateData();
}
}
uint8 numUpdates = UnsafeBytesLib.toUint8(encoded, offset);
numUpdates = UnsafeBytesLib.toUint8(encoded, offset);
offset += 1;
for (uint i = 0; i < numUpdates; i++) {
offset = verifyAndUpdatePriceFeedFromMerkleProof(
digest,
encoded,
offset
);
}
if (offset != encoded.length) revert PythErrors.InvalidUpdateData();
}
}
function verifyAndUpdatePriceFeedFromMerkleProof(
function extractPriceFeedFromMerkleProof(
bytes20 digest,
bytes memory encoded,
uint offset
) private returns (uint endOffset) {
)
private
pure
returns (
uint endOffset,
PythInternalStructs.PriceInfo memory priceInfo,
bytes32 priceId
)
{
unchecked {
uint16 messageSize = UnsafeBytesLib.toUint16(encoded, offset);
bytes memory encodedMessage;
(endOffset, encodedMessage) = extractMessageFromProof(
encoded,
offset,
digest
);
(priceInfo, priceId) = extractPriceFeedMessage(encodedMessage);
return (endOffset, priceInfo, priceId);
}
}
function extractMessageFromProof(
bytes memory encodedProof,
uint offset,
bytes20 merkleRoot
) private pure returns (uint endOffset, bytes memory encodedMessage) {
unchecked {
uint16 messageSize = UnsafeBytesLib.toUint16(encodedProof, offset);
offset += 2;
bytes memory encodedMessage = UnsafeBytesLib.slice(
encoded,
encodedMessage = UnsafeBytesLib.slice(
encodedProof,
offset,
messageSize
);
offset += messageSize;
bool valid;
(valid, offset) = MerkleTree.isProofValid(
encoded,
(valid, endOffset) = MerkleTree.isProofValid(
encodedProof,
offset,
digest,
merkleRoot,
encodedMessage
);
if (!valid) {
revert PythErrors.InvalidUpdateData();
}
parseAndProcessMessage(encodedMessage);
return offset;
}
}
function extractPriceFeedMessage(
bytes memory encodedMessage
)
private
pure
returns (PythInternalStructs.PriceInfo memory info, bytes32 priceId)
{
unchecked {
MessageType messageType = getMessageType(encodedMessage);
if (messageType == MessageType.PriceFeed) {
(info, priceId) = parsePriceFeedMessage(
UnsafeBytesLib.slice(
encodedMessage,
1,
encodedMessage.length - 1
)
);
} else {
revert PythErrors.InvalidUpdateData();
}
}
}
function getMessageType(
bytes memory encodedMessage
) private pure returns (MessageType messageType) {
return MessageType(UnsafeBytesLib.toUint8(encodedMessage, 0));
}
function parsePriceFeedMessage(
bytes memory encodedPriceFeed
)
@ -286,38 +373,76 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
}
}
function parseAndProcessMessage(bytes memory encodedMessage) private {
function updatePriceInfosFromAccumulatorUpdate(
bytes calldata accumulatorUpdate
) internal {
(
uint offset,
UpdateType updateType
) = extractUpdateTypeFromAccumulatorHeader(accumulatorUpdate);
if (updateType != UpdateType.WormholeMerkle) {
revert PythErrors.InvalidUpdateData();
}
updatePriceInfosFromWormholeMerkle(
UnsafeBytesLib.slice(
accumulatorUpdate,
offset,
accumulatorUpdate.length - offset
)
);
}
function updatePriceInfosFromWormholeMerkle(bytes memory encoded) private {
unchecked {
MessageType messageType = MessageType(
UnsafeBytesLib.toUint8(encodedMessage, 0)
);
(
uint offset,
bytes20 digest,
uint8 numUpdates
) = extractWormholeMerkleHeaderDigestAndNumUpdates(encoded);
if (messageType == MessageType.PriceFeed) {
(
PythInternalStructs.PriceInfo memory info,
bytes32 priceId
) = parsePriceFeedMessage(
UnsafeBytesLib.slice(
encodedMessage,
1,
encodedMessage.length - 1
)
);
uint64 latestPublishTime = latestPriceInfoPublishTime(priceId);
if (info.publishTime > latestPublishTime) {
setLatestPriceInfo(priceId, info);
emit PriceFeedUpdate(
priceId,
info.publishTime,
info.price,
info.conf
);
}
} else {
revert PythErrors.InvalidUpdateData();
for (uint i = 0; i < numUpdates; i++) {
offset = verifyAndUpdatePriceFeedFromMerkleProof(
digest,
encoded,
offset
);
}
if (offset != encoded.length) revert PythErrors.InvalidUpdateData();
}
}
function verifyAndUpdatePriceFeedFromMerkleProof(
bytes20 digest,
bytes memory encoded,
uint offset
) private returns (uint endOffset) {
PythInternalStructs.PriceInfo memory priceInfo;
bytes32 priceId;
(offset, priceInfo, priceId) = extractPriceFeedFromMerkleProof(
digest,
encoded,
offset
);
processMessage(priceInfo, priceId);
return offset;
}
function processMessage(
PythInternalStructs.PriceInfo memory info,
bytes32 priceId
) private {
uint64 latestPublishTime = latestPriceInfoPublishTime(priceId);
if (info.publishTime > latestPublishTime) {
setLatestPriceInfo(priceId, info);
emit PriceFeedUpdate(
priceId,
info.publishTime,
info.price,
info.conf
);
}
}
}

View File

@ -22,6 +22,9 @@ contract PythWormholeMerkleAccumulatorTest is
{
IPyth public pyth;
// -1 is equal to 0xffffff which is the biggest uint if converted back
uint64 constant MAX_UINT64 = uint64(int64(-1));
function setUp() public {
pyth = IPyth(setUpPyth(setUpWormhole(1)));
}
@ -46,6 +49,42 @@ contract PythWormholeMerkleAccumulatorTest is
assertEq(emaPrice.publishTime, priceFeedMessage.publishTime);
}
function assertParsedPriceFeedEqualsMessage(
PythStructs.PriceFeed memory priceFeed,
PriceFeedMessage memory priceFeedMessage,
bytes32 priceId
) internal {
assertEq(priceFeed.id, priceId);
assertEq(priceFeed.price.price, priceFeedMessage.price);
assertEq(priceFeed.price.conf, priceFeedMessage.conf);
assertEq(priceFeed.price.expo, priceFeedMessage.expo);
assertEq(priceFeed.price.publishTime, priceFeedMessage.publishTime);
assertEq(priceFeed.emaPrice.price, priceFeedMessage.emaPrice);
assertEq(priceFeed.emaPrice.conf, priceFeedMessage.emaConf);
assertEq(priceFeed.emaPrice.expo, priceFeedMessage.expo);
assertEq(priceFeed.emaPrice.publishTime, priceFeedMessage.publishTime);
}
function assertParsedPriceFeedStored(
PythStructs.PriceFeed memory priceFeed
) internal {
PythStructs.Price memory aggregatePrice = pyth.getPriceUnsafe(
priceFeed.id
);
assertEq(aggregatePrice.price, priceFeed.price.price);
assertEq(aggregatePrice.conf, priceFeed.price.conf);
assertEq(aggregatePrice.expo, priceFeed.price.expo);
assertEq(aggregatePrice.publishTime, priceFeed.price.publishTime);
PythStructs.Price memory emaPrice = pyth.getEmaPriceUnsafe(
priceFeed.id
);
assertEq(emaPrice.price, priceFeed.emaPrice.price);
assertEq(emaPrice.conf, priceFeed.emaPrice.conf);
assertEq(emaPrice.expo, priceFeed.emaPrice.expo);
assertEq(emaPrice.publishTime, priceFeed.emaPrice.publishTime);
}
function generateRandomPriceFeedMessage(
uint numPriceFeeds
) internal returns (PriceFeedMessage[] memory priceFeedMessages) {
@ -162,11 +201,62 @@ contract PythWormholeMerkleAccumulatorTest is
uint updateFee = pyth.getUpdateFee(updateData);
bytes32[] memory priceIds = new bytes32[](3);
priceIds[0] = priceFeedMessages1[0].priceId;
priceIds[1] = priceFeedMessages1[1].priceId;
priceIds[2] = priceFeedMessages2[0].priceId;
// parse price feeds before updating since parsing price feeds should be independent
// of whatever is currently stored in the contract.
PythStructs.PriceFeed[] memory priceFeeds = pyth.parsePriceFeedUpdates{
value: updateFee
}(updateData, priceIds, 0, MAX_UINT64);
PriceFeedMessage[]
memory expectedPriceFeedMessages = new PriceFeedMessage[](3);
// Only the first occurrence of a valid priceFeedMessage for a paritcular priceFeed.id
// within an updateData will be parsed which is why we exclude priceFeedMessages2[1]
// since it has the same priceFeed.id as priceFeedMessages1[0] even though it has a later publishTime.
// This is different than how updatePriceFeed behaves which will always update using the data
// of the priceFeedMessage with the latest publishTime for a particular priceFeed.id
expectedPriceFeedMessages[0] = priceFeedMessages1[0];
expectedPriceFeedMessages[1] = priceFeedMessages1[1];
expectedPriceFeedMessages[2] = priceFeedMessages2[0];
for (uint i = 0; i < expectedPriceFeedMessages.length; i++) {
assertParsedPriceFeedEqualsMessage(
priceFeeds[i],
expectedPriceFeedMessages[i],
priceIds[i]
);
}
// parse updateData[1] for priceFeedMessages1[0].priceId since this has the latest publishTime
// for that priceId and should be the one that is stored.
bytes32[] memory priceIds1 = new bytes32[](1);
priceIds1[0] = priceFeedMessages1[0].priceId;
bytes[] memory parseUpdateDataInput1 = new bytes[](1);
parseUpdateDataInput1[0] = updateData[1];
PythStructs.PriceFeed[] memory priceFeeds1 = pyth.parsePriceFeedUpdates{
value: updateFee
}(parseUpdateDataInput1, priceIds1, 0, MAX_UINT64);
pyth.updatePriceFeeds{value: updateFee}(updateData);
// check stored price feed information matches updateData
assertPriceFeedMessageStored(priceFeedMessages1[1]);
assertPriceFeedMessageStored(priceFeedMessages2[0]);
assertPriceFeedMessageStored(priceFeedMessages2[1]);
PythStructs.PriceFeed[]
memory expectedPriceFeeds = new PythStructs.PriceFeed[](3);
expectedPriceFeeds[0] = priceFeeds1[0];
expectedPriceFeeds[1] = priceFeeds[1];
expectedPriceFeeds[2] = priceFeeds[2];
// check stored price feed information matches parsed price feeds
for (uint i = 0; i < expectedPriceFeeds.length; i++) {
assertParsedPriceFeedStored(expectedPriceFeeds[i]);
}
}
function testUpdatePriceFeedWithWormholeMerkleIgnoresOutOfOrderUpdateSingleCall()
@ -205,6 +295,28 @@ contract PythWormholeMerkleAccumulatorTest is
pyth.updatePriceFeeds{value: updateFee}(updateData);
assertPriceFeedMessageStored(priceFeedMessages1[0]);
bytes32[] memory priceIds = new bytes32[](1);
priceIds[0] = priceFeedMessages1[0].priceId;
PythStructs.PriceFeed[] memory priceFeeds = pyth.parsePriceFeedUpdates{
value: updateFee
}(updateData, priceIds, 0, MAX_UINT64);
assertEq(priceFeeds.length, 1);
assertParsedPriceFeedStored(priceFeeds[0]);
// parsePriceFeedUpdates should return the first priceFeed in the case
// that the updateData contains multiple feeds with the same id.
// Swap the order of updates in updateData to verify that the other priceFeed is returned
bytes[] memory updateData1 = new bytes[](2);
updateData1[0] = updateData[1];
updateData1[1] = updateData[0];
PythStructs.PriceFeed[] memory priceFeeds1 = pyth.parsePriceFeedUpdates{
value: updateFee
}(updateData1, priceIds, 0, MAX_UINT64);
assertEq(priceFeeds1.length, 1);
assertEq(priceFeeds1[0].price.publishTime, 5);
}
function testUpdatePriceFeedWithWormholeMerkleIgnoresOutOfOrderUpdateMultiCall()
@ -227,7 +339,6 @@ contract PythWormholeMerkleAccumulatorTest is
uint updateFee
) = createWormholeMerkleUpdateData(priceFeedMessages1);
pyth.updatePriceFeeds{value: updateFee}(updateData);
assertPriceFeedMessageStored(priceFeedMessages1[0]);
(updateData, updateFee) = createWormholeMerkleUpdateData(
priceFeedMessages2
@ -237,6 +348,86 @@ contract PythWormholeMerkleAccumulatorTest is
assertPriceFeedMessageStored(priceFeedMessages1[0]);
}
function testParsePriceFeedUpdatesWithWormholeMerklWorksWithOurOfOrderUpdateMultiCall()
public
{
PriceFeedMessage[]
memory priceFeedMessages1 = generateRandomPriceFeedMessage(1);
PriceFeedMessage[]
memory priceFeedMessages2 = generateRandomPriceFeedMessage(1);
// Make the price ids the same
priceFeedMessages2[0].priceId = priceFeedMessages1[0].priceId;
// Adjust the timestamps so the second timestamp is smaller than the first
// Parse should work regardless of what's stored on chain.
priceFeedMessages1[0].publishTime = 10;
priceFeedMessages2[0].publishTime = 5;
(
bytes[] memory updateData,
uint updateFee
) = createWormholeMerkleUpdateData(priceFeedMessages1);
bytes32[] memory priceIds = new bytes32[](1);
priceIds[0] = priceFeedMessages1[0].priceId;
PythStructs.PriceFeed[] memory priceFeeds = pyth.parsePriceFeedUpdates{
value: updateFee
}(updateData, priceIds, 0, MAX_UINT64);
// Parse should always return the same value regardless of what's stored on chain.
assertEq(priceFeeds.length, 1);
assertParsedPriceFeedEqualsMessage(
priceFeeds[0],
priceFeedMessages1[0],
priceIds[0]
);
pyth.updatePriceFeeds{value: updateFee}(updateData);
priceFeeds = pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
MAX_UINT64
);
assertEq(priceFeeds.length, 1);
assertParsedPriceFeedEqualsMessage(
priceFeeds[0],
priceFeedMessages1[0],
priceIds[0]
);
(
bytes[] memory updateData1,
uint updateFee1
) = createWormholeMerkleUpdateData(priceFeedMessages2);
pyth.updatePriceFeeds{value: updateFee1}(updateData1);
// reparse the original updateData should still return the same thing
priceFeeds = pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
MAX_UINT64
);
assertEq(priceFeeds.length, 1);
assertParsedPriceFeedEqualsMessage(
priceFeeds[0],
priceFeedMessages1[0],
priceIds[0]
);
// parsing the second message should return the data based on the second messagef
priceFeeds = pyth.parsePriceFeedUpdates{value: updateFee1}(
updateData1,
priceIds,
0,
MAX_UINT64
);
assertEq(priceFeeds.length, 1);
assertParsedPriceFeedEqualsMessage(
priceFeeds[0],
priceFeedMessages2[0],
priceIds[0]
);
}
function isNotMatch(
bytes memory a,
bytes memory b
@ -249,12 +440,23 @@ contract PythWormholeMerkleAccumulatorTest is
/// expected value, that item will be forged to be invalid.
function createAndForgeWormholeMerkleUpdateData(
bytes memory forgeItem
) public returns (bytes[] memory updateData, uint updateFee) {
)
public
returns (
bytes[] memory updateData,
uint updateFee,
bytes32[] memory priceIds
)
{
uint numPriceFeeds = 10;
PriceFeedMessage[]
memory priceFeedMessages = generateRandomPriceFeedMessage(
numPriceFeeds
);
priceIds = new bytes32[](numPriceFeeds);
for (uint i = 0; i < numPriceFeeds; i++) {
priceIds[i] = priceFeedMessages[i].priceId;
}
bytes[] memory encodedPriceFeedMessages = encodePriceFeedMessages(
priceFeedMessages
@ -327,11 +529,21 @@ contract PythWormholeMerkleAccumulatorTest is
// In this test the Wormhole accumulator magic is wrong and the update gets reverted.
(
bytes[] memory updateData,
uint updateFee
uint updateFee,
bytes32[] memory priceIds
) = createAndForgeWormholeMerkleUpdateData("whMagic");
vm.expectRevert(PythErrors.InvalidUpdateData.selector);
pyth.updatePriceFeeds{value: updateFee}(updateData);
vm.expectRevert(PythErrors.InvalidUpdateData.selector);
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
MAX_UINT64
);
}
function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongVAAPayloadUpdateType()
@ -342,11 +554,20 @@ contract PythWormholeMerkleAccumulatorTest is
(
bytes[] memory updateData,
uint updateFee
uint updateFee,
bytes32[] memory priceIds
) = createAndForgeWormholeMerkleUpdateData("whUpdateType");
vm.expectRevert(); // Reason: Conversion into non-existent enum type. However it
// was not possible to check the revert reason in the test.
pyth.updatePriceFeeds{value: updateFee}(updateData);
vm.expectRevert();
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
MAX_UINT64
);
}
function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongVAASource()
@ -355,15 +576,34 @@ contract PythWormholeMerkleAccumulatorTest is
// In this test the Wormhole message source is wrong.
(
bytes[] memory updateData,
uint updateFee
uint updateFee,
bytes32[] memory priceIds
) = createAndForgeWormholeMerkleUpdateData("whSourceAddress");
vm.expectRevert(PythErrors.InvalidUpdateDataSource.selector);
pyth.updatePriceFeeds{value: updateFee}(updateData);
(updateData, updateFee) = createAndForgeWormholeMerkleUpdateData(
"whSourceChain"
vm.expectRevert(PythErrors.InvalidUpdateDataSource.selector);
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
MAX_UINT64
);
(
updateData,
updateFee,
priceIds
) = createAndForgeWormholeMerkleUpdateData("whSourceChain");
vm.expectRevert(PythErrors.InvalidUpdateDataSource.selector);
pyth.updatePriceFeeds{value: updateFee}(updateData);
vm.expectRevert(PythErrors.InvalidUpdateDataSource.selector);
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
MAX_UINT64
);
}
function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongRootDigest()
@ -372,10 +612,19 @@ contract PythWormholeMerkleAccumulatorTest is
// In this test the Wormhole merkle proof digest is wrong
(
bytes[] memory updateData,
uint updateFee
uint updateFee,
bytes32[] memory priceIds
) = createAndForgeWormholeMerkleUpdateData("rootDigest");
vm.expectRevert(PythErrors.InvalidUpdateData.selector);
pyth.updatePriceFeeds{value: updateFee}(updateData);
vm.expectRevert(PythErrors.InvalidUpdateData.selector);
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
MAX_UINT64
);
}
function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongProofItem()
@ -384,10 +633,19 @@ contract PythWormholeMerkleAccumulatorTest is
// In this test all Wormhole merkle proof items are the first item proof
(
bytes[] memory updateData,
uint updateFee
uint updateFee,
bytes32[] memory priceIds
) = createAndForgeWormholeMerkleUpdateData("proofItem");
vm.expectRevert(PythErrors.InvalidUpdateData.selector);
pyth.updatePriceFeeds{value: updateFee}(updateData);
vm.expectRevert(PythErrors.InvalidUpdateData.selector);
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
MAX_UINT64
);
}
function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongHeader()
@ -396,17 +654,35 @@ contract PythWormholeMerkleAccumulatorTest is
// In this test the message headers are wrong
(
bytes[] memory updateData,
uint updateFee
uint updateFee,
bytes32[] memory priceIds
) = createAndForgeWormholeMerkleUpdateData("headerMagic");
vm.expectRevert(); // The revert reason is not deterministic because when it doesn't match it goes through
// the old approach.
pyth.updatePriceFeeds{value: updateFee}(updateData);
(updateData, updateFee) = createAndForgeWormholeMerkleUpdateData(
"headerMajorVersion"
vm.expectRevert();
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
MAX_UINT64
);
(
updateData,
updateFee,
priceIds
) = createAndForgeWormholeMerkleUpdateData("headerMajorVersion");
vm.expectRevert(PythErrors.InvalidUpdateData.selector);
pyth.updatePriceFeeds{value: updateFee}(updateData);
vm.expectRevert(PythErrors.InvalidUpdateData.selector);
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
MAX_UINT64
);
}
function testUpdatePriceFeedWithWormholeMerkleRevertsIfUpdateFeeIsNotPaid()
@ -421,7 +697,222 @@ contract PythWormholeMerkleAccumulatorTest is
priceFeedMessages
);
bytes32[] memory priceIds = new bytes32[](numPriceFeeds);
for (uint i = 0; i < numPriceFeeds; i++) {
priceIds[i] = priceFeedMessages[i].priceId;
}
vm.expectRevert(PythErrors.InsufficientFee.selector);
pyth.updatePriceFeeds{value: 0}(updateData);
vm.expectRevert(PythErrors.InsufficientFee.selector);
pyth.parsePriceFeedUpdates{value: 0}(
updateData,
priceIds,
0,
MAX_UINT64
);
}
function testParsePriceFeedWithWormholeMerkleWorks(uint seed) public {
setRandSeed(seed);
uint numPriceFeeds = (getRandUint() % 10) + 1;
PriceFeedMessage[]
memory priceFeedMessages = generateRandomPriceFeedMessage(
numPriceFeeds
);
(
bytes[] memory updateData,
uint updateFee
) = createWormholeMerkleUpdateData(priceFeedMessages);
bytes32[] memory priceIds = new bytes32[](numPriceFeeds);
for (uint i = 0; i < numPriceFeeds; i++) {
priceIds[i] = priceFeedMessages[i].priceId;
}
PythStructs.PriceFeed[] memory priceFeeds = pyth.parsePriceFeedUpdates{
value: updateFee
}(updateData, priceIds, 0, MAX_UINT64);
for (uint i = 0; i < priceFeeds.length; i++) {
assertParsedPriceFeedEqualsMessage(
priceFeeds[i],
priceFeedMessages[i],
priceIds[i]
);
}
// update priceFeedMessages
for (uint i = 0; i < numPriceFeeds; i++) {
priceFeedMessages[i].price = getRandInt64();
priceFeedMessages[i].conf = getRandUint64();
priceFeedMessages[i].expo = getRandInt32();
priceFeedMessages[i].publishTime = getRandUint64();
priceFeedMessages[i].emaPrice = getRandInt64();
priceFeedMessages[i].emaConf = getRandUint64();
}
(updateData, updateFee) = createWormholeMerkleUpdateData(
priceFeedMessages
);
// reparse
priceFeeds = pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
MAX_UINT64
);
for (uint i = 0; i < priceFeeds.length; i++) {
assertParsedPriceFeedEqualsMessage(
priceFeeds[i],
priceFeedMessages[i],
priceIds[i]
);
}
}
function testParsePriceFeedWithWormholeMerkleWorksRandomDistinctUpdatesInput(
uint seed
) public {
setRandSeed(seed);
uint numPriceFeeds = (getRandUint() % 10) + 1;
PriceFeedMessage[]
memory priceFeedMessages = generateRandomPriceFeedMessage(
numPriceFeeds
);
(
bytes[] memory updateData,
uint updateFee
) = createWormholeMerkleUpdateData(priceFeedMessages);
bytes32[] memory priceIds = new bytes32[](numPriceFeeds);
for (uint i = 0; i < numPriceFeeds; i++) {
priceIds[i] = priceFeedMessages[i].priceId;
}
// Shuffle the priceFeedMessages
for (uint i = 1; i < numPriceFeeds; i++) {
uint swapWith = getRandUint() % (i + 1);
(priceFeedMessages[i], priceFeedMessages[swapWith]) = (
priceFeedMessages[swapWith],
priceFeedMessages[i]
);
(priceIds[i], priceIds[swapWith]) = (
priceIds[swapWith],
priceIds[i]
);
}
// Select only first numSelectedPriceFeeds. numSelectedPriceFeeds will be in [0, numPriceFeeds]
uint numSelectedPriceFeeds = getRandUint() % (numPriceFeeds + 1);
PriceFeedMessage[]
memory selectedPriceFeedsMessages = new PriceFeedMessage[](
numSelectedPriceFeeds
);
bytes32[] memory selectedPriceIds = new bytes32[](
numSelectedPriceFeeds
);
for (uint i = 0; i < numSelectedPriceFeeds; i++) {
selectedPriceFeedsMessages[i] = priceFeedMessages[i];
selectedPriceIds[i] = priceIds[i];
}
PythStructs.PriceFeed[] memory priceFeeds = pyth.parsePriceFeedUpdates{
value: updateFee
}(updateData, selectedPriceIds, 0, MAX_UINT64);
for (uint i = 0; i < numSelectedPriceFeeds; i++) {
assertParsedPriceFeedEqualsMessage(
priceFeeds[i],
selectedPriceFeedsMessages[i],
selectedPriceIds[i]
);
}
}
function testParsePriceFeedWithWormholeMerkleRevertsIfPriceIdNotIncluded()
public
{
PriceFeedMessage[] memory priceFeedMessages = new PriceFeedMessage[](1);
priceFeedMessages[0] = PriceFeedMessage({
priceId: bytes32(uint(1)),
price: getRandInt64(),
conf: getRandUint64(),
expo: getRandInt32(),
publishTime: getRandUint64(),
prevPublishTime: getRandUint64(),
emaPrice: getRandInt64(),
emaConf: getRandUint64()
});
(
bytes[] memory updateData,
uint updateFee
) = createWormholeMerkleUpdateData(priceFeedMessages);
bytes32[] memory priceIds = new bytes32[](1);
priceIds[0] = bytes32(uint(2));
vm.expectRevert(PythErrors.PriceFeedNotFoundWithinRange.selector);
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
MAX_UINT64
);
}
function testParsePriceFeedUpdateRevertsIfPricesOutOfTimeRange() public {
uint numPriceFeeds = (getRandUint() % 10) + 1;
PriceFeedMessage[]
memory priceFeedMessages = generateRandomPriceFeedMessage(
numPriceFeeds
);
for (uint i = 0; i < numPriceFeeds; i++) {
priceFeedMessages[i].publishTime = uint64(
100 + (getRandUint() % 101)
); // All between [100, 200]
}
(
bytes[] memory updateData,
uint updateFee
) = createWormholeMerkleUpdateData(priceFeedMessages);
bytes32[] memory priceIds = new bytes32[](numPriceFeeds);
for (uint i = 0; i < numPriceFeeds; i++) {
priceIds[i] = priceFeedMessages[i].priceId;
}
// Request for parse within the given time range should work
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
100,
200
);
// Request for parse after the time range should revert.
vm.expectRevert(PythErrors.PriceFeedNotFoundWithinRange.selector);
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
300,
MAX_UINT64
);
// Request for parse before the time range should revert.
vm.expectRevert(PythErrors.PriceFeedNotFoundWithinRange.selector);
pyth.parsePriceFeedUpdates{value: updateFee}(
updateData,
priceIds,
0,
99
);
}
//TODO: add some tests of forward compatibility.
// I.e., create a message where each part that can be expanded in size is expanded and make sure that parsing still works
}

View File

@ -15,7 +15,7 @@ import "./utils/RandTestUtils.t.sol";
contract PythTest is Test, WormholeTestUtils, PythTestUtils, RandTestUtils {
IPyth public pyth;
// -1 is equal to 0x111111 which is the biggest uint if converted back
// -1 is equal to 0xffffff which is the biggest uint if converted back
uint64 constant MAX_UINT64 = uint64(int64(-1));
function setUp() public {