[eth] - Aave FallbackOracle Integration (#924)

* feat(eth): aave integration

Add IPriceOracleGetter and PythAssetRegistry mapping

* feat(eth): remove IPriceOracleGetter from PythAssetRegistryGetter

* refactor(eth): flatten PythAssetRegistySetter/Getter into PythAssetRegistry

* feat(eth): address feedback

move aave related contracts into separate directory, add explicit exponent/decimal handling, add
staleness check

* refactor(eth): minor rename to avoid shadowing

* fix(eth): handle exponent conversion and add tests

* chore(eth): remove unused console import

* feat(eth): address PR feedback

add more checks, tests & minor refactoring

* feat(eth): add more tests and address feedback
This commit is contained in:
swimricky 2023-07-06 08:29:08 -04:00 committed by GitHub
parent aa0e6fdf22
commit ac97b4d35d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 612 additions and 4 deletions

View File

@ -0,0 +1,79 @@
// SPDX-License-Identifier: Apache 2
pragma solidity ^0.8.0;
import "@pythnetwork/pyth-sdk-solidity/IPyth.sol";
error InconsistentParamsLength();
contract PythAssetRegistryStorage {
struct State {
address pyth;
address BASE_CURRENCY;
uint256 BASE_CURRENCY_UNIT;
// Map of asset priceIds (asset => priceId)
mapping(address => bytes32) assetsPriceIds;
/// Maximum acceptable time period before price is considered to be stale.
/// This includes attestation delay, block time, and potential clock drift
/// between the source/target chains.
uint validTimePeriodSeconds;
}
}
contract PythAssetRegistry {
PythAssetRegistryStorage.State _registryState;
/**
* @dev Emitted after the base currency is set
* @param baseCurrency The base currency of used for price quotes
* @param baseCurrencyUnit The unit of the base currency
*/
event BaseCurrencySet(
address indexed baseCurrency,
uint256 baseCurrencyUnit
);
/**
* @dev Emitted after the price source of an asset is updated
* @param asset The address of the asset
* @param source The priceId of the asset
*/
event AssetSourceUpdated(address indexed asset, bytes32 indexed source);
function pyth() public view returns (IPyth) {
return IPyth(_registryState.pyth);
}
function setPyth(address pythAddress) internal {
_registryState.pyth = payable(pythAddress);
}
function setAssetsSources(
address[] memory assets,
bytes32[] memory priceIds
) internal {
if (assets.length != priceIds.length) {
revert InconsistentParamsLength();
}
for (uint256 i = 0; i < assets.length; i++) {
_registryState.assetsPriceIds[assets[i]] = priceIds[i];
emit AssetSourceUpdated(assets[i], priceIds[i]);
}
}
function setBaseCurrency(
address baseCurrency,
uint256 baseCurrencyUnit
) internal {
_registryState.BASE_CURRENCY = baseCurrency;
_registryState.BASE_CURRENCY_UNIT = baseCurrencyUnit;
emit BaseCurrencySet(baseCurrency, baseCurrencyUnit);
}
function setValidTimePeriodSeconds(uint validTimePeriodInSeconds) internal {
_registryState.validTimePeriodSeconds = validTimePeriodInSeconds;
}
function validTimePeriodSeconds() public view returns (uint) {
return _registryState.validTimePeriodSeconds;
}
}

View File

@ -0,0 +1,103 @@
// SPDX-License-Identifier: Apache 2
pragma solidity ^0.8.0;
import "@pythnetwork/pyth-sdk-solidity/PythErrors.sol";
import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol";
import "./interfaces/IPriceOracleGetter.sol";
import "./PythAssetRegistry.sol";
/// Invalid non-positive price
error InvalidNonPositivePrice();
/// Normalization overflow
error NormalizationOverflow();
/// Invalid Base Currency Unit value. Must be power of 10.
error InvalidBaseCurrencyUnit();
contract PythPriceOracleGetter is PythAssetRegistry, IPriceOracleGetter {
/// @inheritdoc IPriceOracleGetter
address public immutable override BASE_CURRENCY;
/**
* @notice Returns the base currency unit
* @dev 1 ether for ETH, 1e8 for USD.
* @return Returns the base currency unit.
*/
uint256 public immutable override BASE_CURRENCY_UNIT;
/// BASE_CURRENCY_UNIT as a power of 10
uint8 public immutable BASE_NUM_DECIMALS;
constructor(
address pyth,
address[] memory assets,
bytes32[] memory priceIds,
address baseCurrency,
uint256 baseCurrencyUnit,
uint validTimePeriodSeconds
) {
if (baseCurrencyUnit == 0) {
revert InvalidBaseCurrencyUnit();
}
PythAssetRegistry.setPyth(pyth);
PythAssetRegistry.setAssetsSources(assets, priceIds);
PythAssetRegistry.setBaseCurrency(baseCurrency, baseCurrencyUnit);
BASE_CURRENCY = _registryState.BASE_CURRENCY;
BASE_CURRENCY_UNIT = _registryState.BASE_CURRENCY_UNIT;
if ((10 ** baseNumDecimals(baseCurrencyUnit)) != baseCurrencyUnit) {
revert InvalidBaseCurrencyUnit();
}
BASE_NUM_DECIMALS = baseNumDecimals(baseCurrencyUnit);
PythAssetRegistry.setValidTimePeriodSeconds(validTimePeriodSeconds);
}
/// @inheritdoc IPriceOracleGetter
function getAssetPrice(
address asset
) external view override returns (uint256) {
bytes32 priceId = _registryState.assetsPriceIds[asset];
if (asset == BASE_CURRENCY) {
return BASE_CURRENCY_UNIT;
}
if (priceId == 0) {
revert PythErrors.PriceFeedNotFound();
}
PythStructs.Price memory price = pyth().getPriceNoOlderThan(
priceId,
PythAssetRegistry.validTimePeriodSeconds()
);
// Aave is not using any price feeds < 0 for now.
if (price.price <= 0) {
revert InvalidNonPositivePrice();
}
uint256 normalizedPrice = uint64(price.price);
int32 normalizerExpo = price.expo + int8(BASE_NUM_DECIMALS);
bool isNormalizerExpoNeg = normalizerExpo < 0;
uint256 normalizer = isNormalizerExpoNeg
? 10 ** uint32(-normalizerExpo)
: 10 ** uint32(normalizerExpo);
// this check prevents overflow in normalized price
if (!isNormalizerExpoNeg && normalizer > type(uint192).max) {
revert NormalizationOverflow();
}
normalizedPrice = isNormalizerExpoNeg
? normalizedPrice / normalizer
: normalizedPrice * normalizer;
if (normalizedPrice <= 0) {
revert InvalidNonPositivePrice();
}
return normalizedPrice;
}
function baseNumDecimals(uint number) private pure returns (uint8) {
uint8 digits = 0;
while (number != 0) {
number /= 10;
digits++;
}
return digits - 1;
}
}

View File

@ -0,0 +1,31 @@
// contracts/pyth/aave/PythPriceOracleGetter.sol
// SPDX-License-Identifier: AGPL-3.0
pragma solidity ^0.8.0;
/**
* @title IPriceOracleGetter
* @author Aave
* @notice Interface for the Aave price oracle.
*/
interface IPriceOracleGetter {
/**
* @notice Returns the base currency address
* @dev Address 0x0 is reserved for USD as base currency.
* @return Returns the base currency address.
*/
function BASE_CURRENCY() external view returns (address);
/**
* @notice Returns the base currency unit
* @dev 1 ether for ETH, 1e8 for USD.
* @return Returns the base currency unit.
*/
function BASE_CURRENCY_UNIT() external view returns (uint256);
/**
* @notice Returns the asset price in the base currency
* @param asset The address of the asset
* @return The price of the asset
*/
function getAssetPrice(address asset) external view returns (uint256);
}

View File

@ -49,7 +49,7 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils {
uint[] freshPricesWhMerkleUpdateFee; // i th element contains the update fee for the first i prices
uint64 sequence;
uint randSeed;
uint randomSeed;
function setUp() public {
address wormholeAddr = setUpWormholeReceiver(NUM_GUARDIANS);
@ -120,8 +120,8 @@ contract GasBenchmark is Test, WormholeTestUtils, PythTestUtils {
}
function getRand() internal returns (uint val) {
++randSeed;
val = uint(keccak256(abi.encode(randSeed)));
++randomSeed;
val = uint(keccak256(abi.encode(randomSeed)));
}
function generateWhBatchUpdateDataAndFee(

View File

@ -0,0 +1,368 @@
// SPDX-License-Identifier: Apache 2
pragma solidity ^0.8.0;
import "forge-std/Test.sol";
import "@pythnetwork/pyth-sdk-solidity/IPyth.sol";
import "@pythnetwork/pyth-sdk-solidity/PythErrors.sol";
import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol";
import "./utils/WormholeTestUtils.t.sol";
import "./utils/PythTestUtils.t.sol";
import "./utils/RandTestUtils.t.sol";
import "../contracts/aave/interfaces/IPriceOracleGetter.sol";
import "../contracts/aave/PythPriceOracleGetter.sol";
import "./Pyth.WormholeMerkleAccumulator.t.sol";
contract PythAaveTest is PythWormholeMerkleAccumulatorTest {
IPriceOracleGetter public pythOracleGetter;
address[] assets;
bytes32[] priceIds;
uint constant NUM_PRICE_FEEDS = 5;
uint256 constant BASE_CURRENCY_UNIT = 1e8;
uint constant VALID_TIME_PERIOD_SECS = 60;
function setUp() public override {
pyth = IPyth(setUpPyth(setUpWormholeReceiver(1)));
assets = new address[](NUM_PRICE_FEEDS);
PriceFeedMessage[]
memory priceFeedMessages = generateRandomBoundedPriceFeedMessage(
NUM_PRICE_FEEDS
);
priceIds = new bytes32[](NUM_PRICE_FEEDS);
for (uint i = 0; i < NUM_PRICE_FEEDS; i++) {
assets[i] = address(
uint160(uint(keccak256(abi.encodePacked(i + NUM_PRICE_FEEDS))))
);
priceIds[i] = priceFeedMessages[i].priceId;
}
(
bytes[] memory updateData,
uint updateFee
) = createWormholeMerkleUpdateData(priceFeedMessages);
pyth.updatePriceFeeds{value: updateFee}(updateData);
pythOracleGetter = new PythPriceOracleGetter(
address(pyth),
assets,
priceIds,
address(0x0),
BASE_CURRENCY_UNIT,
VALID_TIME_PERIOD_SECS
);
}
function testConversion(
int64 pythPrice,
int32 pythExpo,
uint256 aavePrice,
uint256 baseCurrencyUnit
) private {
PriceFeedMessage[] memory priceFeedMessages = new PriceFeedMessage[](1);
PriceFeedMessage memory priceFeedMessage = PriceFeedMessage({
priceId: getRandBytes32(),
price: pythPrice,
conf: getRandUint64(),
expo: pythExpo,
publishTime: uint64(1),
prevPublishTime: getRandUint64(),
emaPrice: getRandInt64(),
emaConf: getRandUint64()
});
priceFeedMessages[0] = priceFeedMessage;
(
bytes[] memory updateData,
uint updateFee
) = createWormholeMerkleUpdateData(priceFeedMessages);
pyth.updatePriceFeeds{value: updateFee}(updateData);
priceIds = new bytes32[](1);
priceIds[0] = priceFeedMessage.priceId;
assets = new address[](1);
assets[0] = address(
uint160(uint(keccak256(abi.encodePacked(uint(100)))))
);
pythOracleGetter = new PythPriceOracleGetter(
address(pyth),
assets,
priceIds,
address(0x0),
baseCurrencyUnit,
VALID_TIME_PERIOD_SECS
);
assertEq(pythOracleGetter.getAssetPrice(assets[0]), aavePrice);
}
function testGetAssetPriceWorks() public {
// "display" price is 529.30903
testConversion(52_930_903, -5, 52_930_903_000, BASE_CURRENCY_UNIT);
}
function testGetAssetPriceWorksWithPositiveExponent() public {
// "display" price is 5_293_000
testConversion(5_293, 3, 529_300_000_000_000, BASE_CURRENCY_UNIT);
}
function testGetAssetPriceWorksWithZeroExponent() public {
// "display" price is 5_293
testConversion(5_293, 0, 529_300_000_000, BASE_CURRENCY_UNIT);
}
function testGetAssetPriceWorksWithNegativeNormalizerExponent() public {
// "display" price is 5_293
testConversion(
5_293_000_000_000_000,
-12,
529_300_000_000,
BASE_CURRENCY_UNIT
);
}
function testGetAssetPriceWorksWithBaseCurrencyUnitOfOne() public {
// "display" price is 529.30903
testConversion(52_930_903, -5, 529, 1);
}
function testGetAssetPriceWorksWithBoundedRandomValues(uint seed) public {
setRandSeed(seed);
for (uint i = 0; i < assets.length; i++) {
address asset = assets[i];
uint256 assetPrice = pythOracleGetter.getAssetPrice(asset);
uint256 aavePrice = assetPrice / BASE_CURRENCY_UNIT;
bytes32 priceId = priceIds[i];
PythStructs.Price memory price = pyth.getPrice(priceId);
int64 pythRawPrice = price.price;
uint pythNormalizer;
uint pythPrice;
if (price.expo < 0) {
pythNormalizer = 10 ** uint32(-price.expo);
pythPrice = uint64(pythRawPrice) / pythNormalizer;
} else {
pythNormalizer = 10 ** uint32(price.expo);
pythPrice = uint64(pythRawPrice) * pythNormalizer;
}
assertEq(aavePrice, pythPrice);
}
}
function testGetAssetPriceWorksIfGivenBaseCurrencyAddress() public {
address usdAddress = address(0x0);
uint256 assetPrice = pythOracleGetter.getAssetPrice(usdAddress);
assertEq(assetPrice, BASE_CURRENCY_UNIT);
}
function testGetAssetRevertsIfPriceNotRecentEnough() public {
uint timestamp = block.timestamp;
vm.warp(timestamp + VALID_TIME_PERIOD_SECS);
for (uint i = 0; i < assets.length; i++) {
pythOracleGetter.getAssetPrice(assets[i]);
}
vm.warp(timestamp + VALID_TIME_PERIOD_SECS + 1);
for (uint i = 0; i < assets.length; i++) {
vm.expectRevert(PythErrors.StalePrice.selector);
pythOracleGetter.getAssetPrice(assets[i]);
}
}
function testGetAssetRevertsIfPriceFeedNotFound() public {
address addr = address(
uint160(uint(keccak256(abi.encodePacked(uint(100)))))
);
vm.expectRevert(PythErrors.PriceFeedNotFound.selector);
pythOracleGetter.getAssetPrice(addr);
}
function testGetAssetPriceRevertsIfPriceIsNegative() public {
PriceFeedMessage[] memory priceFeedMessages = new PriceFeedMessage[](1);
PriceFeedMessage memory priceFeedMessage = PriceFeedMessage({
priceId: getRandBytes32(),
price: int64(-5),
conf: getRandUint64(),
expo: getRandInt32(),
publishTime: uint64(1),
prevPublishTime: getRandUint64(),
emaPrice: getRandInt64(),
emaConf: getRandUint64()
});
priceFeedMessages[0] = priceFeedMessage;
(
bytes[] memory updateData,
uint updateFee
) = createWormholeMerkleUpdateData(priceFeedMessages);
pyth.updatePriceFeeds{value: updateFee}(updateData);
priceIds = new bytes32[](1);
priceIds[0] = priceFeedMessage.priceId;
assets = new address[](1);
assets[0] = address(
uint160(uint(keccak256(abi.encodePacked(uint(100)))))
);
pythOracleGetter = new PythPriceOracleGetter(
address(pyth),
assets,
priceIds,
address(0x0),
BASE_CURRENCY_UNIT,
VALID_TIME_PERIOD_SECS
);
vm.expectRevert(abi.encodeWithSignature("InvalidNonPositivePrice()"));
pythOracleGetter.getAssetPrice(assets[0]);
}
function testGetAssetPriceRevertsIfNormalizerOverflows() public {
PriceFeedMessage[] memory priceFeedMessages = new PriceFeedMessage[](1);
PriceFeedMessage memory priceFeedMessage = PriceFeedMessage({
priceId: getRandBytes32(),
price: int64(1),
conf: getRandUint64(),
expo: int32(59), // type(uint192).max = ~6.27e58
publishTime: uint64(1),
prevPublishTime: getRandUint64(),
emaPrice: getRandInt64(),
emaConf: getRandUint64()
});
priceFeedMessages[0] = priceFeedMessage;
(
bytes[] memory updateData,
uint updateFee
) = createWormholeMerkleUpdateData(priceFeedMessages);
pyth.updatePriceFeeds{value: updateFee}(updateData);
priceIds = new bytes32[](1);
priceIds[0] = priceFeedMessage.priceId;
assets = new address[](1);
assets[0] = address(
uint160(uint(keccak256(abi.encodePacked(uint(100)))))
);
pythOracleGetter = new PythPriceOracleGetter(
address(pyth),
assets,
priceIds,
address(0x0),
BASE_CURRENCY_UNIT,
VALID_TIME_PERIOD_SECS
);
vm.expectRevert(abi.encodeWithSignature("NormalizationOverflow()"));
pythOracleGetter.getAssetPrice(assets[0]);
}
function testGetAssetPriceRevertsIfNormalizedToZero() public {
PriceFeedMessage[] memory priceFeedMessages = new PriceFeedMessage[](1);
PriceFeedMessage memory priceFeedMessage = PriceFeedMessage({
priceId: getRandBytes32(),
price: int64(1),
conf: getRandUint64(),
expo: int32(-75),
publishTime: uint64(1),
prevPublishTime: getRandUint64(),
emaPrice: getRandInt64(),
emaConf: getRandUint64()
});
priceFeedMessages[0] = priceFeedMessage;
(
bytes[] memory updateData,
uint updateFee
) = createWormholeMerkleUpdateData(priceFeedMessages);
pyth.updatePriceFeeds{value: updateFee}(updateData);
priceIds = new bytes32[](1);
priceIds[0] = priceFeedMessage.priceId;
assets = new address[](1);
assets[0] = address(
uint160(uint(keccak256(abi.encodePacked(uint(100)))))
);
pythOracleGetter = new PythPriceOracleGetter(
address(pyth),
assets,
priceIds,
address(0x0),
BASE_CURRENCY_UNIT,
VALID_TIME_PERIOD_SECS
);
vm.expectRevert(abi.encodeWithSignature("InvalidNonPositivePrice()"));
pythOracleGetter.getAssetPrice(assets[0]);
}
function testPythPriceOracleGetterConstructorRevertsIfAssetsAndPriceIdsLengthAreDifferent()
public
{
priceIds = new bytes32[](2);
priceIds[0] = getRandBytes32();
priceIds[1] = getRandBytes32();
assets = new address[](1);
assets[0] = address(
uint160(uint(keccak256(abi.encodePacked(uint(100)))))
);
vm.expectRevert(abi.encodeWithSignature("InconsistentParamsLength()"));
pythOracleGetter = new PythPriceOracleGetter(
address(pyth),
assets,
priceIds,
address(0x0),
BASE_CURRENCY_UNIT,
VALID_TIME_PERIOD_SECS
);
}
function testPythPriceOracleGetterConstructorRevertsIfInvalidBaseCurrencyUnit()
public
{
priceIds = new bytes32[](1);
priceIds[0] = getRandBytes32();
assets = new address[](1);
assets[0] = address(
uint160(uint(keccak256(abi.encodePacked(uint(100)))))
);
vm.expectRevert(abi.encodeWithSignature("InvalidBaseCurrencyUnit()"));
pythOracleGetter = new PythPriceOracleGetter(
address(pyth),
assets,
priceIds,
address(0x0),
0,
VALID_TIME_PERIOD_SECS
);
vm.expectRevert(abi.encodeWithSignature("InvalidBaseCurrencyUnit()"));
pythOracleGetter = new PythPriceOracleGetter(
address(pyth),
assets,
priceIds,
address(0x0),
11,
VALID_TIME_PERIOD_SECS
);
vm.expectRevert(abi.encodeWithSignature("InvalidBaseCurrencyUnit()"));
pythOracleGetter = new PythPriceOracleGetter(
address(pyth),
assets,
priceIds,
address(0x0),
20,
VALID_TIME_PERIOD_SECS
);
}
}

View File

@ -25,7 +25,7 @@ contract PythWormholeMerkleAccumulatorTest is
// -1 is equal to 0xffffff which is the biggest uint if converted back
uint64 constant MAX_UINT64 = uint64(int64(-1));
function setUp() public {
function setUp() public virtual {
pyth = IPyth(setUpPyth(setUpWormholeReceiver(1)));
}
@ -121,6 +121,28 @@ contract PythWormholeMerkleAccumulatorTest is
}
}
/**
* @notice Returns `numPriceFeeds` random price feed messages with price & expo bounded
* to realistic values and publishTime set to 1.
*/
function generateRandomBoundedPriceFeedMessage(
uint numPriceFeeds
) internal returns (PriceFeedMessage[] memory priceFeedMessages) {
priceFeedMessages = new PriceFeedMessage[](numPriceFeeds);
for (uint i = 0; i < numPriceFeeds; i++) {
priceFeedMessages[i] = PriceFeedMessage({
priceId: getRandBytes32(),
price: int64(getRandUint64() / 10), // assuming price should always be positive
conf: getRandUint64(),
expo: int32(getRandInt8() % 13), // pyth contract guarantees that expo between [-12, 12]
publishTime: uint64(1),
prevPublishTime: getRandUint64(),
emaPrice: getRandInt64(),
emaConf: getRandUint64()
});
}
}
function createWormholeMerkleUpdateData(
PriceFeedMessage[] memory priceFeedMessages
) internal returns (bytes[] memory updateData, uint updateFee) {

View File

@ -15,6 +15,7 @@ import "@pythnetwork/pyth-sdk-solidity/IPyth.sol";
import "forge-std/Test.sol";
import "./WormholeTestUtils.t.sol";
import "./RandTestUtils.t.sol";
abstract contract PythTestUtils is Test, WormholeTestUtils {
uint16 constant SOURCE_EMITTER_CHAIN_ID = 0x1;

View File

@ -42,4 +42,8 @@ contract RandTestUtils is Test {
function getRandUint8() internal returns (uint8) {
return uint8(getRandUint());
}
function getRandInt8() internal returns (int8) {
return int8(getRandUint8());
}
}