From 9ddc7fdc7d16a35e36055f9db32d9e55604aa41c Mon Sep 17 00:00:00 2001 From: swimricky <86628128+swimricky@users.noreply.github.com> Date: Wed, 7 Jun 2023 12:44:47 -0700 Subject: [PATCH] Ethereum/parse price feed updates accumulators (#855) * feat(target-chains/ethereum): add accumulator support for parsePriceFeedUpdates * feat(target-chains/ethereum): working impl & test of parsePriceFeedUpdates w/ accumulator data * refactor(target-contracts/ethereum): refactor pyth accumulator * refactor: remove console logs & imports * refactor(target-chain/eth): refactor and more tests * feat(target-chains/ethereum): address PR feedback refactor, add parse revert tests * chore: fix comment * test(target-chains/ethereum): add/clean up tests * test: add another test * test: address more feedback --- .../contracts/contracts/pyth/Pyth.sol | 221 +++++--- .../contracts/pyth/PythAccumulator.sol | 293 +++++++--- .../Pyth.WormholeMerkleAccumulator.t.sol | 517 +++++++++++++++++- .../ethereum/contracts/forge-test/Pyth.t.sol | 2 +- 4 files changed, 864 insertions(+), 169 deletions(-) diff --git a/target_chains/ethereum/contracts/contracts/pyth/Pyth.sol b/target_chains/ethereum/contracts/contracts/pyth/Pyth.sol index 06a07009..8d5086de 100644 --- a/target_chains/ethereum/contracts/contracts/pyth/Pyth.sol +++ b/target_chains/ethereum/contracts/contracts/pyth/Pyth.sol @@ -80,7 +80,7 @@ abstract contract Pyth is updateData[i].length > 4 && UnsafeBytesLib.toUint32(updateData[i], 0) == ACCUMULATOR_MAGIC ) { - updatePricesUsingAccumulator(updateData[i]); + updatePriceInfosFromAccumulatorUpdate(updateData[i]); } else { updatePriceBatchFromVm(updateData[i]); } @@ -432,84 +432,131 @@ abstract contract Pyth is } priceFeeds = new PythStructs.PriceFeed[](priceIds.length); - for (uint i = 0; i < updateData.length; i++) { - bytes memory encoded; - - { - IWormhole.VM memory vm = parseAndVerifyBatchAttestationVM( - updateData[i] - ); - encoded = vm.payload; - } - - ( - uint index, - uint nAttestations, - uint attestationSize - ) = parseBatchAttestationHeader(encoded); - - // Deserialize each attestation - for (uint j = 0; j < nAttestations; j++) { - // NOTE: We don't advance the global index immediately. - // attestationIndex is an attestation-local offset used - // for readability and easier debugging. - uint attestationIndex = 0; - - // Unused bytes32 product id - attestationIndex += 32; - - bytes32 priceId = UnsafeBytesLib.toBytes32( - encoded, - index + attestationIndex - ); - - // Check whether the caller requested for this data. - uint k = 0; - for (; k < priceIds.length; k++) { - if (priceIds[k] == priceId) { - break; - } - } - - // If priceFeed[k].id != 0 then it means that there was a valid - // update for priceIds[k] and we don't need to process this one. - if (k == priceIds.length || priceFeeds[k].id != 0) { - index += attestationSize; - continue; - } - + if ( + updateData[i].length > 4 && + UnsafeBytesLib.toUint32(updateData[i], 0) == + ACCUMULATOR_MAGIC + ) { ( - PythInternalStructs.PriceInfo memory info, + PythInternalStructs.PriceInfo[] + memory accumulatorPriceInfos, + bytes32[] memory accumulatorPriceIds + ) = extractPriceInfosFromAccumulatorUpdate(updateData[i]); - ) = parseSingleAttestationFromBatch( - encoded, - index, - attestationSize + for ( + uint accDataIdx = 0; + accDataIdx < accumulatorPriceIds.length; + accDataIdx++ + ) { + bytes32 accumulatorPriceId = accumulatorPriceIds[ + accDataIdx + ]; + // check whether caller requested for this data + uint k = findIndexOfPriceId( + priceIds, + accumulatorPriceId ); - priceFeeds[k].id = priceId; - priceFeeds[k].price.price = info.price; - priceFeeds[k].price.conf = info.conf; - priceFeeds[k].price.expo = info.expo; - priceFeeds[k].price.publishTime = uint(info.publishTime); - priceFeeds[k].emaPrice.price = info.emaPrice; - priceFeeds[k].emaPrice.conf = info.emaConf; - priceFeeds[k].emaPrice.expo = info.expo; - priceFeeds[k].emaPrice.publishTime = uint(info.publishTime); + // If priceFeed[k].id != 0 then it means that there was a valid + // update for priceIds[k] and we don't need to process this one. + if (k == priceIds.length || priceFeeds[k].id != 0) { + continue; + } - // Check the publish time of the price is within the given range - // if it is not, then set the id to 0 to indicate that this price id - // still does not have a valid price feed. This will allow other updates - // for this price id to be processed. - if ( - priceFeeds[k].price.publishTime < minPublishTime || - priceFeeds[k].price.publishTime > maxPublishTime - ) { - priceFeeds[k].id = 0; + PythInternalStructs.PriceInfo + memory info = accumulatorPriceInfos[accDataIdx]; + + uint publishTime = uint(info.publishTime); + // Check the publish time of the price is within the given range + // and only fill the priceFeedsInfo if it is. + // If is not, default id value of 0 will still be set and + // this will allow other updates for this price id to be processed. + if ( + publishTime >= minPublishTime && + publishTime <= maxPublishTime + ) { + fillPriceFeedFromPriceInfo( + priceFeeds, + k, + accumulatorPriceId, + info, + publishTime + ); + } + } + } else { + bytes memory encoded; + { + IWormhole.VM + memory vm = parseAndVerifyBatchAttestationVM( + updateData[i] + ); + encoded = vm.payload; } - index += attestationSize; + /** Batch price logic */ + // TODO: gas optimization + ( + uint index, + uint nAttestations, + uint attestationSize + ) = parseBatchAttestationHeader(encoded); + + // Deserialize each attestation + for (uint j = 0; j < nAttestations; j++) { + // NOTE: We don't advance the global index immediately. + // attestationIndex is an attestation-local offset used + // for readability and easier debugging. + uint attestationIndex = 0; + + // Unused bytes32 product id + attestationIndex += 32; + + bytes32 priceId = UnsafeBytesLib.toBytes32( + encoded, + index + attestationIndex + ); + + // check whether caller requested for this data + uint k = findIndexOfPriceId(priceIds, priceId); + + // If priceFeed[k].id != 0 then it means that there was a valid + // update for priceIds[k] and we don't need to process this one. + if (k == priceIds.length || priceFeeds[k].id != 0) { + index += attestationSize; + continue; + } + + ( + PythInternalStructs.PriceInfo memory info, + + ) = parseSingleAttestationFromBatch( + encoded, + index, + attestationSize + ); + + uint publishTime = uint(info.publishTime); + // Check the publish time of the price is within the given range + // and only fill the priceFeedsInfo if it is. + // If is not, default id value of 0 will still be set and + // this will allow other updates for this price id to be processed. + if ( + publishTime >= minPublishTime && + publishTime <= maxPublishTime + ) { + fillPriceFeedFromPriceInfo( + priceFeeds, + k, + priceId, + info, + publishTime + ); + } + + index += attestationSize; + } } } @@ -521,6 +568,38 @@ abstract contract Pyth is } } + function findIndexOfPriceId( + bytes32[] calldata priceIds, + bytes32 targetPriceId + ) private pure returns (uint index) { + uint k = 0; + uint len = priceIds.length; + for (; k < len; k++) { + if (priceIds[k] == targetPriceId) { + break; + } + } + return k; + } + + function fillPriceFeedFromPriceInfo( + PythStructs.PriceFeed[] memory priceFeeds, + uint k, + bytes32 priceId, + PythInternalStructs.PriceInfo memory info, + uint publishTime + ) private pure { + priceFeeds[k].id = priceId; + priceFeeds[k].price.price = info.price; + priceFeeds[k].price.conf = info.conf; + priceFeeds[k].price.expo = info.expo; + priceFeeds[k].price.publishTime = publishTime; + priceFeeds[k].emaPrice.price = info.emaPrice; + priceFeeds[k].emaPrice.conf = info.emaConf; + priceFeeds[k].emaPrice.expo = info.expo; + priceFeeds[k].emaPrice.publishTime = publishTime; + } + function queryPriceFeed( bytes32 id ) public view override returns (PythStructs.PriceFeed memory priceFeed) { diff --git a/target_chains/ethereum/contracts/contracts/pyth/PythAccumulator.sol b/target_chains/ethereum/contracts/contracts/pyth/PythAccumulator.sol index 14ae1bf7..16cf8c7a 100644 --- a/target_chains/ethereum/contracts/contracts/pyth/PythAccumulator.sol +++ b/target_chains/ethereum/contracts/contracts/pyth/PythAccumulator.sol @@ -42,11 +42,38 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth { revert PythErrors.InvalidUpdateDataSource(); } - function updatePricesUsingAccumulator( - bytes calldata accumulatorUpdate - ) internal { + function extractPriceInfosFromAccumulatorUpdate( + bytes memory accumulatorUpdate + ) + internal + view + returns ( + PythInternalStructs.PriceInfo[] memory priceInfos, + bytes32[] memory priceIds + ) + { + ( + uint offset, + UpdateType updateType + ) = extractUpdateTypeFromAccumulatorHeader(accumulatorUpdate); + + if (updateType != UpdateType.WormholeMerkle) { + revert PythErrors.InvalidUpdateData(); + } + (priceInfos, priceIds) = extractPriceInfosFromWormholeMerkle( + UnsafeBytesLib.slice( + accumulatorUpdate, + offset, + accumulatorUpdate.length - offset + ) + ); + } + + function extractUpdateTypeFromAccumulatorHeader( + bytes memory accumulatorUpdate + ) internal pure returns (uint offset, UpdateType updateType) { unchecked { - uint offset = 0; + offset = 0; { uint32 magic = UnsafeBytesLib.toUint32( @@ -97,37 +124,56 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth { offset += trailingHeaderSize; } - UpdateType updateType = UpdateType( + updateType = UpdateType( UnsafeBytesLib.toUint8(accumulatorUpdate, offset) ); offset += 1; if (accumulatorUpdate.length < offset) revert PythErrors.InvalidUpdateData(); - - if (updateType == UpdateType.WormholeMerkle) { - updatePricesUsingWormholeMerkle( - UnsafeBytesLib.slice( - accumulatorUpdate, - offset, - accumulatorUpdate.length - offset - ) - ); - } else { - revert PythErrors.InvalidUpdateData(); - } } } - function updatePricesUsingWormholeMerkle(bytes memory encoded) private { + function extractPriceInfosFromWormholeMerkle( + bytes memory encoded + ) + internal + view + returns ( + PythInternalStructs.PriceInfo[] memory priceInfos, + bytes32[] memory priceIds + ) + { unchecked { - uint offset = 0; + ( + uint offset, + bytes20 digest, + uint8 numUpdates + ) = extractWormholeMerkleHeaderDigestAndNumUpdates(encoded); + + priceInfos = new PythInternalStructs.PriceInfo[](numUpdates); + priceIds = new bytes32[](numUpdates); + for (uint i = 0; i < numUpdates; i++) { + ( + offset, + priceInfos[i], + priceIds[i] + ) = extractPriceFeedFromMerkleProof(digest, encoded, offset); + } + + if (offset != encoded.length) revert PythErrors.InvalidUpdateData(); + } + } + + function extractWormholeMerkleHeaderDigestAndNumUpdates( + bytes memory encoded + ) internal view returns (uint offset, bytes20 digest, uint8 numUpdates) { + unchecked { + offset = 0; uint16 whProofSize = UnsafeBytesLib.toUint16(encoded, offset); offset += 2; - bytes20 digest; - { IWormhole.VM memory vm = parseAndVerifyPythVM( UnsafeBytesLib.slice(encoded, offset, whProofSize) @@ -138,94 +184,135 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth { // emit AccumulatorUpdate(vm.chainId, vm.sequence); bytes memory encodedPayload = vm.payload; - uint payloadoffset = 0; + uint payloadOffset = 0; { uint32 magic = UnsafeBytesLib.toUint32( encodedPayload, - payloadoffset + payloadOffset ); - payloadoffset += 4; + payloadOffset += 4; if (magic != ACCUMULATOR_WORMHOLE_MAGIC) revert PythErrors.InvalidUpdateData(); UpdateType updateType = UpdateType( - UnsafeBytesLib.toUint8(encodedPayload, payloadoffset) + UnsafeBytesLib.toUint8(encodedPayload, payloadOffset) ); - payloadoffset += 1; + payloadOffset += 1; if (updateType != UpdateType.WormholeMerkle) revert PythErrors.InvalidUpdateData(); // This field is not used // uint64 slot = UnsafeBytesLib.toUint64(encodedPayload, payloadoffset); - payloadoffset += 8; + payloadOffset += 8; // This field is not used // uint32 ringSize = UnsafeBytesLib.toUint32(encodedPayload, payloadoffset); - payloadoffset += 4; + payloadOffset += 4; digest = bytes20( - UnsafeBytesLib.toAddress(encodedPayload, payloadoffset) + UnsafeBytesLib.toAddress(encodedPayload, payloadOffset) ); - payloadoffset += 20; + payloadOffset += 20; // We don't check equality to enable future compatibility. - if (payloadoffset > encodedPayload.length) + if (payloadOffset > encodedPayload.length) revert PythErrors.InvalidUpdateData(); } } - uint8 numUpdates = UnsafeBytesLib.toUint8(encoded, offset); + numUpdates = UnsafeBytesLib.toUint8(encoded, offset); offset += 1; - - for (uint i = 0; i < numUpdates; i++) { - offset = verifyAndUpdatePriceFeedFromMerkleProof( - digest, - encoded, - offset - ); - } - - if (offset != encoded.length) revert PythErrors.InvalidUpdateData(); } } - function verifyAndUpdatePriceFeedFromMerkleProof( + function extractPriceFeedFromMerkleProof( bytes20 digest, bytes memory encoded, uint offset - ) private returns (uint endOffset) { + ) + private + pure + returns ( + uint endOffset, + PythInternalStructs.PriceInfo memory priceInfo, + bytes32 priceId + ) + { unchecked { - uint16 messageSize = UnsafeBytesLib.toUint16(encoded, offset); + bytes memory encodedMessage; + (endOffset, encodedMessage) = extractMessageFromProof( + encoded, + offset, + digest + ); + + (priceInfo, priceId) = extractPriceFeedMessage(encodedMessage); + + return (endOffset, priceInfo, priceId); + } + } + + function extractMessageFromProof( + bytes memory encodedProof, + uint offset, + bytes20 merkleRoot + ) private pure returns (uint endOffset, bytes memory encodedMessage) { + unchecked { + uint16 messageSize = UnsafeBytesLib.toUint16(encodedProof, offset); offset += 2; - bytes memory encodedMessage = UnsafeBytesLib.slice( - encoded, + encodedMessage = UnsafeBytesLib.slice( + encodedProof, offset, messageSize ); offset += messageSize; bool valid; - (valid, offset) = MerkleTree.isProofValid( - encoded, + (valid, endOffset) = MerkleTree.isProofValid( + encodedProof, offset, - digest, + merkleRoot, encodedMessage ); - if (!valid) { revert PythErrors.InvalidUpdateData(); } - - parseAndProcessMessage(encodedMessage); - - return offset; } } + function extractPriceFeedMessage( + bytes memory encodedMessage + ) + private + pure + returns (PythInternalStructs.PriceInfo memory info, bytes32 priceId) + { + unchecked { + MessageType messageType = getMessageType(encodedMessage); + if (messageType == MessageType.PriceFeed) { + (info, priceId) = parsePriceFeedMessage( + UnsafeBytesLib.slice( + encodedMessage, + 1, + encodedMessage.length - 1 + ) + ); + } else { + revert PythErrors.InvalidUpdateData(); + } + } + } + + function getMessageType( + bytes memory encodedMessage + ) private pure returns (MessageType messageType) { + return MessageType(UnsafeBytesLib.toUint8(encodedMessage, 0)); + } + function parsePriceFeedMessage( bytes memory encodedPriceFeed ) @@ -286,38 +373,76 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth { } } - function parseAndProcessMessage(bytes memory encodedMessage) private { + function updatePriceInfosFromAccumulatorUpdate( + bytes calldata accumulatorUpdate + ) internal { + ( + uint offset, + UpdateType updateType + ) = extractUpdateTypeFromAccumulatorHeader(accumulatorUpdate); + + if (updateType != UpdateType.WormholeMerkle) { + revert PythErrors.InvalidUpdateData(); + } + updatePriceInfosFromWormholeMerkle( + UnsafeBytesLib.slice( + accumulatorUpdate, + offset, + accumulatorUpdate.length - offset + ) + ); + } + + function updatePriceInfosFromWormholeMerkle(bytes memory encoded) private { unchecked { - MessageType messageType = MessageType( - UnsafeBytesLib.toUint8(encodedMessage, 0) - ); + ( + uint offset, + bytes20 digest, + uint8 numUpdates + ) = extractWormholeMerkleHeaderDigestAndNumUpdates(encoded); - if (messageType == MessageType.PriceFeed) { - ( - PythInternalStructs.PriceInfo memory info, - bytes32 priceId - ) = parsePriceFeedMessage( - UnsafeBytesLib.slice( - encodedMessage, - 1, - encodedMessage.length - 1 - ) - ); - - uint64 latestPublishTime = latestPriceInfoPublishTime(priceId); - - if (info.publishTime > latestPublishTime) { - setLatestPriceInfo(priceId, info); - emit PriceFeedUpdate( - priceId, - info.publishTime, - info.price, - info.conf - ); - } - } else { - revert PythErrors.InvalidUpdateData(); + for (uint i = 0; i < numUpdates; i++) { + offset = verifyAndUpdatePriceFeedFromMerkleProof( + digest, + encoded, + offset + ); } + + if (offset != encoded.length) revert PythErrors.InvalidUpdateData(); + } + } + + function verifyAndUpdatePriceFeedFromMerkleProof( + bytes20 digest, + bytes memory encoded, + uint offset + ) private returns (uint endOffset) { + PythInternalStructs.PriceInfo memory priceInfo; + bytes32 priceId; + (offset, priceInfo, priceId) = extractPriceFeedFromMerkleProof( + digest, + encoded, + offset + ); + processMessage(priceInfo, priceId); + + return offset; + } + + function processMessage( + PythInternalStructs.PriceInfo memory info, + bytes32 priceId + ) private { + uint64 latestPublishTime = latestPriceInfoPublishTime(priceId); + if (info.publishTime > latestPublishTime) { + setLatestPriceInfo(priceId, info); + emit PriceFeedUpdate( + priceId, + info.publishTime, + info.price, + info.conf + ); } } } diff --git a/target_chains/ethereum/contracts/forge-test/Pyth.WormholeMerkleAccumulator.t.sol b/target_chains/ethereum/contracts/forge-test/Pyth.WormholeMerkleAccumulator.t.sol index 773f0740..875bb9f3 100644 --- a/target_chains/ethereum/contracts/forge-test/Pyth.WormholeMerkleAccumulator.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pyth.WormholeMerkleAccumulator.t.sol @@ -22,6 +22,9 @@ contract PythWormholeMerkleAccumulatorTest is { IPyth public pyth; + // -1 is equal to 0xffffff which is the biggest uint if converted back + uint64 constant MAX_UINT64 = uint64(int64(-1)); + function setUp() public { pyth = IPyth(setUpPyth(setUpWormhole(1))); } @@ -46,6 +49,42 @@ contract PythWormholeMerkleAccumulatorTest is assertEq(emaPrice.publishTime, priceFeedMessage.publishTime); } + function assertParsedPriceFeedEqualsMessage( + PythStructs.PriceFeed memory priceFeed, + PriceFeedMessage memory priceFeedMessage, + bytes32 priceId + ) internal { + assertEq(priceFeed.id, priceId); + assertEq(priceFeed.price.price, priceFeedMessage.price); + assertEq(priceFeed.price.conf, priceFeedMessage.conf); + assertEq(priceFeed.price.expo, priceFeedMessage.expo); + assertEq(priceFeed.price.publishTime, priceFeedMessage.publishTime); + assertEq(priceFeed.emaPrice.price, priceFeedMessage.emaPrice); + assertEq(priceFeed.emaPrice.conf, priceFeedMessage.emaConf); + assertEq(priceFeed.emaPrice.expo, priceFeedMessage.expo); + assertEq(priceFeed.emaPrice.publishTime, priceFeedMessage.publishTime); + } + + function assertParsedPriceFeedStored( + PythStructs.PriceFeed memory priceFeed + ) internal { + PythStructs.Price memory aggregatePrice = pyth.getPriceUnsafe( + priceFeed.id + ); + assertEq(aggregatePrice.price, priceFeed.price.price); + assertEq(aggregatePrice.conf, priceFeed.price.conf); + assertEq(aggregatePrice.expo, priceFeed.price.expo); + assertEq(aggregatePrice.publishTime, priceFeed.price.publishTime); + + PythStructs.Price memory emaPrice = pyth.getEmaPriceUnsafe( + priceFeed.id + ); + assertEq(emaPrice.price, priceFeed.emaPrice.price); + assertEq(emaPrice.conf, priceFeed.emaPrice.conf); + assertEq(emaPrice.expo, priceFeed.emaPrice.expo); + assertEq(emaPrice.publishTime, priceFeed.emaPrice.publishTime); + } + function generateRandomPriceFeedMessage( uint numPriceFeeds ) internal returns (PriceFeedMessage[] memory priceFeedMessages) { @@ -162,11 +201,62 @@ contract PythWormholeMerkleAccumulatorTest is uint updateFee = pyth.getUpdateFee(updateData); + bytes32[] memory priceIds = new bytes32[](3); + priceIds[0] = priceFeedMessages1[0].priceId; + priceIds[1] = priceFeedMessages1[1].priceId; + priceIds[2] = priceFeedMessages2[0].priceId; + // parse price feeds before updating since parsing price feeds should be independent + // of whatever is currently stored in the contract. + PythStructs.PriceFeed[] memory priceFeeds = pyth.parsePriceFeedUpdates{ + value: updateFee + }(updateData, priceIds, 0, MAX_UINT64); + + PriceFeedMessage[] + memory expectedPriceFeedMessages = new PriceFeedMessage[](3); + // Only the first occurrence of a valid priceFeedMessage for a paritcular priceFeed.id + // within an updateData will be parsed which is why we exclude priceFeedMessages2[1] + // since it has the same priceFeed.id as priceFeedMessages1[0] even though it has a later publishTime. + // This is different than how updatePriceFeed behaves which will always update using the data + // of the priceFeedMessage with the latest publishTime for a particular priceFeed.id + expectedPriceFeedMessages[0] = priceFeedMessages1[0]; + expectedPriceFeedMessages[1] = priceFeedMessages1[1]; + expectedPriceFeedMessages[2] = priceFeedMessages2[0]; + for (uint i = 0; i < expectedPriceFeedMessages.length; i++) { + assertParsedPriceFeedEqualsMessage( + priceFeeds[i], + expectedPriceFeedMessages[i], + priceIds[i] + ); + } + + // parse updateData[1] for priceFeedMessages1[0].priceId since this has the latest publishTime + // for that priceId and should be the one that is stored. + bytes32[] memory priceIds1 = new bytes32[](1); + priceIds1[0] = priceFeedMessages1[0].priceId; + bytes[] memory parseUpdateDataInput1 = new bytes[](1); + parseUpdateDataInput1[0] = updateData[1]; + + PythStructs.PriceFeed[] memory priceFeeds1 = pyth.parsePriceFeedUpdates{ + value: updateFee + }(parseUpdateDataInput1, priceIds1, 0, MAX_UINT64); + pyth.updatePriceFeeds{value: updateFee}(updateData); + // check stored price feed information matches updateData assertPriceFeedMessageStored(priceFeedMessages1[1]); assertPriceFeedMessageStored(priceFeedMessages2[0]); assertPriceFeedMessageStored(priceFeedMessages2[1]); + + PythStructs.PriceFeed[] + memory expectedPriceFeeds = new PythStructs.PriceFeed[](3); + expectedPriceFeeds[0] = priceFeeds1[0]; + expectedPriceFeeds[1] = priceFeeds[1]; + expectedPriceFeeds[2] = priceFeeds[2]; + + // check stored price feed information matches parsed price feeds + for (uint i = 0; i < expectedPriceFeeds.length; i++) { + assertParsedPriceFeedStored(expectedPriceFeeds[i]); + } } function testUpdatePriceFeedWithWormholeMerkleIgnoresOutOfOrderUpdateSingleCall() @@ -205,6 +295,28 @@ contract PythWormholeMerkleAccumulatorTest is pyth.updatePriceFeeds{value: updateFee}(updateData); assertPriceFeedMessageStored(priceFeedMessages1[0]); + + bytes32[] memory priceIds = new bytes32[](1); + priceIds[0] = priceFeedMessages1[0].priceId; + + PythStructs.PriceFeed[] memory priceFeeds = pyth.parsePriceFeedUpdates{ + value: updateFee + }(updateData, priceIds, 0, MAX_UINT64); + assertEq(priceFeeds.length, 1); + assertParsedPriceFeedStored(priceFeeds[0]); + + // parsePriceFeedUpdates should return the first priceFeed in the case + // that the updateData contains multiple feeds with the same id. + // Swap the order of updates in updateData to verify that the other priceFeed is returned + bytes[] memory updateData1 = new bytes[](2); + updateData1[0] = updateData[1]; + updateData1[1] = updateData[0]; + + PythStructs.PriceFeed[] memory priceFeeds1 = pyth.parsePriceFeedUpdates{ + value: updateFee + }(updateData1, priceIds, 0, MAX_UINT64); + assertEq(priceFeeds1.length, 1); + assertEq(priceFeeds1[0].price.publishTime, 5); } function testUpdatePriceFeedWithWormholeMerkleIgnoresOutOfOrderUpdateMultiCall() @@ -227,7 +339,6 @@ contract PythWormholeMerkleAccumulatorTest is uint updateFee ) = createWormholeMerkleUpdateData(priceFeedMessages1); pyth.updatePriceFeeds{value: updateFee}(updateData); - assertPriceFeedMessageStored(priceFeedMessages1[0]); (updateData, updateFee) = createWormholeMerkleUpdateData( priceFeedMessages2 @@ -237,6 +348,86 @@ contract PythWormholeMerkleAccumulatorTest is assertPriceFeedMessageStored(priceFeedMessages1[0]); } + function testParsePriceFeedUpdatesWithWormholeMerklWorksWithOurOfOrderUpdateMultiCall() + public + { + PriceFeedMessage[] + memory priceFeedMessages1 = generateRandomPriceFeedMessage(1); + PriceFeedMessage[] + memory priceFeedMessages2 = generateRandomPriceFeedMessage(1); + + // Make the price ids the same + priceFeedMessages2[0].priceId = priceFeedMessages1[0].priceId; + // Adjust the timestamps so the second timestamp is smaller than the first + // Parse should work regardless of what's stored on chain. + priceFeedMessages1[0].publishTime = 10; + priceFeedMessages2[0].publishTime = 5; + + ( + bytes[] memory updateData, + uint updateFee + ) = createWormholeMerkleUpdateData(priceFeedMessages1); + bytes32[] memory priceIds = new bytes32[](1); + priceIds[0] = priceFeedMessages1[0].priceId; + PythStructs.PriceFeed[] memory priceFeeds = pyth.parsePriceFeedUpdates{ + value: updateFee + }(updateData, priceIds, 0, MAX_UINT64); + + // Parse should always return the same value regardless of what's stored on chain. + assertEq(priceFeeds.length, 1); + assertParsedPriceFeedEqualsMessage( + priceFeeds[0], + priceFeedMessages1[0], + priceIds[0] + ); + pyth.updatePriceFeeds{value: updateFee}(updateData); + priceFeeds = pyth.parsePriceFeedUpdates{value: updateFee}( + updateData, + priceIds, + 0, + MAX_UINT64 + ); + assertEq(priceFeeds.length, 1); + assertParsedPriceFeedEqualsMessage( + priceFeeds[0], + priceFeedMessages1[0], + priceIds[0] + ); + + ( + bytes[] memory updateData1, + uint updateFee1 + ) = createWormholeMerkleUpdateData(priceFeedMessages2); + pyth.updatePriceFeeds{value: updateFee1}(updateData1); + // reparse the original updateData should still return the same thing + priceFeeds = pyth.parsePriceFeedUpdates{value: updateFee}( + updateData, + priceIds, + 0, + MAX_UINT64 + ); + assertEq(priceFeeds.length, 1); + assertParsedPriceFeedEqualsMessage( + priceFeeds[0], + priceFeedMessages1[0], + priceIds[0] + ); + + // parsing the second message should return the data based on the second messagef + priceFeeds = pyth.parsePriceFeedUpdates{value: updateFee1}( + updateData1, + priceIds, + 0, + MAX_UINT64 + ); + assertEq(priceFeeds.length, 1); + assertParsedPriceFeedEqualsMessage( + priceFeeds[0], + priceFeedMessages2[0], + priceIds[0] + ); + } + function isNotMatch( bytes memory a, bytes memory b @@ -249,12 +440,23 @@ contract PythWormholeMerkleAccumulatorTest is /// expected value, that item will be forged to be invalid. function createAndForgeWormholeMerkleUpdateData( bytes memory forgeItem - ) public returns (bytes[] memory updateData, uint updateFee) { + ) + public + returns ( + bytes[] memory updateData, + uint updateFee, + bytes32[] memory priceIds + ) + { uint numPriceFeeds = 10; PriceFeedMessage[] memory priceFeedMessages = generateRandomPriceFeedMessage( numPriceFeeds ); + priceIds = new bytes32[](numPriceFeeds); + for (uint i = 0; i < numPriceFeeds; i++) { + priceIds[i] = priceFeedMessages[i].priceId; + } bytes[] memory encodedPriceFeedMessages = encodePriceFeedMessages( priceFeedMessages @@ -327,11 +529,21 @@ contract PythWormholeMerkleAccumulatorTest is // In this test the Wormhole accumulator magic is wrong and the update gets reverted. ( bytes[] memory updateData, - uint updateFee + uint updateFee, + bytes32[] memory priceIds ) = createAndForgeWormholeMerkleUpdateData("whMagic"); vm.expectRevert(PythErrors.InvalidUpdateData.selector); pyth.updatePriceFeeds{value: updateFee}(updateData); + + vm.expectRevert(PythErrors.InvalidUpdateData.selector); + + pyth.parsePriceFeedUpdates{value: updateFee}( + updateData, + priceIds, + 0, + MAX_UINT64 + ); } function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongVAAPayloadUpdateType() @@ -342,11 +554,20 @@ contract PythWormholeMerkleAccumulatorTest is ( bytes[] memory updateData, - uint updateFee + uint updateFee, + bytes32[] memory priceIds ) = createAndForgeWormholeMerkleUpdateData("whUpdateType"); vm.expectRevert(); // Reason: Conversion into non-existent enum type. However it // was not possible to check the revert reason in the test. pyth.updatePriceFeeds{value: updateFee}(updateData); + + vm.expectRevert(); + pyth.parsePriceFeedUpdates{value: updateFee}( + updateData, + priceIds, + 0, + MAX_UINT64 + ); } function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongVAASource() @@ -355,15 +576,34 @@ contract PythWormholeMerkleAccumulatorTest is // In this test the Wormhole message source is wrong. ( bytes[] memory updateData, - uint updateFee + uint updateFee, + bytes32[] memory priceIds ) = createAndForgeWormholeMerkleUpdateData("whSourceAddress"); vm.expectRevert(PythErrors.InvalidUpdateDataSource.selector); pyth.updatePriceFeeds{value: updateFee}(updateData); - (updateData, updateFee) = createAndForgeWormholeMerkleUpdateData( - "whSourceChain" + vm.expectRevert(PythErrors.InvalidUpdateDataSource.selector); + pyth.parsePriceFeedUpdates{value: updateFee}( + updateData, + priceIds, + 0, + MAX_UINT64 ); + + ( + updateData, + updateFee, + priceIds + ) = createAndForgeWormholeMerkleUpdateData("whSourceChain"); vm.expectRevert(PythErrors.InvalidUpdateDataSource.selector); pyth.updatePriceFeeds{value: updateFee}(updateData); + + vm.expectRevert(PythErrors.InvalidUpdateDataSource.selector); + pyth.parsePriceFeedUpdates{value: updateFee}( + updateData, + priceIds, + 0, + MAX_UINT64 + ); } function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongRootDigest() @@ -372,10 +612,19 @@ contract PythWormholeMerkleAccumulatorTest is // In this test the Wormhole merkle proof digest is wrong ( bytes[] memory updateData, - uint updateFee + uint updateFee, + bytes32[] memory priceIds ) = createAndForgeWormholeMerkleUpdateData("rootDigest"); vm.expectRevert(PythErrors.InvalidUpdateData.selector); pyth.updatePriceFeeds{value: updateFee}(updateData); + + vm.expectRevert(PythErrors.InvalidUpdateData.selector); + pyth.parsePriceFeedUpdates{value: updateFee}( + updateData, + priceIds, + 0, + MAX_UINT64 + ); } function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongProofItem() @@ -384,10 +633,19 @@ contract PythWormholeMerkleAccumulatorTest is // In this test all Wormhole merkle proof items are the first item proof ( bytes[] memory updateData, - uint updateFee + uint updateFee, + bytes32[] memory priceIds ) = createAndForgeWormholeMerkleUpdateData("proofItem"); vm.expectRevert(PythErrors.InvalidUpdateData.selector); pyth.updatePriceFeeds{value: updateFee}(updateData); + + vm.expectRevert(PythErrors.InvalidUpdateData.selector); + pyth.parsePriceFeedUpdates{value: updateFee}( + updateData, + priceIds, + 0, + MAX_UINT64 + ); } function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongHeader() @@ -396,17 +654,35 @@ contract PythWormholeMerkleAccumulatorTest is // In this test the message headers are wrong ( bytes[] memory updateData, - uint updateFee + uint updateFee, + bytes32[] memory priceIds ) = createAndForgeWormholeMerkleUpdateData("headerMagic"); vm.expectRevert(); // The revert reason is not deterministic because when it doesn't match it goes through // the old approach. pyth.updatePriceFeeds{value: updateFee}(updateData); - - (updateData, updateFee) = createAndForgeWormholeMerkleUpdateData( - "headerMajorVersion" + vm.expectRevert(); + pyth.parsePriceFeedUpdates{value: updateFee}( + updateData, + priceIds, + 0, + MAX_UINT64 ); + + ( + updateData, + updateFee, + priceIds + ) = createAndForgeWormholeMerkleUpdateData("headerMajorVersion"); vm.expectRevert(PythErrors.InvalidUpdateData.selector); pyth.updatePriceFeeds{value: updateFee}(updateData); + + vm.expectRevert(PythErrors.InvalidUpdateData.selector); + pyth.parsePriceFeedUpdates{value: updateFee}( + updateData, + priceIds, + 0, + MAX_UINT64 + ); } function testUpdatePriceFeedWithWormholeMerkleRevertsIfUpdateFeeIsNotPaid() @@ -421,7 +697,222 @@ contract PythWormholeMerkleAccumulatorTest is priceFeedMessages ); + bytes32[] memory priceIds = new bytes32[](numPriceFeeds); + for (uint i = 0; i < numPriceFeeds; i++) { + priceIds[i] = priceFeedMessages[i].priceId; + } vm.expectRevert(PythErrors.InsufficientFee.selector); pyth.updatePriceFeeds{value: 0}(updateData); + + vm.expectRevert(PythErrors.InsufficientFee.selector); + pyth.parsePriceFeedUpdates{value: 0}( + updateData, + priceIds, + 0, + MAX_UINT64 + ); } + + function testParsePriceFeedWithWormholeMerkleWorks(uint seed) public { + setRandSeed(seed); + + uint numPriceFeeds = (getRandUint() % 10) + 1; + PriceFeedMessage[] + memory priceFeedMessages = generateRandomPriceFeedMessage( + numPriceFeeds + ); + ( + bytes[] memory updateData, + uint updateFee + ) = createWormholeMerkleUpdateData(priceFeedMessages); + bytes32[] memory priceIds = new bytes32[](numPriceFeeds); + for (uint i = 0; i < numPriceFeeds; i++) { + priceIds[i] = priceFeedMessages[i].priceId; + } + PythStructs.PriceFeed[] memory priceFeeds = pyth.parsePriceFeedUpdates{ + value: updateFee + }(updateData, priceIds, 0, MAX_UINT64); + + for (uint i = 0; i < priceFeeds.length; i++) { + assertParsedPriceFeedEqualsMessage( + priceFeeds[i], + priceFeedMessages[i], + priceIds[i] + ); + } + + // update priceFeedMessages + for (uint i = 0; i < numPriceFeeds; i++) { + priceFeedMessages[i].price = getRandInt64(); + priceFeedMessages[i].conf = getRandUint64(); + priceFeedMessages[i].expo = getRandInt32(); + priceFeedMessages[i].publishTime = getRandUint64(); + priceFeedMessages[i].emaPrice = getRandInt64(); + priceFeedMessages[i].emaConf = getRandUint64(); + } + + (updateData, updateFee) = createWormholeMerkleUpdateData( + priceFeedMessages + ); + + // reparse + priceFeeds = pyth.parsePriceFeedUpdates{value: updateFee}( + updateData, + priceIds, + 0, + MAX_UINT64 + ); + + for (uint i = 0; i < priceFeeds.length; i++) { + assertParsedPriceFeedEqualsMessage( + priceFeeds[i], + priceFeedMessages[i], + priceIds[i] + ); + } + } + + function testParsePriceFeedWithWormholeMerkleWorksRandomDistinctUpdatesInput( + uint seed + ) public { + setRandSeed(seed); + + uint numPriceFeeds = (getRandUint() % 10) + 1; + PriceFeedMessage[] + memory priceFeedMessages = generateRandomPriceFeedMessage( + numPriceFeeds + ); + + ( + bytes[] memory updateData, + uint updateFee + ) = createWormholeMerkleUpdateData(priceFeedMessages); + bytes32[] memory priceIds = new bytes32[](numPriceFeeds); + for (uint i = 0; i < numPriceFeeds; i++) { + priceIds[i] = priceFeedMessages[i].priceId; + } + + // Shuffle the priceFeedMessages + for (uint i = 1; i < numPriceFeeds; i++) { + uint swapWith = getRandUint() % (i + 1); + (priceFeedMessages[i], priceFeedMessages[swapWith]) = ( + priceFeedMessages[swapWith], + priceFeedMessages[i] + ); + (priceIds[i], priceIds[swapWith]) = ( + priceIds[swapWith], + priceIds[i] + ); + } + + // Select only first numSelectedPriceFeeds. numSelectedPriceFeeds will be in [0, numPriceFeeds] + uint numSelectedPriceFeeds = getRandUint() % (numPriceFeeds + 1); + + PriceFeedMessage[] + memory selectedPriceFeedsMessages = new PriceFeedMessage[]( + numSelectedPriceFeeds + ); + bytes32[] memory selectedPriceIds = new bytes32[]( + numSelectedPriceFeeds + ); + + for (uint i = 0; i < numSelectedPriceFeeds; i++) { + selectedPriceFeedsMessages[i] = priceFeedMessages[i]; + selectedPriceIds[i] = priceIds[i]; + } + + PythStructs.PriceFeed[] memory priceFeeds = pyth.parsePriceFeedUpdates{ + value: updateFee + }(updateData, selectedPriceIds, 0, MAX_UINT64); + for (uint i = 0; i < numSelectedPriceFeeds; i++) { + assertParsedPriceFeedEqualsMessage( + priceFeeds[i], + selectedPriceFeedsMessages[i], + selectedPriceIds[i] + ); + } + } + + function testParsePriceFeedWithWormholeMerkleRevertsIfPriceIdNotIncluded() + public + { + PriceFeedMessage[] memory priceFeedMessages = new PriceFeedMessage[](1); + priceFeedMessages[0] = PriceFeedMessage({ + priceId: bytes32(uint(1)), + price: getRandInt64(), + conf: getRandUint64(), + expo: getRandInt32(), + publishTime: getRandUint64(), + prevPublishTime: getRandUint64(), + emaPrice: getRandInt64(), + emaConf: getRandUint64() + }); + + ( + bytes[] memory updateData, + uint updateFee + ) = createWormholeMerkleUpdateData(priceFeedMessages); + bytes32[] memory priceIds = new bytes32[](1); + priceIds[0] = bytes32(uint(2)); + + vm.expectRevert(PythErrors.PriceFeedNotFoundWithinRange.selector); + pyth.parsePriceFeedUpdates{value: updateFee}( + updateData, + priceIds, + 0, + MAX_UINT64 + ); + } + + function testParsePriceFeedUpdateRevertsIfPricesOutOfTimeRange() public { + uint numPriceFeeds = (getRandUint() % 10) + 1; + PriceFeedMessage[] + memory priceFeedMessages = generateRandomPriceFeedMessage( + numPriceFeeds + ); + for (uint i = 0; i < numPriceFeeds; i++) { + priceFeedMessages[i].publishTime = uint64( + 100 + (getRandUint() % 101) + ); // All between [100, 200] + } + + ( + bytes[] memory updateData, + uint updateFee + ) = createWormholeMerkleUpdateData(priceFeedMessages); + + bytes32[] memory priceIds = new bytes32[](numPriceFeeds); + for (uint i = 0; i < numPriceFeeds; i++) { + priceIds[i] = priceFeedMessages[i].priceId; + } + + // Request for parse within the given time range should work + pyth.parsePriceFeedUpdates{value: updateFee}( + updateData, + priceIds, + 100, + 200 + ); + + // Request for parse after the time range should revert. + vm.expectRevert(PythErrors.PriceFeedNotFoundWithinRange.selector); + pyth.parsePriceFeedUpdates{value: updateFee}( + updateData, + priceIds, + 300, + MAX_UINT64 + ); + + // Request for parse before the time range should revert. + vm.expectRevert(PythErrors.PriceFeedNotFoundWithinRange.selector); + pyth.parsePriceFeedUpdates{value: updateFee}( + updateData, + priceIds, + 0, + 99 + ); + } + + //TODO: add some tests of forward compatibility. + // I.e., create a message where each part that can be expanded in size is expanded and make sure that parsing still works } diff --git a/target_chains/ethereum/contracts/forge-test/Pyth.t.sol b/target_chains/ethereum/contracts/forge-test/Pyth.t.sol index 0fcaf7bb..42e5b1e6 100644 --- a/target_chains/ethereum/contracts/forge-test/Pyth.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pyth.t.sol @@ -15,7 +15,7 @@ import "./utils/RandTestUtils.t.sol"; contract PythTest is Test, WormholeTestUtils, PythTestUtils, RandTestUtils { IPyth public pyth; - // -1 is equal to 0x111111 which is the biggest uint if converted back + // -1 is equal to 0xffffff which is the biggest uint if converted back uint64 constant MAX_UINT64 = uint64(int64(-1)); function setUp() public {