diff --git a/mango/inventory.py b/mango/inventory.py index 8273389..1057c38 100644 --- a/mango/inventory.py +++ b/mango/inventory.py @@ -23,8 +23,11 @@ from .account import Account from .cache import Cache from .group import Group from .instrumentvalue import InstrumentValue -from .markets import InventorySource, Market +from .loadedmarket import LoadedMarket +from .markets import InventorySource from .openorders import OpenOrders +from .perpmarket import PerpMarket +from .spotmarket import SpotMarket from .watcher import Watcher @@ -65,7 +68,7 @@ class Inventory: class InventoryAccountWatcher: def __init__( self, - market: Market, + market: LoadedMarket, account_watcher: Watcher[Account], group_watcher: Watcher[Group], all_open_orders_watchers: typing.Sequence[Watcher[OpenOrders]], @@ -78,9 +81,19 @@ class InventoryAccountWatcher: ] = all_open_orders_watchers self.cache_watcher: Watcher[Cache] = cache_watcher account: Account = account_watcher.latest - self.spot_account_index: int = group_watcher.latest.slot_by_spot_market_address( - market.address - ).index + if SpotMarket.isa(market): + self.spot_account_index: int = ( + group_watcher.latest.slot_by_spot_market_address(market.address).index + ) + elif PerpMarket.isa(market): + self.spot_account_index = group_watcher.latest.slot_by_spot_market_address( + market.address + ).index + else: + raise Exception( + f"Cannot find slot for market {market} in group {group_watcher.latest.address}" + ) + base_value = InstrumentValue.find_by_symbol( account.net_values, market.base.symbol )