This commit is contained in:
Jayant Krishnamurthy 2023-03-10 10:40:47 -08:00
parent 41436faa50
commit d23231ff8e
13 changed files with 305 additions and 7 deletions

View File

@ -6,6 +6,7 @@ pragma solidity ^0.8.0;
import "../libraries/external/UnsafeBytesLib.sol";
import "@pythnetwork/pyth-sdk-solidity/AbstractPyth.sol";
import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol";
import "forge-std/Test.sol";
import "@pythnetwork/pyth-sdk-solidity/PythErrors.sol";
import "./PythGetters.sol";
@ -91,6 +92,44 @@ abstract contract Pyth is PythGetters, PythSetters, AbstractPyth {
return singleUpdateFeeInWei() * updateData.length;
}
function requirePriceFeeds(
bytes32[] memory priceIds
) public payable override returns (bytes32) {
bytes32 requestId = keccak256(abi.encode(tx.origin, priceIds));
address payable payer = getPendingRequest(requestId);
if (payer != address(0x0)) {
console.log(payer);
console.log(address(this));
console.log(address(this).balance);
// TODO: transfer?
bool success = payer.send(msg.value);
console.log(success);
clearPendingRequest(requestId);
} else {
revert PythErrors.RequirePriceFeeds(priceIds);
}
return requestId;
}
function updatePriceFeedsOnBehalfOf(
address requester,
bytes32[] calldata priceIds,
bytes[] calldata updateData
) public payable override returns (bytes32) {
// TODO: does this need to be more differentiated??
bytes32 requestId = keccak256(abi.encode(requester, priceIds));
updatePriceFeeds(updateData);
// TODO: check that update includes all of priceIds
setPendingRequest(requestId, payable(msg.sender));
return requestId;
}
function verifyPythVM(
IWormhole.VM memory vm
) private view returns (bool valid) {

View File

@ -91,4 +91,10 @@ contract PythGetters is PythState {
function governanceDataSourceIndex() public view returns (uint32) {
return _state.governanceDataSourceIndex;
}
function getPendingRequest(
bytes32 requestId
) public view returns (address payable) {
return _state.pendingRequests[requestId];
}
}

View File

@ -38,4 +38,17 @@ contract PythSetters is PythState {
function setGovernanceDataSourceIndex(uint32 newIndex) internal {
_state.governanceDataSourceIndex = newIndex;
}
function setPendingRequest(
bytes32 requestId,
address payable payer
) internal {
// TODO: probably needs a timestamp / freshness check
_state.pendingRequests[requestId] = payer;
}
function clearPendingRequest(bytes32 requestId) internal {
// TODO: probably needs a timestamp / freshness check
delete _state.pendingRequests[requestId];
}
}

View File

@ -38,6 +38,7 @@ contract PythStorage {
// Mapping of cached price information
// priceId => PriceInfo
mapping(bytes32 => PythInternalStructs.PriceInfo) latestPriceInfo;
mapping(bytes32 => address payable) pendingRequests;
}
}

View File

@ -525,4 +525,41 @@ contract PythTest is Test, WormholeTestUtils, PythTestUtils, RandTestUtils {
MAX_UINT64
);
}
function testRequirePriceFeeds() public {
uint numAttestations = 1;
(
bytes32[] memory priceIds,
PriceAttestation[] memory attestations
) = generateRandomPriceAttestations(numAttestations);
vm.expectRevert(
abi.encodeWithSelector(
PythErrors.RequirePriceFeeds.selector,
priceIds
)
);
pyth.requirePriceFeeds(priceIds);
(
bytes[] memory updateData,
uint updateFee
) = createBatchedUpdateDataFromAttestations(attestations);
// console.log(address(this));
// console.log(address(this).balance);
// console.log(address(pyth).balance);
// payable(address(this)).transfer(100);
bytes32 requestId1 = pyth.updatePriceFeedsOnBehalfOf{value: updateFee}(
tx.origin,
priceIds,
updateData
);
console.logBytes32(requestId1);
bytes32 requestId2 = pyth.requirePriceFeeds{value: 7}(priceIds);
console.logBytes32(requestId2);
}
fallback() external payable {}
}

View File

@ -1,12 +1,14 @@
import React, { useEffect, useState } from "react";
import "./App.css";
import Web3 from "web3";
import ethers from "ethers";
import { BigNumber } from "ethers";
import { TokenConfig, numberToTokenQty, tokenQtyToNumber } from "./utils";
import IPythAbi from "@pythnetwork/pyth-sdk-solidity/abis/IPyth.json";
import OracleSwapAbi from "./abi/OracleSwapAbi.json";
import { approveToken, getApprovedQuantity } from "./erc20";
import { EvmPriceServiceConnection } from "@pythnetwork/pyth-evm-js";
import {PriceId} from "pyth_relay/lib/relay/iface";
/**
* The order entry component lets users enter a quantity of the base token to buy/sell and submit
@ -200,6 +202,8 @@ async function sendSwapTx(
pythContractAddress
);
// todo: need to craft update transaction here
const updateFee = await pythContract.methods
.getUpdateFee(priceFeedUpdateData.length)
.call();
@ -210,6 +214,56 @@ async function sendSwapTx(
);
await swapContract.methods
.swap(isBuy, qtyWei, priceFeedUpdateData)
.swapNoUpdate(isBuy, qtyWei, priceFeedUpdateData)
.send({ value: updateFee, from: sender });
}
async function sendSwapTxEthers(
web3: Web3,
priceServiceUrl: string,
baseTokenPriceFeedId: string,
quoteTokenPriceFeedId: string,
pythContractAddress: string,
swapContractAddress: string,
sender: string,
qtyWei: BigNumber,
isBuy: boolean
) {
// @ts-ignore
const provider = new ethers.providers.Web3Provider(window.ethereum);
const signer: any = undefined;
const swapContract = new ethers.Contract(
swapContractAddress,
OracleSwapAbi as any,
provider
).connect(signer);
let swapTx = await swapContract.populateTransaction.swapNoUpdate(isBuy, qtyWei);
}
async function sendTxWithPyth(tx: PopulatedTransaction, priceServiceUrl: string) {
const pythPriceService = new EvmPriceServiceConnection(priceServiceUrl);
let requiredFeeds: [PriceId] = [];
let loop = true;
let bundle = undefined;
while (loop) {
const priceFeedUpdateData = await pythPriceService.getPriceFeedsUpdateData(requiredFeeds);
let updateTx = await pythContract.populateTransaction.updatePriceFeeds(priceFeedUpdateData);
bundle = [updateTx, tx];
let maybeError = await ethers.simulateBundle([updateTx, tx]);
if (maybeError == "NoPriceFeed") {
// this error needs an ID that can get attached to requiredFeeds
// there's another similar error
} else {
loop = false;
}
}
ethers.
}

View File

@ -1,7 +1,10 @@
[profile.default]
solc = '0.8.4'
solc_version = '0.8.4'
src = 'src'
out = 'out'
libs = ['lib']
libs = [
'lib',
'../../../../../node_modules/@pythnetwork/pyth-sdk-solidity',
]
# See more config options https://github.com/foundry-rs/foundry/tree/master/config

View File

@ -1,4 +1,4 @@
ds-test/=lib/forge-std/lib/ds-test/src/
forge-std/=lib/forge-std/src/
pyth-sdk-solidity/=lib/pyth-sdk-solidity/
@pythnetwork/=../../../../../node_modules/@pythnetwork/
openzeppelin-contracts/=lib/openzeppelin-contracts/

View File

@ -1,8 +1,9 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.0;
import "pyth-sdk-solidity/IPyth.sol";
import "pyth-sdk-solidity/PythStructs.sol";
import "forge-std/Test.sol";
import "@pythnetwork/pyth-sdk-solidity/IPyth.sol";
import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol";
import "openzeppelin-contracts/contracts/token/ERC20/ERC20.sol";
// Example oracle AMM powered by Pyth price feeds.
@ -59,6 +60,20 @@ contract OracleSwap {
uint updateFee = pyth.getUpdateFee(pythUpdateData);
pyth.updatePriceFeeds{value: updateFee}(pythUpdateData);
swapImpl(isBuy, size);
}
function swapNoUpdate(bool isBuy, uint size) external payable {
bytes32[] memory feedIds = new bytes32[](2);
feedIds[0] = baseTokenPriceId;
feedIds[1] = quoteTokenPriceId;
pyth.requirePriceFeeds(feedIds);
swapImpl(isBuy, size);
}
function swapImpl(bool isBuy, uint size) private {
PythStructs.Price memory currentBasePrice = pyth.getPrice(
baseTokenPriceId
);

View File

@ -3,7 +3,7 @@ pragma solidity ^0.8.0;
import "forge-std/Test.sol";
import "../src/OracleSwap.sol";
import "pyth-sdk-solidity/MockPyth.sol";
import "@pythnetwork/pyth-sdk-solidity/MockPyth.sol";
import "openzeppelin-contracts/contracts/mocks/ERC20Mock.sol";
contract OracleSwapTest is Test {
@ -119,6 +119,80 @@ contract OracleSwapTest is Test {
assertEq(baseToken.balanceOf(address(this)), 20e18);
}
function doSwapNoUpdate(
int32 basePrice,
int32 quotePrice,
bool isBuy,
uint size
) private {
bytes[] memory updateData = new bytes[](2);
// This is a dummy update data for Eth. It shows the price as $1000 +- $10 (with -5 exponent).
updateData[0] = mockPyth.createPriceFeedUpdateData(
BASE_PRICE_ID,
basePrice * 100000,
10 * 100000,
-5,
basePrice * 100000,
10 * 100000,
uint64(block.timestamp)
);
updateData[1] = mockPyth.createPriceFeedUpdateData(
QUOTE_PRICE_ID,
quotePrice * 100000,
10 * 100000,
-5,
quotePrice * 100000,
10 * 100000,
uint64(block.timestamp)
);
// Make sure the contract has enough funds to update the pyth feeds
baseToken.approve(address(swap), MAX_INT);
quoteToken.approve(address(swap), MAX_INT);
bytes32[] memory priceIds = new bytes32[](2);
priceIds[0] = BASE_PRICE_ID;
priceIds[1] = QUOTE_PRICE_ID;
uint tip = 7;
vm.expectRevert(
abi.encodeWithSelector(
PythErrors.RequirePriceFeeds.selector,
priceIds
)
);
swap.swapNoUpdate{value: tip}(isBuy, size);
uint updateFee = mockPyth.getUpdateFee(updateData);
vm.deal(address(this), updateFee);
mockPyth.updatePriceFeedsOnBehalfOf{value: updateFee}(
tx.origin,
priceIds,
updateData
);
vm.deal(address(this), tip);
swap.swapNoUpdate{value: tip}(isBuy, size);
}
function testSwapNoUpdate() public {
setupTokens(20e18, 20e18, 20e18, 20e18);
doSwapNoUpdate(10, 1, true, 1e18);
// assertEq(quoteToken.balanceOf(address(this)), 10e18 - 1);
// assertEq(baseToken.balanceOf(address(this)), 21e18);
// doSwapNoUpdate(10, 1, false, 1e18);
// assertEq(quoteToken.balanceOf(address(this)), 20e18 - 1);
// assertEq(baseToken.balanceOf(address(this)), 20e18);
}
function testWithdraw() public {
setupTokens(10e18, 10e18, 10e18, 10e18);

View File

@ -136,4 +136,14 @@ interface IPyth is IPythEvents {
uint64 minPublishTime,
uint64 maxPublishTime
) external payable returns (PythStructs.PriceFeed[] memory priceFeeds);
function requirePriceFeeds(
bytes32[] memory priceIds
) external payable returns (bytes32);
function updatePriceFeedsOnBehalfOf(
address requester,
bytes32[] calldata priceIds,
bytes[] calldata updateData
) external payable returns (bytes32);
}

View File

@ -1,6 +1,7 @@
// SPDX-License-Identifier: Apache-2.0
pragma solidity ^0.8.0;
import "forge-std/Test.sol";
import "./AbstractPyth.sol";
import "./PythStructs.sol";
import "./PythErrors.sol";
@ -12,6 +13,8 @@ contract MockPyth is AbstractPyth {
uint singleUpdateFeeInWei;
uint validTimePeriod;
mapping(bytes32 => address payable) pendingRequests;
constructor(uint _validTimePeriod, uint _singleUpdateFeeInWei) {
singleUpdateFeeInWei = _singleUpdateFeeInWei;
validTimePeriod = _validTimePeriod;
@ -111,6 +114,47 @@ contract MockPyth is AbstractPyth {
}
}
function requirePriceFeeds(
bytes32[] memory priceIds
) public payable override returns (bytes32) {
console.log("requirePriceFeeds");
console.log(tx.origin);
console.log(msg.value);
bytes32 requestId = keccak256(abi.encode(tx.origin, priceIds));
address payable payer = pendingRequests[requestId];
if (payer != address(0x0)) {
// TODO: transfer?
bool success = payer.send(msg.value);
delete pendingRequests[requestId];
} else {
revert PythErrors.RequirePriceFeeds(priceIds);
}
return requestId;
}
function updatePriceFeedsOnBehalfOf(
address requester,
bytes32[] calldata priceIds,
bytes[] calldata updateData
) public payable override returns (bytes32) {
console.log("updateOnBehalfOf");
console.log(requester);
// TODO: does this need to be more differentiated??
bytes32 requestId = keccak256(abi.encode(requester, priceIds));
updatePriceFeeds(updateData);
// TODO: check that update includes all of priceIds
pendingRequests[requestId] = payable(msg.sender);
return requestId;
}
function createPriceFeedUpdateData(
bytes32 id,
int64 price,

View File

@ -29,4 +29,6 @@ library PythErrors {
error InvalidGovernanceDataSource();
// Governance message is old.
error OldGovernanceMessage();
// Call requires someone to update the given price ids first.
error RequirePriceFeeds(bytes32[] priceIds);
}