From ac97b4d35d2b352bf0bbd0d8455511202c860fbe Mon Sep 17 00:00:00 2001 From: swimricky <86628128+swimricky@users.noreply.github.com> Date: Thu, 6 Jul 2023 08:29:08 -0400 Subject: [PATCH] [eth] - Aave FallbackOracle Integration (#924) * feat(eth): aave integration Add IPriceOracleGetter and PythAssetRegistry mapping * feat(eth): remove IPriceOracleGetter from PythAssetRegistryGetter * refactor(eth): flatten PythAssetRegistySetter/Getter into PythAssetRegistry * feat(eth): address feedback move aave related contracts into separate directory, add explicit exponent/decimal handling, add staleness check * refactor(eth): minor rename to avoid shadowing * fix(eth): handle exponent conversion and add tests * chore(eth): remove unused console import * feat(eth): address PR feedback add more checks, tests & minor refactoring * feat(eth): add more tests and address feedback --- .../contracts/aave/PythAssetRegistry.sol | 79 ++++ .../contracts/aave/PythPriceOracleGetter.sol | 103 +++++ .../aave/interfaces/IPriceOracleGetter.sol | 31 ++ .../contracts/forge-test/GasBenchmark.t.sol | 6 +- .../contracts/forge-test/Pyth.Aave.t.sol | 368 ++++++++++++++++++ .../Pyth.WormholeMerkleAccumulator.t.sol | 24 +- .../forge-test/utils/PythTestUtils.t.sol | 1 + .../forge-test/utils/RandTestUtils.t.sol | 4 + 8 files changed, 612 insertions(+), 4 deletions(-) create mode 100644 target_chains/ethereum/contracts/contracts/aave/PythAssetRegistry.sol create mode 100644 target_chains/ethereum/contracts/contracts/aave/PythPriceOracleGetter.sol create mode 100644 target_chains/ethereum/contracts/contracts/aave/interfaces/IPriceOracleGetter.sol create mode 100644 target_chains/ethereum/contracts/forge-test/Pyth.Aave.t.sol diff --git a/target_chains/ethereum/contracts/contracts/aave/PythAssetRegistry.sol b/target_chains/ethereum/contracts/contracts/aave/PythAssetRegistry.sol new file mode 100644 index 00000000..3160ec5d --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/aave/PythAssetRegistry.sol @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: Apache 2 +pragma solidity ^0.8.0; + +import "@pythnetwork/pyth-sdk-solidity/IPyth.sol"; + +error InconsistentParamsLength(); + +contract PythAssetRegistryStorage { + struct State { + address pyth; + address BASE_CURRENCY; + uint256 BASE_CURRENCY_UNIT; + // Map of asset priceIds (asset => priceId) + mapping(address => bytes32) assetsPriceIds; + /// Maximum acceptable time period before price is considered to be stale. + /// This includes attestation delay, block time, and potential clock drift + /// between the source/target chains. + uint validTimePeriodSeconds; + } +} + +contract PythAssetRegistry { + PythAssetRegistryStorage.State _registryState; + + /** + * @dev Emitted after the base currency is set + * @param baseCurrency The base currency of used for price quotes + * @param baseCurrencyUnit The unit of the base currency + */ + event BaseCurrencySet( + address indexed baseCurrency, + uint256 baseCurrencyUnit + ); + + /** + * @dev Emitted after the price source of an asset is updated + * @param asset The address of the asset + * @param source The priceId of the asset + */ + event AssetSourceUpdated(address indexed asset, bytes32 indexed source); + + function pyth() public view returns (IPyth) { + return IPyth(_registryState.pyth); + } + + function setPyth(address pythAddress) internal { + _registryState.pyth = payable(pythAddress); + } + + function setAssetsSources( + address[] memory assets, + bytes32[] memory priceIds + ) internal { + if (assets.length != priceIds.length) { + revert InconsistentParamsLength(); + } + for (uint256 i = 0; i < assets.length; i++) { + _registryState.assetsPriceIds[assets[i]] = priceIds[i]; + emit AssetSourceUpdated(assets[i], priceIds[i]); + } + } + + function setBaseCurrency( + address baseCurrency, + uint256 baseCurrencyUnit + ) internal { + _registryState.BASE_CURRENCY = baseCurrency; + _registryState.BASE_CURRENCY_UNIT = baseCurrencyUnit; + emit BaseCurrencySet(baseCurrency, baseCurrencyUnit); + } + + function setValidTimePeriodSeconds(uint validTimePeriodInSeconds) internal { + _registryState.validTimePeriodSeconds = validTimePeriodInSeconds; + } + + function validTimePeriodSeconds() public view returns (uint) { + return _registryState.validTimePeriodSeconds; + } +} diff --git a/target_chains/ethereum/contracts/contracts/aave/PythPriceOracleGetter.sol b/target_chains/ethereum/contracts/contracts/aave/PythPriceOracleGetter.sol new file mode 100644 index 00000000..cc0f800a --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/aave/PythPriceOracleGetter.sol @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: Apache 2 +pragma solidity ^0.8.0; + +import "@pythnetwork/pyth-sdk-solidity/PythErrors.sol"; +import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol"; + +import "./interfaces/IPriceOracleGetter.sol"; +import "./PythAssetRegistry.sol"; + +/// Invalid non-positive price +error InvalidNonPositivePrice(); +/// Normalization overflow +error NormalizationOverflow(); +/// Invalid Base Currency Unit value. Must be power of 10. +error InvalidBaseCurrencyUnit(); + +contract PythPriceOracleGetter is PythAssetRegistry, IPriceOracleGetter { + /// @inheritdoc IPriceOracleGetter + address public immutable override BASE_CURRENCY; + /** + * @notice Returns the base currency unit + * @dev 1 ether for ETH, 1e8 for USD. + * @return Returns the base currency unit. + */ + uint256 public immutable override BASE_CURRENCY_UNIT; + /// BASE_CURRENCY_UNIT as a power of 10 + uint8 public immutable BASE_NUM_DECIMALS; + + constructor( + address pyth, + address[] memory assets, + bytes32[] memory priceIds, + address baseCurrency, + uint256 baseCurrencyUnit, + uint validTimePeriodSeconds + ) { + if (baseCurrencyUnit == 0) { + revert InvalidBaseCurrencyUnit(); + } + PythAssetRegistry.setPyth(pyth); + PythAssetRegistry.setAssetsSources(assets, priceIds); + PythAssetRegistry.setBaseCurrency(baseCurrency, baseCurrencyUnit); + BASE_CURRENCY = _registryState.BASE_CURRENCY; + BASE_CURRENCY_UNIT = _registryState.BASE_CURRENCY_UNIT; + if ((10 ** baseNumDecimals(baseCurrencyUnit)) != baseCurrencyUnit) { + revert InvalidBaseCurrencyUnit(); + } + BASE_NUM_DECIMALS = baseNumDecimals(baseCurrencyUnit); + PythAssetRegistry.setValidTimePeriodSeconds(validTimePeriodSeconds); + } + + /// @inheritdoc IPriceOracleGetter + function getAssetPrice( + address asset + ) external view override returns (uint256) { + bytes32 priceId = _registryState.assetsPriceIds[asset]; + if (asset == BASE_CURRENCY) { + return BASE_CURRENCY_UNIT; + } + if (priceId == 0) { + revert PythErrors.PriceFeedNotFound(); + } + PythStructs.Price memory price = pyth().getPriceNoOlderThan( + priceId, + PythAssetRegistry.validTimePeriodSeconds() + ); + + // Aave is not using any price feeds < 0 for now. + if (price.price <= 0) { + revert InvalidNonPositivePrice(); + } + uint256 normalizedPrice = uint64(price.price); + int32 normalizerExpo = price.expo + int8(BASE_NUM_DECIMALS); + bool isNormalizerExpoNeg = normalizerExpo < 0; + uint256 normalizer = isNormalizerExpoNeg + ? 10 ** uint32(-normalizerExpo) + : 10 ** uint32(normalizerExpo); + + // this check prevents overflow in normalized price + if (!isNormalizerExpoNeg && normalizer > type(uint192).max) { + revert NormalizationOverflow(); + } + + normalizedPrice = isNormalizerExpoNeg + ? normalizedPrice / normalizer + : normalizedPrice * normalizer; + + if (normalizedPrice <= 0) { + revert InvalidNonPositivePrice(); + } + + return normalizedPrice; + } + + function baseNumDecimals(uint number) private pure returns (uint8) { + uint8 digits = 0; + while (number != 0) { + number /= 10; + digits++; + } + return digits - 1; + } +} diff --git a/target_chains/ethereum/contracts/contracts/aave/interfaces/IPriceOracleGetter.sol b/target_chains/ethereum/contracts/contracts/aave/interfaces/IPriceOracleGetter.sol new file mode 100644 index 00000000..248e89e5 --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/aave/interfaces/IPriceOracleGetter.sol @@ -0,0 +1,31 @@ +// contracts/pyth/aave/PythPriceOracleGetter.sol +// SPDX-License-Identifier: AGPL-3.0 +pragma solidity ^0.8.0; + +/** + * @title IPriceOracleGetter + * @author Aave + * @notice Interface for the Aave price oracle. + */ +interface IPriceOracleGetter { + /** + * @notice Returns the base currency address + * @dev Address 0x0 is reserved for USD as base currency. + * @return Returns the base currency address. + */ + function BASE_CURRENCY() external view returns (address); + + /** + * @notice Returns the base currency unit + * @dev 1 ether for ETH, 1e8 for USD. + * @return Returns the base currency unit. + */ + function BASE_CURRENCY_UNIT() external view returns (uint256); + + /** + * @notice Returns the asset price in the base currency + * @param asset The address of the asset + * @return The price of the asset + */ + function getAssetPrice(address asset) external view returns (uint256); +} diff --git a/target_chains/ethereum/contracts/forge-test/GasBenchmark.t.sol b/target_chains/ethereum/contracts/forge-test/GasBenchmark.t.sol index 44d1da85..7a554b73 100644 --- a/target_chains/ethereum/contracts/forge-test/GasBenchmark.t.sol +++ b/target_chains/ethereum/contracts/forge-test/GasBenchmark.t.sol @@ -49,7 +49,7 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils { uint[] freshPricesWhMerkleUpdateFee; // i th element contains the update fee for the first i prices uint64 sequence; - uint randSeed; + uint randomSeed; function setUp() public { address wormholeAddr = setUpWormholeReceiver(NUM_GUARDIANS); @@ -120,8 +120,8 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils { } function getRand() internal returns (uint val) { - ++randSeed; - val = uint(keccak256(abi.encode(randSeed))); + ++randomSeed; + val = uint(keccak256(abi.encode(randomSeed))); } function generateWhBatchUpdateDataAndFee( diff --git a/target_chains/ethereum/contracts/forge-test/Pyth.Aave.t.sol b/target_chains/ethereum/contracts/forge-test/Pyth.Aave.t.sol new file mode 100644 index 00000000..e718c705 --- /dev/null +++ b/target_chains/ethereum/contracts/forge-test/Pyth.Aave.t.sol @@ -0,0 +1,368 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +import "forge-std/Test.sol"; + +import "@pythnetwork/pyth-sdk-solidity/IPyth.sol"; +import "@pythnetwork/pyth-sdk-solidity/PythErrors.sol"; +import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol"; +import "./utils/WormholeTestUtils.t.sol"; +import "./utils/PythTestUtils.t.sol"; +import "./utils/RandTestUtils.t.sol"; + +import "../contracts/aave/interfaces/IPriceOracleGetter.sol"; +import "../contracts/aave/PythPriceOracleGetter.sol"; +import "./Pyth.WormholeMerkleAccumulator.t.sol"; + +contract PythAaveTest is PythWormholeMerkleAccumulatorTest { + IPriceOracleGetter public pythOracleGetter; + address[] assets; + bytes32[] priceIds; + uint constant NUM_PRICE_FEEDS = 5; + uint256 constant BASE_CURRENCY_UNIT = 1e8; + uint constant VALID_TIME_PERIOD_SECS = 60; + + function setUp() public override { + pyth = IPyth(setUpPyth(setUpWormholeReceiver(1))); + assets = new address[](NUM_PRICE_FEEDS); + PriceFeedMessage[] + memory priceFeedMessages = generateRandomBoundedPriceFeedMessage( + NUM_PRICE_FEEDS + ); + priceIds = new bytes32[](NUM_PRICE_FEEDS); + + for (uint i = 0; i < NUM_PRICE_FEEDS; i++) { + assets[i] = address( + uint160(uint(keccak256(abi.encodePacked(i + NUM_PRICE_FEEDS)))) + ); + priceIds[i] = priceFeedMessages[i].priceId; + } + + ( + bytes[] memory updateData, + uint updateFee + ) = createWormholeMerkleUpdateData(priceFeedMessages); + pyth.updatePriceFeeds{value: updateFee}(updateData); + + pythOracleGetter = new PythPriceOracleGetter( + address(pyth), + assets, + priceIds, + address(0x0), + BASE_CURRENCY_UNIT, + VALID_TIME_PERIOD_SECS + ); + } + + function testConversion( + int64 pythPrice, + int32 pythExpo, + uint256 aavePrice, + uint256 baseCurrencyUnit + ) private { + PriceFeedMessage[] memory priceFeedMessages = new PriceFeedMessage[](1); + PriceFeedMessage memory priceFeedMessage = PriceFeedMessage({ + priceId: getRandBytes32(), + price: pythPrice, + conf: getRandUint64(), + expo: pythExpo, + publishTime: uint64(1), + prevPublishTime: getRandUint64(), + emaPrice: getRandInt64(), + emaConf: getRandUint64() + }); + priceFeedMessages[0] = priceFeedMessage; + + ( + bytes[] memory updateData, + uint updateFee + ) = createWormholeMerkleUpdateData(priceFeedMessages); + pyth.updatePriceFeeds{value: updateFee}(updateData); + + priceIds = new bytes32[](1); + priceIds[0] = priceFeedMessage.priceId; + assets = new address[](1); + assets[0] = address( + uint160(uint(keccak256(abi.encodePacked(uint(100))))) + ); + + pythOracleGetter = new PythPriceOracleGetter( + address(pyth), + assets, + priceIds, + address(0x0), + baseCurrencyUnit, + VALID_TIME_PERIOD_SECS + ); + + assertEq(pythOracleGetter.getAssetPrice(assets[0]), aavePrice); + } + + function testGetAssetPriceWorks() public { + // "display" price is 529.30903 + testConversion(52_930_903, -5, 52_930_903_000, BASE_CURRENCY_UNIT); + } + + function testGetAssetPriceWorksWithPositiveExponent() public { + // "display" price is 5_293_000 + testConversion(5_293, 3, 529_300_000_000_000, BASE_CURRENCY_UNIT); + } + + function testGetAssetPriceWorksWithZeroExponent() public { + // "display" price is 5_293 + testConversion(5_293, 0, 529_300_000_000, BASE_CURRENCY_UNIT); + } + + function testGetAssetPriceWorksWithNegativeNormalizerExponent() public { + // "display" price is 5_293 + testConversion( + 5_293_000_000_000_000, + -12, + 529_300_000_000, + BASE_CURRENCY_UNIT + ); + } + + function testGetAssetPriceWorksWithBaseCurrencyUnitOfOne() public { + // "display" price is 529.30903 + testConversion(52_930_903, -5, 529, 1); + } + + function testGetAssetPriceWorksWithBoundedRandomValues(uint seed) public { + setRandSeed(seed); + + for (uint i = 0; i < assets.length; i++) { + address asset = assets[i]; + uint256 assetPrice = pythOracleGetter.getAssetPrice(asset); + uint256 aavePrice = assetPrice / BASE_CURRENCY_UNIT; + + bytes32 priceId = priceIds[i]; + PythStructs.Price memory price = pyth.getPrice(priceId); + int64 pythRawPrice = price.price; + uint pythNormalizer; + uint pythPrice; + if (price.expo < 0) { + pythNormalizer = 10 ** uint32(-price.expo); + pythPrice = uint64(pythRawPrice) / pythNormalizer; + } else { + pythNormalizer = 10 ** uint32(price.expo); + pythPrice = uint64(pythRawPrice) * pythNormalizer; + } + assertEq(aavePrice, pythPrice); + } + } + + function testGetAssetPriceWorksIfGivenBaseCurrencyAddress() public { + address usdAddress = address(0x0); + uint256 assetPrice = pythOracleGetter.getAssetPrice(usdAddress); + assertEq(assetPrice, BASE_CURRENCY_UNIT); + } + + function testGetAssetRevertsIfPriceNotRecentEnough() public { + uint timestamp = block.timestamp; + vm.warp(timestamp + VALID_TIME_PERIOD_SECS); + for (uint i = 0; i < assets.length; i++) { + pythOracleGetter.getAssetPrice(assets[i]); + } + vm.warp(timestamp + VALID_TIME_PERIOD_SECS + 1); + for (uint i = 0; i < assets.length; i++) { + vm.expectRevert(PythErrors.StalePrice.selector); + pythOracleGetter.getAssetPrice(assets[i]); + } + } + + function testGetAssetRevertsIfPriceFeedNotFound() public { + address addr = address( + uint160(uint(keccak256(abi.encodePacked(uint(100))))) + ); + vm.expectRevert(PythErrors.PriceFeedNotFound.selector); + pythOracleGetter.getAssetPrice(addr); + } + + function testGetAssetPriceRevertsIfPriceIsNegative() public { + PriceFeedMessage[] memory priceFeedMessages = new PriceFeedMessage[](1); + PriceFeedMessage memory priceFeedMessage = PriceFeedMessage({ + priceId: getRandBytes32(), + price: int64(-5), + conf: getRandUint64(), + expo: getRandInt32(), + publishTime: uint64(1), + prevPublishTime: getRandUint64(), + emaPrice: getRandInt64(), + emaConf: getRandUint64() + }); + + priceFeedMessages[0] = priceFeedMessage; + + ( + bytes[] memory updateData, + uint updateFee + ) = createWormholeMerkleUpdateData(priceFeedMessages); + pyth.updatePriceFeeds{value: updateFee}(updateData); + + priceIds = new bytes32[](1); + priceIds[0] = priceFeedMessage.priceId; + assets = new address[](1); + assets[0] = address( + uint160(uint(keccak256(abi.encodePacked(uint(100))))) + ); + + pythOracleGetter = new PythPriceOracleGetter( + address(pyth), + assets, + priceIds, + address(0x0), + BASE_CURRENCY_UNIT, + VALID_TIME_PERIOD_SECS + ); + + vm.expectRevert(abi.encodeWithSignature("InvalidNonPositivePrice()")); + pythOracleGetter.getAssetPrice(assets[0]); + } + + function testGetAssetPriceRevertsIfNormalizerOverflows() public { + PriceFeedMessage[] memory priceFeedMessages = new PriceFeedMessage[](1); + PriceFeedMessage memory priceFeedMessage = PriceFeedMessage({ + priceId: getRandBytes32(), + price: int64(1), + conf: getRandUint64(), + expo: int32(59), // type(uint192).max = ~6.27e58 + publishTime: uint64(1), + prevPublishTime: getRandUint64(), + emaPrice: getRandInt64(), + emaConf: getRandUint64() + }); + + priceFeedMessages[0] = priceFeedMessage; + + ( + bytes[] memory updateData, + uint updateFee + ) = createWormholeMerkleUpdateData(priceFeedMessages); + pyth.updatePriceFeeds{value: updateFee}(updateData); + + priceIds = new bytes32[](1); + priceIds[0] = priceFeedMessage.priceId; + assets = new address[](1); + assets[0] = address( + uint160(uint(keccak256(abi.encodePacked(uint(100))))) + ); + + pythOracleGetter = new PythPriceOracleGetter( + address(pyth), + assets, + priceIds, + address(0x0), + BASE_CURRENCY_UNIT, + VALID_TIME_PERIOD_SECS + ); + + vm.expectRevert(abi.encodeWithSignature("NormalizationOverflow()")); + pythOracleGetter.getAssetPrice(assets[0]); + } + + function testGetAssetPriceRevertsIfNormalizedToZero() public { + PriceFeedMessage[] memory priceFeedMessages = new PriceFeedMessage[](1); + PriceFeedMessage memory priceFeedMessage = PriceFeedMessage({ + priceId: getRandBytes32(), + price: int64(1), + conf: getRandUint64(), + expo: int32(-75), + publishTime: uint64(1), + prevPublishTime: getRandUint64(), + emaPrice: getRandInt64(), + emaConf: getRandUint64() + }); + + priceFeedMessages[0] = priceFeedMessage; + + ( + bytes[] memory updateData, + uint updateFee + ) = createWormholeMerkleUpdateData(priceFeedMessages); + pyth.updatePriceFeeds{value: updateFee}(updateData); + + priceIds = new bytes32[](1); + priceIds[0] = priceFeedMessage.priceId; + assets = new address[](1); + assets[0] = address( + uint160(uint(keccak256(abi.encodePacked(uint(100))))) + ); + + pythOracleGetter = new PythPriceOracleGetter( + address(pyth), + assets, + priceIds, + address(0x0), + BASE_CURRENCY_UNIT, + VALID_TIME_PERIOD_SECS + ); + + vm.expectRevert(abi.encodeWithSignature("InvalidNonPositivePrice()")); + pythOracleGetter.getAssetPrice(assets[0]); + } + + function testPythPriceOracleGetterConstructorRevertsIfAssetsAndPriceIdsLengthAreDifferent() + public + { + priceIds = new bytes32[](2); + priceIds[0] = getRandBytes32(); + priceIds[1] = getRandBytes32(); + assets = new address[](1); + assets[0] = address( + uint160(uint(keccak256(abi.encodePacked(uint(100))))) + ); + + vm.expectRevert(abi.encodeWithSignature("InconsistentParamsLength()")); + pythOracleGetter = new PythPriceOracleGetter( + address(pyth), + assets, + priceIds, + address(0x0), + BASE_CURRENCY_UNIT, + VALID_TIME_PERIOD_SECS + ); + } + + function testPythPriceOracleGetterConstructorRevertsIfInvalidBaseCurrencyUnit() + public + { + priceIds = new bytes32[](1); + priceIds[0] = getRandBytes32(); + assets = new address[](1); + assets[0] = address( + uint160(uint(keccak256(abi.encodePacked(uint(100))))) + ); + + vm.expectRevert(abi.encodeWithSignature("InvalidBaseCurrencyUnit()")); + pythOracleGetter = new PythPriceOracleGetter( + address(pyth), + assets, + priceIds, + address(0x0), + 0, + VALID_TIME_PERIOD_SECS + ); + + vm.expectRevert(abi.encodeWithSignature("InvalidBaseCurrencyUnit()")); + pythOracleGetter = new PythPriceOracleGetter( + address(pyth), + assets, + priceIds, + address(0x0), + 11, + VALID_TIME_PERIOD_SECS + ); + + vm.expectRevert(abi.encodeWithSignature("InvalidBaseCurrencyUnit()")); + pythOracleGetter = new PythPriceOracleGetter( + address(pyth), + assets, + priceIds, + address(0x0), + 20, + VALID_TIME_PERIOD_SECS + ); + } +} diff --git a/target_chains/ethereum/contracts/forge-test/Pyth.WormholeMerkleAccumulator.t.sol b/target_chains/ethereum/contracts/forge-test/Pyth.WormholeMerkleAccumulator.t.sol index 4e00a3e8..8da11270 100644 --- a/target_chains/ethereum/contracts/forge-test/Pyth.WormholeMerkleAccumulator.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pyth.WormholeMerkleAccumulator.t.sol @@ -25,7 +25,7 @@ contract PythWormholeMerkleAccumulatorTest is // -1 is equal to 0xffffff which is the biggest uint if converted back uint64 constant MAX_UINT64 = uint64(int64(-1)); - function setUp() public { + function setUp() public virtual { pyth = IPyth(setUpPyth(setUpWormholeReceiver(1))); } @@ -121,6 +121,28 @@ contract PythWormholeMerkleAccumulatorTest is } } + /** + * @notice Returns `numPriceFeeds` random price feed messages with price & expo bounded + * to realistic values and publishTime set to 1. + */ + function generateRandomBoundedPriceFeedMessage( + uint numPriceFeeds + ) internal returns (PriceFeedMessage[] memory priceFeedMessages) { + priceFeedMessages = new PriceFeedMessage[](numPriceFeeds); + for (uint i = 0; i < numPriceFeeds; i++) { + priceFeedMessages[i] = PriceFeedMessage({ + priceId: getRandBytes32(), + price: int64(getRandUint64() / 10), // assuming price should always be positive + conf: getRandUint64(), + expo: int32(getRandInt8() % 13), // pyth contract guarantees that expo between [-12, 12] + publishTime: uint64(1), + prevPublishTime: getRandUint64(), + emaPrice: getRandInt64(), + emaConf: getRandUint64() + }); + } + } + function createWormholeMerkleUpdateData( PriceFeedMessage[] memory priceFeedMessages ) internal returns (bytes[] memory updateData, uint updateFee) { diff --git a/target_chains/ethereum/contracts/forge-test/utils/PythTestUtils.t.sol b/target_chains/ethereum/contracts/forge-test/utils/PythTestUtils.t.sol index 58d9065a..db237f7d 100644 --- a/target_chains/ethereum/contracts/forge-test/utils/PythTestUtils.t.sol +++ b/target_chains/ethereum/contracts/forge-test/utils/PythTestUtils.t.sol @@ -15,6 +15,7 @@ import "@pythnetwork/pyth-sdk-solidity/IPyth.sol"; import "forge-std/Test.sol"; import "./WormholeTestUtils.t.sol"; +import "./RandTestUtils.t.sol"; abstract contract PythTestUtils is Test, WormholeTestUtils { uint16 constant SOURCE_EMITTER_CHAIN_ID = 0x1; diff --git a/target_chains/ethereum/contracts/forge-test/utils/RandTestUtils.t.sol b/target_chains/ethereum/contracts/forge-test/utils/RandTestUtils.t.sol index eec2894d..dd42e97f 100644 --- a/target_chains/ethereum/contracts/forge-test/utils/RandTestUtils.t.sol +++ b/target_chains/ethereum/contracts/forge-test/utils/RandTestUtils.t.sol @@ -42,4 +42,8 @@ contract RandTestUtils is Test { function getRandUint8() internal returns (uint8) { return uint8(getRandUint()); } + + function getRandInt8() internal returns (int8) { + return int8(getRandUint8()); + } }