fix submission race condition

This commit is contained in:
De Facto 2021-02-20 19:28:39 +08:00
parent 2c3bfe1efd
commit b482cf00eb
6 changed files with 97 additions and 35 deletions

View File

@ -10,28 +10,28 @@ use thiserror::Error;
#[derive(Clone, Debug, Eq, Error, FromPrimitive, PartialEq)] #[derive(Clone, Debug, Eq, Error, FromPrimitive, PartialEq)]
pub enum Error { pub enum Error {
/// Owner mismatch /// Owner mismatch
#[error("Owner mismatch")] #[error("Owner mismatch")] // 0
OwnerMismatch, OwnerMismatch,
#[error("Insufficient withdrawable")] #[error("Insufficient withdrawable")] // 1
InsufficientWithdrawable, InsufficientWithdrawable,
#[error("Aggregator key not match")] #[error("Aggregator key not match")] // 2
AggregatorMismatch, AggregatorMismatch,
#[error("Invalid round id")] #[error("Invalid round id")] // 3
InvalidRoundID, InvalidRoundID,
#[error("Cannot start new round until cooldown")] #[error("Cannot start new round until cooldown")] // 4
OracleNewRoundCooldown, OracleNewRoundCooldown,
#[error("Max number of submissions reached for this round")] #[error("Max number of submissions reached for this round")] // 5
MaxSubmissionsReached, MaxSubmissionsReached,
#[error("Each oracle may only submit once per round")] #[error("Each oracle may only submit once per round")] // 6
OracleAlreadySubmitted, OracleAlreadySubmitted,
#[error("Rewards overflow")] #[error("Rewards overflow")] // 7
RewardsOverflow, RewardsOverflow,
#[error("No resolve answer")] #[error("No resolve answer")]

View File

@ -31,7 +31,7 @@ export class PriceFeeder {
) )
if (oracleInfo == null) { if (oracleInfo == null) {
log.debug("Is not an oracle for:", name) log.debug("Is not an oracle", { name })
continue continue
} }

View File

@ -29,6 +29,8 @@ export class Submitter {
public logger!: Logger public logger!: Logger
public currentValue: BN public currentValue: BN
public reportedRound: BN
constructor( constructor(
programID: PublicKey, programID: PublicKey,
public aggregatorPK: PublicKey, public aggregatorPK: PublicKey,
@ -40,19 +42,14 @@ export class Submitter {
this.program = new FluxAggregator(this.oracleOwnerWallet, programID) this.program = new FluxAggregator(this.oracleOwnerWallet, programID)
this.currentValue = new BN(0) this.currentValue = new BN(0)
this.reportedRound = new BN(0)
} }
// TODO: harvest rewards if > n // TODO: harvest rewards if > n
public async start() { public async start() {
// make sure the states are initialized // make sure the states are initialized
this.aggregator = await Aggregator.load(this.aggregatorPK) await this.reloadState()
this.roundSubmissions = await Submissions.load(
this.aggregator.roundSubmissions
)
this.answerSubmissions = await Submissions.load(
this.aggregator.answerSubmissions
)
this.logger = log.child({ this.logger = log.child({
aggregator: this.aggregator.config.description, aggregator: this.aggregator.config.description,
@ -61,20 +58,28 @@ export class Submitter {
await Promise.all([this.observeAggregatorState(), this.observePriceFeed()]) await Promise.all([this.observeAggregatorState(), this.observePriceFeed()])
} }
public async withdrawRewards() { public async withdrawRewards() {}
private async reloadState(loadAggregator = true) {
if (loadAggregator) {
this.aggregator = await Aggregator.load(this.aggregatorPK)
}
this.roundSubmissions = await Submissions.load(
this.aggregator.roundSubmissions
)
this.answerSubmissions = await Submissions.load(
this.aggregator.answerSubmissions
)
this.oracle = await Oracle.load(this.oraclePK)
} }
private async observeAggregatorState() { private async observeAggregatorState() {
conn.onAccountChange(this.aggregatorPK, async (info) => { conn.onAccountChange(this.aggregatorPK, async (info) => {
this.aggregator = Aggregator.deserialize(info.data) this.aggregator = Aggregator.deserialize(info.data)
this.roundSubmissions = await Submissions.load( await this.reloadState(false)
this.aggregator.roundSubmissions
)
this.answerSubmissions = await Submissions.load(
this.aggregator.answerSubmissions
)
// TODO: load answer
this.logger.debug("state updated", { this.logger.debug("state updated", {
aggregator: this.aggregator, aggregator: this.aggregator,
submissions: this.roundSubmissions, submissions: this.roundSubmissions,
@ -138,8 +143,11 @@ export class Submitter {
// oracle to start // oracle to start
const oracle = await Oracle.load(this.oraclePK) const oracle = await Oracle.load(this.oraclePK)
if (oracle.canStartNewRound(round.id)) { if (oracle.canStartNewRound(round.id)) {
this.logger.info("Starting a new round") let newRoundID = round.id.addn(1)
return this.submitCurrentValue(round.id.addn(1)) this.logger.info("Starting a new round", {
round: newRoundID.toString(),
})
return this.submitCurrentValue(newRoundID)
} }
} }
@ -148,7 +156,9 @@ export class Submitter {
return return
} }
this.logger.info("Another oracle started a new round") this.logger.info("Another oracle started a new round", {
round: this.aggregator.round.id.toString(),
})
await this.trySubmit() await this.trySubmit()
} }
@ -159,20 +169,27 @@ export class Submitter {
) )
} }
private async submitCurrentValue(round: BN) { private async submitCurrentValue(roundID: BN) {
// guard zero value // guard zero value
const value = this.currentValue const value = this.currentValue
if (value.isZero()) { if (value.isZero()) {
this.logger.warn("current value is zero. skip submit.") this.logger.warn("current value is zero. skip submit")
return
}
if (!roundID.isZero() && roundID.lte(this.reportedRound)) {
this.logger.debug("don't report to the same round twice")
return return
} }
this.logger.info("Submit value", { this.logger.info("Submit value", {
round: round.toString(), round: roundID.toString(),
value: value.toString(), value: value.toString(),
}) })
try { try {
// prevent async race condition where submit could be called twice on the same round
this.reportedRound = roundID
await this.program.submit({ await this.program.submit({
accounts: { accounts: {
aggregator: { write: this.aggregatorPK }, aggregator: { write: this.aggregatorPK },
@ -182,14 +199,23 @@ export class Submitter {
oracle_owner: this.oracleOwnerWallet.account, oracle_owner: this.oracleOwnerWallet.account,
}, },
round_id: round, round_id: roundID,
value, value,
}) })
await this.reloadState()
this.logger.info("Submit OK", {
withdrawable: this.oracle.withdrawable.toString(),
rewardToken: this.aggregator.config.rewardTokenAccount.toString(),
})
} catch (err) { } catch (err) {
console.log(err) console.log(err)
this.logger.error("Submit error", { this.logger.error("Submit error", {
err: err.toString(), err: err.toString(),
}) })
} }
} }
} }

View File

@ -2,17 +2,50 @@ import dotenv from "dotenv"
dotenv.config() dotenv.config()
import { Command, option } from "commander" import { Command, option } from "commander"
import { jsonReplacer, loadJSONFile } from "./json" import { jsonReplacer, loadJSONFile } from "./json"
import { AggregatorDeployFile } from "./Deployer" import { AggregatorDeployFile, Deployer } from "./Deployer"
import { conn, network } from "./context" import { conn, network } from "./context"
import { AggregatorObserver } from "./AggregatorObserver" import { AggregatorObserver } from "./AggregatorObserver"
import { Aggregator, Answer } from "./schema" import { Aggregator, Answer } from "./schema"
import { PriceFeeder } from "./PriceFeeder" import { PriceFeeder } from "./PriceFeeder"
import { walletFromEnv } from "./utils" import { sleep, walletFromEnv } from "./utils"
import { PublicKey, Wallet } from "solray"
import { log } from "./log"
const cli = new Command() const cli = new Command()
async function maybeRequestAirdrop(pubkey: PublicKey) {
if (network != "mainnet") {
log.info("airdrop 10 SOL", { address: pubkey.toBase58() })
await conn.requestAirdrop(pubkey, 10 * 1e9)
await sleep(500)
}
}
function deployFile(): AggregatorDeployFile {
return loadJSONFile<AggregatorDeployFile>(process.env.DEPLOY_FILE!)
}
cli.command("new-wallet").action(async (name) => {
const mnemonic = Wallet.generateMnemonic()
const wallet = await Wallet.fromMnemonic(mnemonic, conn)
log.info(`address: ${wallet.address}`)
log.info(`mnemonic: ${mnemonic}`)
await maybeRequestAirdrop(wallet.pubkey)
})
cli.command("setup <setup-file>").action(async (setupFile) => {
const wallet = await walletFromEnv("ADMIN_MNEMONIC", conn)
await maybeRequestAirdrop(wallet.pubkey)
const deployer = new Deployer(process.env.DEPLOY_FILE!, setupFile, wallet)
await deployer.runAll()
})
cli.command("oracle").action(async (name) => { cli.command("oracle").action(async (name) => {
const wallet = await walletFromEnv("ORACLE_MNEMONIC", conn) const wallet = await walletFromEnv("ORACLE_MNEMONIC", conn)
await maybeRequestAirdrop(wallet.pubkey)
let deploy = loadJSONFile<AggregatorDeployFile>(process.env.DEPLOY_FILE!) let deploy = loadJSONFile<AggregatorDeployFile>(process.env.DEPLOY_FILE!)
const feeder = new PriceFeeder(deploy, wallet) const feeder = new PriceFeeder(deploy, wallet)
feeder.start() feeder.start()

View File

@ -351,6 +351,7 @@ function boolToInt(t: boolean) {
export class Oracle extends Serialization { export class Oracle extends Serialization {
public static size = 113 public static size = 113
public allowStartRound!: BN public allowStartRound!: BN
public withdrawable!: BN
public static schema = { public static schema = {
kind: "struct", kind: "struct",

View File

@ -2,8 +2,9 @@ import dotenv from "dotenv"
dotenv.config() dotenv.config()
import { AppContext, conn, network } from "./src/context" import { AppContext, conn, network } from "./src/context"
import { Deployer } from "./src/Deployer" import { AggregatorDeployFile, Deployer } from "./src/Deployer"
import { coinbase } from "./src/feeds" import { coinbase } from "./src/feeds"
import { loadJSONFile } from "./src/json"
import { log } from "./src/log" import { log } from "./src/log"
import { PriceFeeder } from "./src/PriceFeeder" import { PriceFeeder } from "./src/PriceFeeder"
@ -22,7 +23,8 @@ async function main() {
await deployer.runAll() await deployer.runAll()
const feeder = new PriceFeeder(deployFile, feederConfigFile, oracleWallet) const deploy = loadJSONFile<AggregatorDeployFile>(deployFile)
const feeder = new PriceFeeder(deploy, oracleWallet)
feeder.start() feeder.start()
return return