diff --git a/bin/fetch-price b/bin/fetch-price index 8df1a60..ea1a832 100755 --- a/bin/fetch-price +++ b/bin/fetch-price @@ -27,7 +27,7 @@ context = mango.ContextBuilder.from_command_line_parameters(args) logging.info(str(context)) -oracle_provider: mango.OracleProvider = mango.create_oracle_provider(args.provider) +oracle_provider: mango.OracleProvider = mango.create_oracle_provider(context, args.provider) market = context.market_lookup.find_by_symbol(args.symbol.upper()) if market is None: diff --git a/bin/marketmaker b/bin/marketmaker index 6f54672..52ef4db 100755 --- a/bin/marketmaker +++ b/bin/marketmaker @@ -153,7 +153,7 @@ def build_latest_perp_market_observer(manager: mango.WebSocketSubscriptionManage def build_latest_price_observer(context: mango.Context, disposer: mango.DisposePropagator, provider_name: str, market: mango.Market) -> mango.LatestItemObserverSubscriber[mango.Price]: - oracle_provider: mango.OracleProvider = mango.create_oracle_provider(provider_name) + oracle_provider: mango.OracleProvider = mango.create_oracle_provider(context, provider_name) oracle = oracle_provider.oracle_for_market(context, market) if oracle is None: raise Exception(f"Could not find oracle for market {market.symbol} from provider {provider_name}.") diff --git a/bin/simple-marketmaker b/bin/simple-marketmaker index 66a42d8..fd8159e 100755 --- a/bin/simple-marketmaker +++ b/bin/simple-marketmaker @@ -49,7 +49,7 @@ try: market_operations: mango.MarketOperations = mango.create_market_operations(context, wallet, args.dry_run, market) - oracle_provider: mango.OracleProvider = mango.create_oracle_provider(args.oracle_provider) + oracle_provider: mango.OracleProvider = mango.create_oracle_provider(context, args.oracle_provider) oracle = oracle_provider.oracle_for_market(context, market) if oracle is None: raise Exception(f"Could not find oracle for spot market {market_symbol}") diff --git a/mango/context.py b/mango/context.py index cc564a8..21d4c05 100644 --- a/mango/context.py +++ b/mango/context.py @@ -159,6 +159,16 @@ class Context: return Context(self.cluster, self.cluster_url, program_id, dex_program_id, group_name, group_id, self.token_lookup, self.market_lookup) raise Exception(f"Could not find group with ID '{group_id}' in cluster '{self.cluster}'.") + def new_forced_to_devnet(self) -> "Context": + cluster: str = "devnet" + cluster_url: str = MangoConstants["cluster_urls"][cluster] + return Context(cluster, cluster_url, self.program_id, self.dex_program_id, self.group_name, self.group_id, self.token_lookup, self.market_lookup) + + def new_forced_to_mainnet_beta(self) -> "Context": + cluster: str = "mainnet-beta" + cluster_url: str = MangoConstants["cluster_urls"][cluster] + return Context(cluster, cluster_url, self.program_id, self.dex_program_id, self.group_name, self.group_id, self.token_lookup, self.market_lookup) + def __str__(self) -> str: return f"""« 𝙲𝚘𝚗𝚝𝚎𝚡𝚝: Cluster: {self.cluster} diff --git a/mango/oraclefactory.py b/mango/oraclefactory.py index 604ed9a..1a626fb 100644 --- a/mango/oraclefactory.py +++ b/mango/oraclefactory.py @@ -13,6 +13,7 @@ # [Github](https://github.com/blockworks-foundation) # [Email](mailto:hello@blockworks.foundation) +from .context import Context from .oracle import OracleProvider from .oracles.ftx import ftx from .oracles.pythnetwork import pythnetwork @@ -24,11 +25,17 @@ from .oracles.serum import serum # This file allows you to create a concreate OracleProvider for a specified provider name. # -def create_oracle_provider(provider_name: str) -> OracleProvider: +def create_oracle_provider(context: Context, provider_name: str) -> OracleProvider: if provider_name == "serum": return serum.SerumOracleProvider() elif provider_name == "ftx": return ftx.FtxOracleProvider() elif provider_name == "pyth": - return pythnetwork.PythOracleProvider() + return pythnetwork.PythOracleProvider(context) + elif provider_name == "pyth-mainnet-beta": + mainnet_beta_pyth_context: Context = context.new_forced_to_mainnet_beta() + return pythnetwork.PythOracleProvider(mainnet_beta_pyth_context) + elif provider_name == "pyth-devnet": + devnet_pyth_context: Context = context.new_forced_to_devnet() + return pythnetwork.PythOracleProvider(devnet_pyth_context) raise Exception(f"Unknown oracle provider '{provider_name}'.") diff --git a/mango/oracles/pythnetwork/layouts.py b/mango/oracles/pythnetwork/layouts.py index bf17492..d717ccf 100644 --- a/mango/oracles/pythnetwork/layouts.py +++ b/mango/oracles/pythnetwork/layouts.py @@ -39,7 +39,8 @@ PROD_ACCT_SIZE = 512 PROD_HDR_SIZE = 48 PROD_ATTR_SIZE = PROD_ACCT_SIZE - PROD_HDR_SIZE -PYTH_MAPPING_ROOT = PublicKey("BmA9Z6FjioHJPpjT39QazZyhDRUdZy2ezwx4GiDdE2u2") +PYTH_DEVNET_MAPPING_ROOT = PublicKey("BmA9Z6FjioHJPpjT39QazZyhDRUdZy2ezwx4GiDdE2u2") +PYTH_MAINNET_MAPPING_ROOT = PublicKey("AHtgzX45WTKfkPG53L6WYhGEXwQkN1BVknET3sVsLL8J") # # 🥭 ACCOUNT_TYPE enum diff --git a/mango/oracles/pythnetwork/pythnetwork.py b/mango/oracles/pythnetwork/pythnetwork.py index 34d63f4..bd1fb33 100644 --- a/mango/oracles/pythnetwork/pythnetwork.py +++ b/mango/oracles/pythnetwork/pythnetwork.py @@ -29,7 +29,7 @@ from ...market import Market from ...observables import observable_pipeline_error_reporter from ...oracle import Oracle, OracleProvider, OracleSource, Price, SupportedOracleFeature -from .layouts import MAGIC, MAPPING, PRICE, PRODUCT, PYTH_MAPPING_ROOT +from .layouts import MAGIC, MAPPING, PRICE, PRODUCT, PYTH_DEVNET_MAPPING_ROOT, PYTH_MAINNET_MAPPING_ROOT # # 🥭 Pyth @@ -60,19 +60,18 @@ from .layouts import MAGIC, MAPPING, PRICE, PRODUCT, PYTH_MAPPING_ROOT # class PythOracle(Oracle): - def __init__(self, market: Market, product_data: PRODUCT): + def __init__(self, context: Context, market: Market, product_data: PRODUCT): name = f"Pyth Oracle for {market.symbol}" super().__init__(name, market) + self.context: Context = context self.market: Market = market self.product_data: PRODUCT = product_data self.address: PublicKey = product_data.address features: SupportedOracleFeature = SupportedOracleFeature.MID_PRICE | SupportedOracleFeature.CONFIDENCE self.source: OracleSource = OracleSource("Pyth", name, features, market) - def fetch_price(self, context: Context) -> Price: - pyth_context = context.new_from_cluster("devnet") - - price_account_info = AccountInfo.load(pyth_context, self.product_data.px_acc) + def fetch_price(self, _: Context) -> Price: + price_account_info = AccountInfo.load(self.context, self.product_data.px_acc) if price_account_info is None: raise Exception(f"Price account {self.product_data.px_acc} not found.") @@ -105,24 +104,26 @@ class PythOracle(Oracle): # # Implements the `OracleProvider` abstract base class specialised to the Pyth Network. # +# In order to allow it to vary its cluster without affecting other programs, this takes a `Context` in its +# constructor and uses that to access the data. It ignores the context passed as a parameter to its methods. +# This allows the context-fudging to only happen on construction. class PythOracleProvider(OracleProvider): - def __init__(self, address: PublicKey = PYTH_MAPPING_ROOT) -> None: - super().__init__(f"Pyth Oracle Factory [{address}]") - self.address = address + def __init__(self, context: Context) -> None: + self.address: PublicKey = PYTH_MAINNET_MAPPING_ROOT if context.cluster == "mainnet-beta" else PYTH_DEVNET_MAPPING_ROOT + super().__init__(f"Pyth Oracle Factory [{self.address}]") + self.context: Context = context - def oracle_for_market(self, context: Context, market: Market) -> typing.Optional[Oracle]: - pyth_context = context.new_from_cluster("devnet") + def oracle_for_market(self, _: Context, market: Market) -> typing.Optional[Oracle]: pyth_symbol = self._market_symbol_to_pyth_symbol(market.symbol) - products = self._fetch_all_pyth_products(pyth_context, self.address) + products = self._fetch_all_pyth_products(self.context, self.address) for product in products: if product.attr["symbol"] == pyth_symbol: - return PythOracle(market, product) + return PythOracle(self.context, market, product) return None - def all_available_symbols(self, context: Context) -> typing.Sequence[str]: - pyth_context = context.new_from_cluster("devnet") - products = self._fetch_all_pyth_products(pyth_context, self.address) + def all_available_symbols(self, _: Context) -> typing.Sequence[str]: + products = self._fetch_all_pyth_products(self.context, self.address) symbols: typing.List[str] = [] for product in products: symbol = product.attr["symbol"]