From 954732976da12ebda1d01baa6dd00f522ffb33f9 Mon Sep 17 00:00:00 2001 From: Kevin Heavey Date: Thu, 12 Aug 2021 20:36:28 +0100 Subject: [PATCH] Add async support (#82) * add async utils and connection * refactor market.py before adding async * add async open orders account * add async_market * add type hint * replace pytest-tornasync with pytest-asyncio * add async tests * add async tests * linting * add async exmplae to README * fix unit test selection * bump minor version number * chmod * fix keygen error when key already exists * use --cov-append * fix coverage for multi test * fix typo Co-authored-by: kevinheavey --- .github/workflows/integration.yml | 3 + .github/workflows/main.yml | 2 +- .pylintrc | 3 +- Makefile | 5 +- Pipfile | 3 +- Pipfile.lock | 51 ++- README.md | 33 +- pyserum/async_connection.py | 16 + pyserum/async_open_orders_account.py | 33 ++ pyserum/async_utils.py | 19 + pyserum/connection.py | 28 +- pyserum/market/__init__.py | 1 + pyserum/market/async_market.py | 169 ++++++++ pyserum/market/core.py | 474 +++++++++++++++++++++ pyserum/market/market.py | 449 +++---------------- pyserum/market/state.py | 32 +- pyserum/open_orders_account.py | 81 ++-- pyserum/utils.py | 15 +- pytest.ini | 2 + scripts/bootstrap_dex.sh | 8 +- scripts/clean_up.sh | 3 + scripts/run_async_int_tests.sh | 19 + scripts/run_coverage.sh | 13 +- setup.py | 2 +- tests/conftest.py | 24 ++ tests/integration/test_async_connection.py | 24 ++ tests/integration/test_async_market.py | 204 +++++++++ tests/integration/test_connection.py | 4 + 28 files changed, 1238 insertions(+), 482 deletions(-) create mode 100644 pyserum/async_connection.py create mode 100644 pyserum/async_open_orders_account.py create mode 100644 pyserum/async_utils.py create mode 100644 pyserum/market/async_market.py create mode 100644 pyserum/market/core.py create mode 100755 scripts/run_async_int_tests.sh create mode 100644 tests/integration/test_async_connection.py create mode 100644 tests/integration/test_async_market.py diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 2a6a147..6905fb9 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -35,3 +35,6 @@ jobs: - name: Run integration tests run: scripts/run_int_tests.sh + + - name: Run async integration tests + run: scripts/run_async_int_tests.sh diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 8de392e..06909d5 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -42,7 +42,7 @@ jobs: - name: Run unit tests run: | - pipenv run pytest -v -m "not integration" + pipenv run pytest -v -m "not integration and not async_integration" coverage: # The type of runner that the job will run on diff --git a/.pylintrc b/.pylintrc index 849b685..aca9dd2 100644 --- a/.pylintrc +++ b/.pylintrc @@ -141,7 +141,8 @@ disable=missing-class-docstring, xreadlines-attribute, deprecated-sys-function, exception-escape, - comprehension-escape + comprehension-escape, + duplicate-code # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option diff --git a/Makefile b/Makefile index f4ce457..0714a77 100644 --- a/Makefile +++ b/Makefile @@ -25,11 +25,14 @@ test-publish: pipenv run twine upload -r testpypi -u serum-community dist/* unit-tests: - pipenv run pytest -v -m "not integration" + pipenv run pytest -v -m "not integration and not async_integration" int-tests: bash scripts/run_int_tests.sh +async-int-tests: + bash scripts/run_async_int_tests.sh + # Minimal makefile for Sphinx documentation # diff --git a/Pipfile b/Pipfile index 41bc09c..38e9320 100644 --- a/Pipfile +++ b/Pipfile @@ -14,7 +14,6 @@ jupyterlab = "*" black = "*" pytest = "*" pylint = "*" -pytest-tornasync = "*" mypy = "*" pydocstyle = "*" flake8 = "*" @@ -25,6 +24,8 @@ twine = "*" setuptools = "*" sphinx = "*" sphinxemoji = "*" +pytest-asyncio = "*" +types-requests = "*" [packages] solana = {version = ">=0.11.3"} diff --git a/Pipfile.lock b/Pipfile.lock index a3c00a7..f7c3201 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "2604eac564a5636b2243ed21a37f66996d0aec5fee5792057041c3331fb4cec8" + "sha256": "6e3769e9efcc90fa5c6a46dcd89465e6457f6cb2c1395adc91ae3886a900ab17" }, "pipfile-spec": 6, "requires": { @@ -790,11 +790,11 @@ }, "jupyterlab-server": { "hashes": [ - "sha256:6dc6e7d26600d110b862acbfaa4d1a2c5e86781008d139213896d96178c3accd", - "sha256:ab568da1dcef2ffdfc9161128dc00b931aae94d6a94978b16f55330dcd1cb043" + "sha256:244c815578c2fdcd341f01635e77d9f112efcbc92ba299e8c6243f870c84c609", + "sha256:31457ef564febc42043bc539356c804f6f9144f602e2852150bf0820ed6d7e18" ], "markers": "python_version >= '3.6'", - "version": "==2.6.2" + "version": "==2.7.0" }, "keyring": { "hashes": [ @@ -972,11 +972,11 @@ }, "notebook": { "hashes": [ - "sha256:5ae23d7f831a5788e8bd51a0ba65c486db3bfd43e9db97a62330b6273e3175e3", - "sha256:ba9db5e5a9bd2d272b67e3de9143cca2be5125578f1c4f2902d7178ce2f0b4ff" + "sha256:b50eafa8208d5db966efd1caa4076b4dfc51815e02a805b32ecd717e9e6cc071", + "sha256:e6b6dfed36b00cf950f63c0d42e947c101d4258aec21624de62b9e0c11ed5c0d" ], "markers": "python_version >= '3.6'", - "version": "==6.4.2" + "version": "==6.4.3" }, "packaging": { "hashes": [ @@ -1160,6 +1160,14 @@ "index": "pypi", "version": "==6.2.4" }, + "pytest-asyncio": { + "hashes": [ + "sha256:2564ceb9612bbd560d19ca4b41347b54e7835c2f792c504f698e05395ed63f6f", + "sha256:3042bcdf1c5d978f6b74d96a151c4cfb9dcece65006198389ccd7e6c60eb1eea" + ], + "index": "pypi", + "version": "==0.15.1" + }, "pytest-cov": { "hashes": [ "sha256:261bb9e47e65bd099c89c3edf92972865210c36813f80ede5277dceb77a4a62a", @@ -1168,14 +1176,6 @@ "index": "pypi", "version": "==2.12.1" }, - "pytest-tornasync": { - "hashes": [ - "sha256:4b165b6ba76b5b228933598f456b71ba233f127991a52889788db0a950ad04ba", - "sha256:d781b6d951a2e7c08843141d3ff583610b4ea86bfa847714c76edefb576bbe5d" - ], - "index": "pypi", - "version": "==0.6.0.post2" - }, "python-dateutil": { "hashes": [ "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86", @@ -1194,6 +1194,7 @@ "pyzmq": { "hashes": [ "sha256:021e22a8c58ab294bd4b96448a2ca4e716e1d76600192ff84c33d71edb1fbd37", + "sha256:0471d634c7fe48ff7d3849798da6c16afc71676dd890b5ae08eb1efe735c6fec", "sha256:0d17bac19e934e9f547a8811b7c2a32651a7840f38086b924e2e3dcb2fae5c3a", "sha256:200ac096cee5499964c90687306a7244b79ef891f773ed4cf15019fd1f3df330", "sha256:240b83b3a8175b2f616f80092cbb019fcd5c18598f78ffc6aa0ae9034b300f14", @@ -1201,9 +1202,12 @@ "sha256:2534a036b777f957bd6b89b55fb2136775ca2659fb0f1c85036ba78d17d86fd5", "sha256:262f470e7acde18b7217aac78d19d2e29ced91a5afbeb7d98521ebf26461aa7e", "sha256:2dd3896b3c952cf6c8013deda53c1df16bf962f355b5503d23521e0f6403ae3d", + "sha256:31c5dfb6df5148789835128768c01bf6402eb753d06f524f12f6786caf96fb44", + "sha256:4842a8263cbaba6fce401bbe4e2b125321c401a01714e42624dabc554bfc2629", "sha256:50d007d5702171bc810c1e74498fa2c7bc5b50f9750697f7fd2a3e71a25aad91", "sha256:5933d1f4087de6e52906f72d92e1e4dcc630d371860b92c55d7f7a4b815a664c", "sha256:620b0abb813958cb3ecb5144c177e26cde92fee6f43c4b9de6b329515532bf27", + "sha256:631f932fb1fa4b76f31adf976f8056519bc6208a3c24c184581c3dd5be15066e", "sha256:66375a6094af72a6098ed4403b15b4db6bf00013c6febc1baa832e7abda827f4", "sha256:6a5b4566f66d953601d0d47d4071897f550a265bafd52ebcad5ac7aad3838cbb", "sha256:6d18c76676771fd891ca8e0e68da0bbfb88e30129835c0ade748016adb3b6242", @@ -1217,10 +1221,13 @@ "sha256:b4428302c389fffc0c9c07a78cad5376636b9d096f332acfe66b321ae9ff2c63", "sha256:b4a51c7d906dc263a0cc5590761e53e0a68f2c2fefe549cbef21c9ee5d2d98a4", "sha256:b921758f8b5098faa85f341bbdd5e36d5339de5e9032ca2b07d8c8e7bec5069b", + "sha256:c1b6619ceb33a8907f1cb82ff8afc8a133e7a5f16df29528e919734718600426", "sha256:c9cb0bd3a3cb7ccad3caa1d7b0d18ba71ed3a4a3610028e506a4084371d4d223", + "sha256:d60a407663b7c2af781ab7f49d94a3d379dd148bb69ea8d9dd5bc69adf18097c", "sha256:da7f7f3bb08bcf59a6b60b4e53dd8f08bb00c9e61045319d825a906dbb3c8fb7", "sha256:e66025b64c4724ba683d6d4a4e5ee23de12fe9ae683908f0c7f0f91b4a2fd94e", "sha256:ed67df4eaa99a20d162d76655bda23160abdf8abf82a17f41dfd3962e608dbcc", + "sha256:f520e9fee5d7a2e09b051d924f85b977c6b4e224e56c0551c3c241bbeeb0ad8d", "sha256:f5c84c5de9a773bbf8b22c51e28380999ea72e5e85b4db8edf5e69a7a0d4d9f9", "sha256:ff345d48940c834168f81fa1d4724675099f148f1ab6369748c4d712ed71bf7c" ], @@ -1551,6 +1558,14 @@ "markers": "python_version < '3.8' and implementation_name == 'cpython'", "version": "==1.4.3" }, + "types-requests": { + "hashes": [ + "sha256:a5a305b43ea57bf64d6731f89816946a405b591eff6de28d4c0fd58422cee779", + "sha256:e21541c0f55c066c491a639309159556dd8c5833e49fcde929c4c47bdb0002ee" + ], + "index": "pypi", + "version": "==2.25.6" + }, "typing-extensions": { "hashes": [ "sha256:0ac0f89795dd19de6b97debb0c6af1c70987fd80a2d62d1958f7e56fcc31b497", @@ -1584,11 +1599,11 @@ }, "websocket-client": { "hashes": [ - "sha256:4cf754af7e3b3ba76589d49f9e09fd9a6c0aae9b799a89124d656009c01a261d", - "sha256:8d07f155f8ed14ae3ced97bd7582b08f280bb1bfd27945f023ba2aceff05ab52" + "sha256:7665ba6c645989b28b61670874ab753e6929179e9fc90565ace6ac090f59c559", + "sha256:d82a975bdd02216f7884cd18106a0d9896a9a9e8cc90f23fe8c81dc48da2f142" ], "markers": "python_version >= '3.6'", - "version": "==1.1.1" + "version": "==1.2.0" }, "wrapt": { "hashes": [ diff --git a/README.md b/README.md index 13feb65..3863feb 100644 --- a/README.md +++ b/README.md @@ -50,8 +50,37 @@ print("\n") print("Bid Orders:") bids = market.load_bids() for bid in bids: - print("Order id: %d, price: %f, size: %f." % ( - bid.order_id, bid.info.price, bid.info.size)) + print(f"Order id: {bid.order_id}, price: {bid.info.price}, size: {bid.info.size}.") +``` + +### Get Orderbook (Async) + +```python +import asyncio +from pyserum.async_connection import async_conn +from pyserum.market import AsyncMarket + + +async def main(): + market_address = "5LgJphS6D5zXwUVPU7eCryDBkyta3AidrJ5vjNU6BcGW" # Address for BTC/USDC + async with async_conn("https://api.mainnet-beta.solana.com/") as cc: + # Load the given market + market = await AsyncMarket.load(cc, market_address) + asks = await market.load_asks() + # Show all current ask order + print("Ask Orders:") + for ask in asks: + print(f"Order id: {ask.order_id}, price: {ask.info.price}, size: {ask.info.size}.") + print("\n") + # Show all current bid order + print("Bid Orders:") + bids = await market.load_bids() + for bid in bids: + print(f"Order id: {bid.order_id}, price: {bid.info.price}, size: {bid.info.size}.") + + +asyncio.run(main()) + ``` ### Support diff --git a/pyserum/async_connection.py b/pyserum/async_connection.py new file mode 100644 index 0000000..7311e31 --- /dev/null +++ b/pyserum/async_connection.py @@ -0,0 +1,16 @@ +from typing import List +import httpx +from solana.rpc.async_api import AsyncClient as async_conn # pylint: disable=unused-import # noqa:F401 + +from .market.types import MarketInfo, TokenInfo +from .connection import LIVE_MARKETS_URL, TOKEN_MINTS_URL, parse_live_markets, parse_token_mints + + +async def get_live_markets(httpx_client: httpx.AsyncClient) -> List[MarketInfo]: + resp = await httpx_client.get(LIVE_MARKETS_URL) + return parse_live_markets(resp.json()) + + +async def get_token_mints(httpx_client: httpx.AsyncClient) -> List[TokenInfo]: + resp = await httpx_client.get(TOKEN_MINTS_URL) + return parse_token_mints(resp.json()) diff --git a/pyserum/async_open_orders_account.py b/pyserum/async_open_orders_account.py new file mode 100644 index 0000000..0b6917c --- /dev/null +++ b/pyserum/async_open_orders_account.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import List +from solana.rpc.async_api import AsyncClient +from solana.publickey import PublicKey +from solana.rpc.types import Commitment +from solana.rpc.commitment import Recent + +from .async_utils import load_bytes_data +from .open_orders_account import _OpenOrdersAccountCore + + +class AsyncOpenOrdersAccount(_OpenOrdersAccountCore): + @classmethod + async def find_for_market_and_owner( # pylint: disable=too-many-arguments + cls, + conn: AsyncClient, + market: PublicKey, + owner: PublicKey, + program_id: PublicKey, + commitment: Commitment = Recent, + ) -> List[AsyncOpenOrdersAccount]: + args = cls._build_get_program_accounts_args( + market=market, program_id=program_id, owner=owner, commitment=commitment + ) + resp = await conn.get_program_accounts(*args) + return cls._process_get_program_accounts_resp(resp) + + @classmethod + async def load(cls, conn: AsyncClient, address: str) -> AsyncOpenOrdersAccount: + addr_pub_key = PublicKey(address) + bytes_data = await load_bytes_data(addr_pub_key, conn) + return cls.from_bytes(addr_pub_key, bytes_data) diff --git a/pyserum/async_utils.py b/pyserum/async_utils.py new file mode 100644 index 0000000..a8a259e --- /dev/null +++ b/pyserum/async_utils.py @@ -0,0 +1,19 @@ +from solana.publickey import PublicKey +from solana.rpc.async_api import AsyncClient +from spl.token.constants import WRAPPED_SOL_MINT + +from pyserum.utils import parse_bytes_data, parse_mint_decimals + + +async def load_bytes_data(addr: PublicKey, conn: AsyncClient) -> bytes: + res = await conn.get_account_info(addr) + return parse_bytes_data(res) + + +async def get_mint_decimals(conn: AsyncClient, mint_pub_key: PublicKey) -> int: + """Get the mint decimals for a token mint""" + if mint_pub_key == WRAPPED_SOL_MINT: + return 9 + + bytes_data = await load_bytes_data(mint_pub_key, conn) + return parse_mint_decimals(bytes_data) diff --git a/pyserum/connection.py b/pyserum/connection.py index 9dd9aaa..5e99c3d 100644 --- a/pyserum/connection.py +++ b/pyserum/connection.py @@ -1,20 +1,28 @@ -from typing import List +from typing import List, Dict, Any + +import requests from solana.rpc.api import Client as conn # pylint: disable=unused-import # noqa:F401 -from solana.rpc.providers.http import requests - +from solana.publickey import PublicKey from .market.types import MarketInfo, TokenInfo +LIVE_MARKETS_URL = "https://raw.githubusercontent.com/project-serum/serum-ts/master/packages/serum/src/markets.json" +TOKEN_MINTS_URL = "https://raw.githubusercontent.com/project-serum/serum-ts/master/packages/serum/src/token-mints.json" -def get_live_markets() -> List[MarketInfo]: - url = "https://raw.githubusercontent.com/project-serum/serum-ts/master/packages/serum/src/markets.json" + +def parse_live_markets(data: List[Dict[str, Any]]) -> List[MarketInfo]: return [ - MarketInfo(name=m["name"], address=m["address"], program_id=m["programId"]) - for m in requests.get(url).json() - if not m["deprecated"] + MarketInfo(name=m["name"], address=m["address"], program_id=m["programId"]) for m in data if not m["deprecated"] ] +def parse_token_mints(data: List[Dict[str, str]]) -> List[TokenInfo]: + return [TokenInfo(name=t["name"], address=PublicKey(t["address"])) for t in data] + + +def get_live_markets() -> List[MarketInfo]: + return parse_live_markets(requests.get(LIVE_MARKETS_URL).json()) + + def get_token_mints() -> List[TokenInfo]: - url = "https://raw.githubusercontent.com/project-serum/serum-ts/master/packages/serum/src/token-mints.json" - return [TokenInfo(**t) for t in requests.get(url).json()] + return parse_token_mints(requests.get(TOKEN_MINTS_URL).json()) diff --git a/pyserum/market/__init__.py b/pyserum/market/__init__.py index 1801fe6..9a2aadd 100644 --- a/pyserum/market/__init__.py +++ b/pyserum/market/__init__.py @@ -1,3 +1,4 @@ from .market import Market # noqa: F401 +from .async_market import AsyncMarket # noqa: F401 from .orderbook import OrderBook # noqa: F401 from .state import MarketState as State # noqa: F401 diff --git a/pyserum/market/async_market.py b/pyserum/market/async_market.py new file mode 100644 index 0000000..0318ce5 --- /dev/null +++ b/pyserum/market/async_market.py @@ -0,0 +1,169 @@ +"""Market module to interact with Serum DEX.""" +from __future__ import annotations + +from typing import List + +from solana.account import Account +from solana.publickey import PublicKey +from solana.rpc.async_api import AsyncClient +from solana.rpc.types import RPCResponse, TxOpts +from solana.transaction import Transaction + +from pyserum import instructions +import pyserum.market.types as t + +from .._layouts.open_orders import OPEN_ORDERS_LAYOUT +from ..enums import OrderType, Side +from ..async_open_orders_account import AsyncOpenOrdersAccount +from ..async_utils import load_bytes_data +from ._internal.queue import decode_event_queue, decode_request_queue +from .orderbook import OrderBook +from .state import MarketState +from .core import MarketCore + +LAMPORTS_PER_SOL = 1000000000 + + +# pylint: disable=too-many-public-methods,abstract-method +class AsyncMarket(MarketCore): + """Represents a Serum Market.""" + + def __init__(self, conn: AsyncClient, market_state: MarketState, force_use_request_queue: bool = False) -> None: + super().__init__(market_state=market_state, force_use_request_queue=force_use_request_queue) + self._conn = conn + + @classmethod + # pylint: disable=unused-argument + async def load( + cls, + conn: AsyncClient, + market_address: PublicKey, + program_id: PublicKey = instructions.DEFAULT_DEX_PROGRAM_ID, + force_use_request_queue: bool = False, + ) -> AsyncMarket: + """Factory method to create a Market. + + :param conn: The connection that we use to load the data, created from `solana.rpc.api`. + :param market_address: The market address that you want to connect to. + :param program_id: The program id of the given market, it will use the default value if not provided. + """ + market_state = await MarketState.async_load(conn, market_address, program_id) + return cls(conn, market_state, force_use_request_queue) + + async def find_open_orders_accounts_for_owner(self, owner_address: PublicKey) -> List[AsyncOpenOrdersAccount]: + return await AsyncOpenOrdersAccount.find_for_market_and_owner( + self._conn, self.state.public_key(), owner_address, self.state.program_id() + ) + + async def load_bids(self) -> OrderBook: + """Load the bid order book""" + bytes_data = await load_bytes_data(self.state.bids(), self._conn) + return self._parse_bids_or_asks(bytes_data) + + async def load_asks(self) -> OrderBook: + """Load the ask order book.""" + bytes_data = await load_bytes_data(self.state.asks(), self._conn) + return self._parse_bids_or_asks(bytes_data) + + async def load_orders_for_owner(self, owner_address: PublicKey) -> List[t.Order]: + """Load orders for owner.""" + bids = await self.load_bids() + asks = await self.load_asks() + open_orders_accounts = await self.find_open_orders_accounts_for_owner(owner_address) + return self._parse_orders_for_owner(bids, asks, open_orders_accounts) + + async def load_event_queue(self) -> List[t.Event]: + """Load the event queue which includes the fill item and out item. For any trades two fill items are added to + the event queue. And in case of a trade, cancel or IOC order that missed, out items are added to the event + queue. + """ + bytes_data = await load_bytes_data(self.state.event_queue(), self._conn) + return decode_event_queue(bytes_data) + + async def load_request_queue(self) -> List[t.Request]: + bytes_data = await load_bytes_data(self.state.request_queue(), self._conn) + return decode_request_queue(bytes_data) + + async def load_fills(self, limit=100) -> List[t.FilledOrder]: + bytes_data = await load_bytes_data(self.state.event_queue(), self._conn) + return self._parse_fills(bytes_data, limit) + + async def place_order( # pylint: disable=too-many-arguments,too-many-locals + self, + payer: PublicKey, + owner: Account, + order_type: OrderType, + side: Side, + limit_price: float, + max_quantity: float, + client_id: int = 0, + opts: TxOpts = TxOpts(), + ) -> RPCResponse: # TODO: Add open_orders_address_key param and fee_discount_pubkey + transaction = Transaction() + signers: List[Account] = [owner] + open_order_accounts = await self.find_open_orders_accounts_for_owner(owner.public_key()) + if open_order_accounts: + place_order_open_order_account = open_order_accounts[0].address + else: + mbfre_resp = await self._conn.get_minimum_balance_for_rent_exemption(OPEN_ORDERS_LAYOUT.sizeof()) + place_order_open_order_account = self._after_oo_mbfre_resp( + mbfre_resp=mbfre_resp, owner=owner, signers=signers, transaction=transaction + ) + # TODO: Cache new_open_orders_account + # TODO: Handle fee_discount_pubkey + + self._prepare_order_transaction( + transaction=transaction, + payer=payer, + owner=owner, + order_type=order_type, + side=side, + signers=signers, + limit_price=limit_price, + max_quantity=max_quantity, + client_id=client_id, + open_order_accounts=open_order_accounts, + place_order_open_order_account=place_order_open_order_account, + ) + return await self._conn.send_transaction(transaction, *signers, opts=opts) + + async def cancel_order_by_client_id( + self, owner: Account, open_orders_account: PublicKey, client_id: int, opts: TxOpts = TxOpts() + ) -> RPCResponse: + txs = self._build_cancel_order_by_client_id_tx( + owner=owner, open_orders_account=open_orders_account, client_id=client_id + ) + return await self._conn.send_transaction(txs, owner, opts=opts) + + async def cancel_order(self, owner: Account, order: t.Order, opts: TxOpts = TxOpts()) -> RPCResponse: + txn = self._build_cancel_order_tx(owner=owner, order=order) + return await self._conn.send_transaction(txn, owner, opts=opts) + + async def match_orders(self, fee_payer: Account, limit: int, opts: TxOpts = TxOpts()) -> RPCResponse: + txn = self._build_match_orders_tx(limit) + return await self._conn.send_transaction(txn, fee_payer, opts=opts) + + async def settle_funds( # pylint: disable=too-many-arguments + self, + owner: Account, + open_orders: AsyncOpenOrdersAccount, + base_wallet: PublicKey, + quote_wallet: PublicKey, # TODO: add referrer_quote_wallet. + opts: TxOpts = TxOpts(), + ) -> RPCResponse: + # TODO: Handle wrapped sol accounts + should_wrap_sol = self._settle_funds_should_wrap_sol() + if should_wrap_sol: + mbfre_resp = await self._conn.get_minimum_balance_for_rent_exemption(165) + min_bal_for_rent_exemption = mbfre_resp["result"] + else: + min_bal_for_rent_exemption = 0 # value only matters if should_wrap_sol + transaction = self._build_settle_funds_tx( + owner=owner, + open_orders=open_orders, + base_wallet=base_wallet, + quote_wallet=quote_wallet, + min_bal_for_rent_exemption=min_bal_for_rent_exemption, + should_wrap_sol=should_wrap_sol, + ) + return await self._conn.send_transaction(transaction, owner, opts=opts) diff --git a/pyserum/market/core.py b/pyserum/market/core.py new file mode 100644 index 0000000..da301ae --- /dev/null +++ b/pyserum/market/core.py @@ -0,0 +1,474 @@ +"""Market module to interact with Serum DEX.""" +from __future__ import annotations + +import itertools +import logging +from typing import List, Union + +from solana.account import Account +from solana.publickey import PublicKey +from solana.rpc.types import RPCResponse +from solana.system_program import CreateAccountParams, create_account +from solana.transaction import Transaction, TransactionInstruction +from spl.token.constants import ACCOUNT_LEN, TOKEN_PROGRAM_ID, WRAPPED_SOL_MINT +from spl.token.instructions import CloseAccountParams +from spl.token.instructions import InitializeAccountParams, close_account, initialize_account + +from pyserum import instructions +import pyserum.market.types as t + +from ..enums import OrderType, SelfTradeBehavior, Side +from ..open_orders_account import OpenOrdersAccount, make_create_account_instruction +from ..async_open_orders_account import AsyncOpenOrdersAccount +from ._internal.queue import decode_event_queue +from .orderbook import OrderBook +from .state import MarketState + +LAMPORTS_PER_SOL = 1000000000 + + +# pylint: disable=too-many-public-methods +class MarketCore: + """Represents a Serum Market.""" + + logger = logging.getLogger("pyserum.market.Market") + + def __init__(self, market_state: MarketState, force_use_request_queue: bool = False) -> None: + self.state = market_state + self.force_use_request_queue = force_use_request_queue + + def _use_request_queue(self) -> bool: + return ( + # DEX Version 1 + self.state.program_id == PublicKey("4ckmDgGdxQoPDLUkDT3vHgSAkzA3QRdNq5ywwY4sUSJn") + or + # DEX Version 1 + self.state.program_id == PublicKey("BJ3jrUzddfuSrZHXSCxMUUQsjKEyLmuuyZebkcaFp2fg") + or + # DEX Version 2 + self.state.program_id == PublicKey("EUqojwWA2rd19FZrzeBncJsm38Jm1hEhE3zsmX3bRc2o") + or self.force_use_request_queue + ) + + def support_srm_fee_discounts(self) -> bool: + raise NotImplementedError("support_srm_fee_discounts not implemented") + + def find_fee_discount_keys(self, owner: PublicKey, cache_duration: int): + raise NotImplementedError("find_fee_discount_keys not implemented") + + def find_best_fee_discount_key(self, owner: PublicKey, cache_duration: int): + raise NotImplementedError("find_best_fee_discount_key not implemented") + + def find_quote_token_accounts_for_owner(self, owner_address: PublicKey, include_unwrapped_sol: bool = False): + raise NotImplementedError("find_quote_token_accounts_for_owner not implemented") + + def _parse_bids_or_asks(self, bytes_data: bytes) -> OrderBook: + return OrderBook.from_bytes(self.state, bytes_data) + + @staticmethod + def _parse_orders_for_owner(bids, asks, open_orders_accounts) -> List[t.Order]: + if not open_orders_accounts: + return [] + + all_orders = itertools.chain(bids.orders(), asks.orders()) + open_orders_addresses = {str(o.address) for o in open_orders_accounts} + orders = [o for o in all_orders if str(o.open_order_address) in open_orders_addresses] + return orders + + def load_base_token_for_owner(self): + raise NotImplementedError("load_base_token_for_owner not implemented") + + def _parse_fills(self, bytes_data: bytes, limit: int) -> List[t.FilledOrder]: + events = decode_event_queue(bytes_data, limit) + return [ + self.parse_fill_event(event) + for event in events + if event.event_flags.fill and event.native_quantity_paid > 0 + ] + + def parse_fill_event(self, event: t.Event) -> t.FilledOrder: + if event.event_flags.bid: + side = Side.BUY + price_before_fees = ( + event.native_quantity_released + event.native_fee_or_rebate + if event.event_flags.maker + else event.native_quantity_released - event.native_fee_or_rebate + ) + else: + side = Side.SELL + price_before_fees = ( + event.native_quantity_released - event.native_fee_or_rebate + if event.event_flags.maker + else event.native_quantity_released + event.native_fee_or_rebate + ) + + price = (price_before_fees * self.state.base_spl_token_multiplier()) / ( + self.state.quote_spl_token_multiplier() * event.native_quantity_paid + ) + size = event.native_quantity_paid / self.state.base_spl_token_multiplier() + return t.FilledOrder( + order_id=event.order_id, + side=side, + price=price, + size=size, + fee_cost=event.native_fee_or_rebate * (1 if event.event_flags.maker else -1), + ) + + def _prepare_new_oo_account( + self, owner: Account, balance_needed: int, signers: List[Account], transaction: Transaction + ) -> PublicKey: + new_open_orders_account = Account() + place_order_open_order_account = new_open_orders_account.public_key() + transaction.add( + make_create_account_instruction( + owner_address=owner.public_key(), + new_account_address=new_open_orders_account.public_key(), + lamports=balance_needed, + program_id=self.state.program_id(), + ) + ) + signers.append(new_open_orders_account) + return place_order_open_order_account + + def _prepare_order_transaction( # pylint: disable=too-many-arguments,too-many-locals + self, + transaction: Transaction, + payer: PublicKey, + owner: Account, + order_type: OrderType, + side: Side, + signers: List[Account], + limit_price: float, + max_quantity: float, + client_id: int, + open_order_accounts: Union[List[OpenOrdersAccount], List[AsyncOpenOrdersAccount]], + place_order_open_order_account: PublicKey, + ) -> None: + # unwrapped SOL cannot be used for payment + if payer == owner.public_key(): + raise ValueError("Invalid payer account. Cannot use unwrapped SOL.") + + # TODO: add integration test for SOL wrapping. + should_wrap_sol = (side == Side.BUY and self.state.quote_mint() == WRAPPED_SOL_MINT) or ( + side == Side.SELL and self.state.base_mint() == WRAPPED_SOL_MINT + ) + + if should_wrap_sol: + wrapped_sol_account = Account() + payer = wrapped_sol_account.public_key() + signers.append(wrapped_sol_account) + transaction.add( + create_account( + CreateAccountParams( + from_pubkey=owner.public_key(), + new_account_pubkey=wrapped_sol_account.public_key(), + lamports=self._get_lamport_need_for_sol_wrapping( + limit_price, max_quantity, side, open_order_accounts + ), + space=ACCOUNT_LEN, + program_id=TOKEN_PROGRAM_ID, + ) + ) + ) + transaction.add( + initialize_account( + InitializeAccountParams( + account=wrapped_sol_account.public_key(), + mint=WRAPPED_SOL_MINT, + owner=owner.public_key(), + program_id=TOKEN_PROGRAM_ID, + ) + ) + ) + + transaction.add( + self.make_place_order_instruction( + payer=payer, + owner=owner, + order_type=order_type, + side=side, + limit_price=limit_price, + max_quantity=max_quantity, + client_id=client_id, + open_order_account=place_order_open_order_account, + ) + ) + + if should_wrap_sol: + transaction.add( + close_account( + CloseAccountParams( + account=wrapped_sol_account.public_key(), + owner=owner.public_key(), + dest=owner.public_key(), + program_id=TOKEN_PROGRAM_ID, + ) + ) + ) + + def _after_oo_mbfre_resp( + self, mbfre_resp: RPCResponse, owner: Account, signers: List[Account], transaction: Transaction + ) -> PublicKey: + balance_needed = mbfre_resp["result"] + place_order_open_order_account = self._prepare_new_oo_account(owner, balance_needed, signers, transaction) + return place_order_open_order_account + + @staticmethod + def _get_lamport_need_for_sol_wrapping( + price: float, + size: float, + side: Side, + open_orders_accounts: Union[List[OpenOrdersAccount], List[AsyncOpenOrdersAccount]], + ) -> int: + lamports = 0 + if side == Side.BUY: + lamports = round(price * size * 1.01 * LAMPORTS_PER_SOL) + if open_orders_accounts: + lamports -= open_orders_accounts[0].quote_token_free + else: + lamports = round(size * LAMPORTS_PER_SOL) + if open_orders_accounts: + lamports -= open_orders_accounts[0].base_token_free + + return max(lamports, 0) + 10000000 + + def make_place_order_instruction( # pylint: disable=too-many-arguments + self, + payer: PublicKey, + owner: Account, + order_type: OrderType, + side: Side, + limit_price: float, + max_quantity: float, + client_id: int, + open_order_account: PublicKey, + fee_discount_pubkey: PublicKey = None, + ) -> TransactionInstruction: + if self.state.base_size_number_to_lots(max_quantity) < 0: + raise Exception("Size lot %d is too small" % max_quantity) + if self.state.price_number_to_lots(limit_price) < 0: + raise Exception("Price lot %d is too small" % limit_price) + if self._use_request_queue(): + return instructions.new_order( + instructions.NewOrderParams( + market=self.state.public_key(), + open_orders=open_order_account, + payer=payer, + owner=owner.public_key(), + request_queue=self.state.request_queue(), + base_vault=self.state.base_vault(), + quote_vault=self.state.quote_vault(), + side=side, + limit_price=self.state.price_number_to_lots(limit_price), + max_quantity=self.state.base_size_number_to_lots(max_quantity), + order_type=order_type, + client_id=client_id, + program_id=self.state.program_id(), + ) + ) + return instructions.new_order_v3( + instructions.NewOrderV3Params( + market=self.state.public_key(), + open_orders=open_order_account, + payer=payer, + owner=owner.public_key(), + request_queue=self.state.request_queue(), + event_queue=self.state.event_queue(), + bids=self.state.bids(), + asks=self.state.asks(), + base_vault=self.state.base_vault(), + quote_vault=self.state.quote_vault(), + side=side, + limit_price=self.state.price_number_to_lots(limit_price), + max_base_quantity=self.state.base_size_number_to_lots(max_quantity), + max_quote_quantity=self.state.base_size_number_to_lots(max_quantity) + * self.state.quote_lot_size() + * self.state.price_number_to_lots(limit_price), + order_type=order_type, + client_id=client_id, + program_id=self.state.program_id(), + self_trade_behavior=SelfTradeBehavior.DECREMENT_TAKE, + fee_discount_pubkey=fee_discount_pubkey, + limit=65535, + ) + ) + + def _build_cancel_order_by_client_id_tx( + self, owner: Account, open_orders_account: PublicKey, client_id: int + ) -> Transaction: + return Transaction().add(self.make_cancel_order_by_client_id_instruction(owner, open_orders_account, client_id)) + + def make_cancel_order_by_client_id_instruction( + self, owner: Account, open_orders_account: PublicKey, client_id: int + ) -> TransactionInstruction: + if self._use_request_queue(): + return instructions.cancel_order_by_client_id( + instructions.CancelOrderByClientIDParams( + market=self.state.public_key(), + owner=owner.public_key(), + open_orders=open_orders_account, + request_queue=self.state.request_queue(), + client_id=client_id, + program_id=self.state.program_id(), + ) + ) + return instructions.cancel_order_by_client_id_v2( + instructions.CancelOrderByClientIDV2Params( + market=self.state.public_key(), + owner=owner.public_key(), + open_orders=open_orders_account, + bids=self.state.bids(), + asks=self.state.asks(), + event_queue=self.state.event_queue(), + client_id=client_id, + program_id=self.state.program_id(), + ) + ) + + def _build_cancel_order_tx(self, owner: Account, order: t.Order) -> Transaction: + return Transaction().add(self.make_cancel_order_instruction(owner.public_key(), order)) + + def make_cancel_order_instruction(self, owner: PublicKey, order: t.Order) -> TransactionInstruction: + if self._use_request_queue(): + return instructions.cancel_order( + instructions.CancelOrderParams( + market=self.state.public_key(), + owner=owner, + open_orders=order.open_order_address, + request_queue=self.state.request_queue(), + side=order.side, + order_id=order.order_id, + open_orders_slot=order.open_order_slot, + program_id=self.state.program_id(), + ) + ) + return instructions.cancel_order_v2( + instructions.CancelOrderV2Params( + market=self.state.public_key(), + owner=owner, + open_orders=order.open_order_address, + bids=self.state.bids(), + asks=self.state.asks(), + event_queue=self.state.event_queue(), + side=order.side, + order_id=order.order_id, + open_orders_slot=order.open_order_slot, + program_id=self.state.program_id(), + ) + ) + + def _build_match_orders_tx(self, limit: int) -> Transaction: + return Transaction().add(self.make_match_orders_instruction(limit)) + + def make_match_orders_instruction(self, limit: int) -> TransactionInstruction: + params = instructions.MatchOrdersParams( + market=self.state.public_key(), + request_queue=self.state.request_queue(), + event_queue=self.state.event_queue(), + bids=self.state.bids(), + asks=self.state.asks(), + base_vault=self.state.base_vault(), + quote_vault=self.state.quote_vault(), + limit=limit, + program_id=self.state.program_id(), + ) + return instructions.match_orders(params) + + def _build_settle_funds_tx( # pylint: disable=too-many-arguments + self, + owner: Account, + open_orders: Union[OpenOrdersAccount, AsyncOpenOrdersAccount], + base_wallet: PublicKey, + quote_wallet: PublicKey, # TODO: add referrer_quote_wallet. + min_bal_for_rent_exemption: int, + should_wrap_sol: bool, + ) -> Transaction: + # TODO: Handle wrapped sol accounts + if open_orders.owner != owner.public_key(): + raise Exception("Invalid open orders account") + vault_signer = PublicKey.create_program_address( + [bytes(self.state.public_key()), self.state.vault_signer_nonce().to_bytes(8, byteorder="little")], + self.state.program_id(), + ) + transaction = Transaction() + signers: List[Account] = [owner] + + if should_wrap_sol: + wrapped_sol_account = Account() + signers.append(wrapped_sol_account) + # make a wrapped SOL account with enough balance to + # fund the trade, run the program, then send itself back home + transaction.add( + create_account( + CreateAccountParams( + from_pubkey=owner.public_key(), + new_account_pubkey=wrapped_sol_account.public_key(), + lamports=min_bal_for_rent_exemption, + space=ACCOUNT_LEN, + program_id=TOKEN_PROGRAM_ID, + ) + ) + ) + # this was also broken upstream. it should be minting wrapped SOL, and using the token program ID + transaction.add( + initialize_account( + InitializeAccountParams( + account=wrapped_sol_account.public_key(), + mint=WRAPPED_SOL_MINT, + owner=owner.public_key(), + program_id=TOKEN_PROGRAM_ID, + ) + ) + ) + + transaction.add( + self.make_settle_funds_instruction( + open_orders, + base_wallet if self.state.base_mint() != WRAPPED_SOL_MINT else wrapped_sol_account.public_key(), + quote_wallet if self.state.quote_mint() != WRAPPED_SOL_MINT else wrapped_sol_account.public_key(), + vault_signer, + ) + ) + + if should_wrap_sol: + # close out the account and send the funds home when the trade is completed/cancelled + transaction.add( + close_account( + CloseAccountParams( + account=wrapped_sol_account.public_key(), + owner=owner.public_key(), + dest=owner.public_key(), + program_id=TOKEN_PROGRAM_ID, + ) + ) + ) + return transaction + + def _settle_funds_should_wrap_sol(self) -> bool: + return (self.state.quote_mint() == WRAPPED_SOL_MINT) or (self.state.base_mint() == WRAPPED_SOL_MINT) + + def make_settle_funds_instruction( + self, + open_orders_account: Union[OpenOrdersAccount, AsyncOpenOrdersAccount], + base_wallet: PublicKey, + quote_wallet: PublicKey, + vault_signer: PublicKey, + ) -> TransactionInstruction: + if base_wallet == self.state.base_vault(): + raise ValueError("base_wallet should not be a vault address") + if quote_wallet == self.state.quote_vault(): + raise ValueError("quote_wallet should not be a vault address") + + return instructions.settle_funds( + instructions.SettleFundsParams( + market=self.state.public_key(), + open_orders=open_orders_account.address, + owner=open_orders_account.owner, + base_vault=self.state.base_vault(), + quote_vault=self.state.quote_vault(), + base_wallet=base_wallet, + quote_wallet=quote_wallet, + vault_signer=vault_signer, + program_id=self.state.program_id(), + ) + ) diff --git a/pyserum/market/market.py b/pyserum/market/market.py index 50f9724..d06154c 100644 --- a/pyserum/market/market.py +++ b/pyserum/market/market.py @@ -1,48 +1,41 @@ """Market module to interact with Serum DEX.""" from __future__ import annotations -import itertools -import logging from typing import List from solana.account import Account from solana.publickey import PublicKey from solana.rpc.api import Client from solana.rpc.types import RPCResponse, TxOpts -from solana.system_program import CreateAccountParams, create_account -from solana.transaction import Transaction, TransactionInstruction -from spl.token.constants import ACCOUNT_LEN, TOKEN_PROGRAM_ID, WRAPPED_SOL_MINT -from spl.token.instructions import CloseAccountParams -from spl.token.instructions import InitializeAccountParams, close_account, initialize_account +from solana.transaction import Transaction from pyserum import instructions import pyserum.market.types as t from .._layouts.open_orders import OPEN_ORDERS_LAYOUT -from ..enums import OrderType, SelfTradeBehavior, Side -from ..open_orders_account import OpenOrdersAccount, make_create_account_instruction +from ..enums import OrderType, Side +from ..open_orders_account import OpenOrdersAccount from ..utils import load_bytes_data from ._internal.queue import decode_event_queue, decode_request_queue from .orderbook import OrderBook from .state import MarketState +from .core import MarketCore LAMPORTS_PER_SOL = 1000000000 -# pylint: disable=too-many-public-methods -class Market: +# pylint: disable=too-many-public-methods,abstract-method +class Market(MarketCore): """Represents a Serum Market.""" - logger = logging.getLogger("pyserum.market.Market") - def __init__(self, conn: Client, market_state: MarketState, force_use_request_queue: bool = False) -> None: + super().__init__(market_state=market_state, force_use_request_queue=force_use_request_queue) self._conn = conn - self.state = market_state - self.force_use_request_queue = force_use_request_queue - @staticmethod + @classmethod # pylint: disable=unused-argument def load( + cls, conn: Client, market_address: PublicKey, program_id: PublicKey = instructions.DEFAULT_DEX_PROGRAM_ID, @@ -55,63 +48,29 @@ class Market: :param program_id: The program id of the given market, it will use the default value if not provided. """ market_state = MarketState.load(conn, market_address, program_id) - return Market(conn, market_state, force_use_request_queue) - - def _use_request_queue(self) -> bool: - return ( - # DEX Version 1 - self.state.program_id == PublicKey("4ckmDgGdxQoPDLUkDT3vHgSAkzA3QRdNq5ywwY4sUSJn") - or - # DEX Version 1 - self.state.program_id == PublicKey("BJ3jrUzddfuSrZHXSCxMUUQsjKEyLmuuyZebkcaFp2fg") - or - # DEX Version 2 - self.state.program_id == PublicKey("EUqojwWA2rd19FZrzeBncJsm38Jm1hEhE3zsmX3bRc2o") - or self.force_use_request_queue - ) - - def support_srm_fee_discounts(self) -> bool: - raise NotImplementedError("support_srm_fee_discounts not implemented") - - def find_fee_discount_keys(self, owner: PublicKey, cache_duration: int): - raise NotImplementedError("find_fee_discount_keys not implemented") - - def find_best_fee_discount_key(self, owner: PublicKey, cache_duration: int): - raise NotImplementedError("find_best_fee_discount_key not implemented") + return cls(conn, market_state, force_use_request_queue) def find_open_orders_accounts_for_owner(self, owner_address: PublicKey) -> List[OpenOrdersAccount]: return OpenOrdersAccount.find_for_market_and_owner( self._conn, self.state.public_key(), owner_address, self.state.program_id() ) - def find_quote_token_accounts_for_owner(self, owner_address: PublicKey, include_unwrapped_sol: bool = False): - raise NotImplementedError("find_quote_token_accounts_for_owner not implemented") - def load_bids(self) -> OrderBook: """Load the bid order book""" bytes_data = load_bytes_data(self.state.bids(), self._conn) - return OrderBook.from_bytes(self.state, bytes_data) + return self._parse_bids_or_asks(bytes_data) def load_asks(self) -> OrderBook: """Load the ask order book.""" bytes_data = load_bytes_data(self.state.asks(), self._conn) - return OrderBook.from_bytes(self.state, bytes_data) + return self._parse_bids_or_asks(bytes_data) def load_orders_for_owner(self, owner_address: PublicKey) -> List[t.Order]: """Load orders for owner.""" bids = self.load_bids() asks = self.load_asks() open_orders_accounts = self.find_open_orders_accounts_for_owner(owner_address) - if not open_orders_accounts: - return [] - - all_orders = itertools.chain(bids.orders(), asks.orders()) - open_orders_addresses = {str(o.address) for o in open_orders_accounts} - orders = [o for o in all_orders if str(o.open_order_address) in open_orders_addresses] - return orders - - def load_base_token_for_owner(self): - raise NotImplementedError("load_base_token_for_owner not implemented") + return self._parse_orders_for_owner(bids, asks, open_orders_accounts) def load_event_queue(self) -> List[t.Event]: """Load the event queue which includes the fill item and out item. For any trades two fill items are added to @@ -127,40 +86,7 @@ class Market: def load_fills(self, limit=100) -> List[t.FilledOrder]: bytes_data = load_bytes_data(self.state.event_queue(), self._conn) - events = decode_event_queue(bytes_data, limit) - return [ - self.parse_fill_event(event) - for event in events - if event.event_flags.fill and event.native_quantity_paid > 0 - ] - - def parse_fill_event(self, event) -> t.FilledOrder: - if event.event_flags.bid: - side = Side.BUY - price_before_fees = ( - event.native_quantity_released + event.native_fee_or_rebate - if event.event_flags.maker - else event.native_quantity_released - event.native_fee_or_rebate - ) - else: - side = Side.SELL - price_before_fees = ( - event.native_quantity_released - event.native_fee_or_rebate - if event.event_flags.maker - else event.native_quantity_released + event.native_fee_or_rebate - ) - - price = (price_before_fees * self.state.base_spl_token_multiplier()) / ( - self.state.quote_spl_token_multiplier() * event.native_quantity_paid - ) - size = event.native_quantity_paid / self.state.base_spl_token_multiplier() - return t.FilledOrder( - order_id=event.order_id, - side=side, - price=price, - size=size, - fee_cost=event.native_fee_or_rebate * (1 if event.event_flags.maker else -1), - ) + return self._parse_fills(bytes_data, limit) def place_order( # pylint: disable=too-many-arguments,too-many-locals self, @@ -176,250 +102,47 @@ class Market: transaction = Transaction() signers: List[Account] = [owner] open_order_accounts = self.find_open_orders_accounts_for_owner(owner.public_key()) - if not open_order_accounts: - new_open_orders_account = Account() - place_order_open_order_account = new_open_orders_account.public_key() - mbfre_resp = self._conn.get_minimum_balance_for_rent_exemption(OPEN_ORDERS_LAYOUT.sizeof()) - balanced_needed = mbfre_resp["result"] - transaction.add( - make_create_account_instruction( - owner_address=owner.public_key(), - new_account_address=new_open_orders_account.public_key(), - lamports=balanced_needed, - program_id=self.state.program_id(), - ) - ) - signers.append(new_open_orders_account) - # TODO: Cache new_open_orders_account - else: + if open_order_accounts: place_order_open_order_account = open_order_accounts[0].address + else: + mbfre_resp = self._conn.get_minimum_balance_for_rent_exemption(OPEN_ORDERS_LAYOUT.sizeof()) + place_order_open_order_account = self._after_oo_mbfre_resp( + mbfre_resp=mbfre_resp, owner=owner, signers=signers, transaction=transaction + ) + # TODO: Cache new_open_orders_account # TODO: Handle fee_discount_pubkey - # unwrapped SOL cannot be used for payment - if payer == owner.public_key(): - raise ValueError("Invalid payer account. Cannot use unwrapped SOL.") - - # TODO: add integration test for SOL wrapping. - should_wrap_sol = (side == Side.BUY and self.state.quote_mint() == WRAPPED_SOL_MINT) or ( - side == Side.SELL and self.state.base_mint() == WRAPPED_SOL_MINT + self._prepare_order_transaction( + transaction=transaction, + payer=payer, + owner=owner, + order_type=order_type, + side=side, + signers=signers, + limit_price=limit_price, + max_quantity=max_quantity, + client_id=client_id, + open_order_accounts=open_order_accounts, + place_order_open_order_account=place_order_open_order_account, ) - - if should_wrap_sol: - wrapped_sol_account = Account() - payer = wrapped_sol_account.public_key() - signers.append(wrapped_sol_account) - transaction.add( - create_account( - CreateAccountParams( - from_pubkey=owner.public_key(), - new_account_pubkey=wrapped_sol_account.public_key(), - lamports=Market._get_lamport_need_for_sol_wrapping( - limit_price, max_quantity, side, open_order_accounts - ), - space=ACCOUNT_LEN, - program_id=TOKEN_PROGRAM_ID, - ) - ) - ) - transaction.add( - initialize_account( - InitializeAccountParams( - account=wrapped_sol_account.public_key(), - mint=WRAPPED_SOL_MINT, - owner=owner.public_key(), - program_id=TOKEN_PROGRAM_ID, - ) - ) - ) - - transaction.add( - self.make_place_order_instruction( - payer=payer, - owner=owner, - order_type=order_type, - side=side, - limit_price=limit_price, - max_quantity=max_quantity, - client_id=client_id, - open_order_account=place_order_open_order_account, - ) - ) - - if should_wrap_sol: - transaction.add( - close_account( - CloseAccountParams( - account=wrapped_sol_account.public_key(), - owner=owner.public_key(), - dest=owner.public_key(), - program_id=TOKEN_PROGRAM_ID, - ) - ) - ) - # TODO: extract `make_place_order_transaction`. return self._conn.send_transaction(transaction, *signers, opts=opts) - @staticmethod - def _get_lamport_need_for_sol_wrapping( - price: float, size: float, side: Side, open_orders_accounts: List[OpenOrdersAccount] - ) -> int: - lamports = 0 - if side == Side.BUY: - lamports = round(price * size * 1.01 * LAMPORTS_PER_SOL) - if open_orders_accounts: - lamports -= open_orders_accounts[0].quote_token_free - else: - lamports = round(size * LAMPORTS_PER_SOL) - if open_orders_accounts: - lamports -= open_orders_accounts[0].base_token_free - - return max(lamports, 0) + 10000000 - - def make_place_order_instruction( # pylint: disable=too-many-arguments - self, - payer: PublicKey, - owner: Account, - order_type: OrderType, - side: Side, - limit_price: float, - max_quantity: float, - client_id: int, - open_order_account: PublicKey, - fee_discount_pubkey: PublicKey = None, - ) -> TransactionInstruction: - if self.state.base_size_number_to_lots(max_quantity) < 0: - raise Exception("Size lot %d is too small" % max_quantity) - if self.state.price_number_to_lots(limit_price) < 0: - raise Exception("Price lot %d is too small" % limit_price) - if self._use_request_queue(): - return instructions.new_order( - instructions.NewOrderParams( - market=self.state.public_key(), - open_orders=open_order_account, - payer=payer, - owner=owner.public_key(), - request_queue=self.state.request_queue(), - base_vault=self.state.base_vault(), - quote_vault=self.state.quote_vault(), - side=side, - limit_price=self.state.price_number_to_lots(limit_price), - max_quantity=self.state.base_size_number_to_lots(max_quantity), - order_type=order_type, - client_id=client_id, - program_id=self.state.program_id(), - ) - ) - return instructions.new_order_v3( - instructions.NewOrderV3Params( - market=self.state.public_key(), - open_orders=open_order_account, - payer=payer, - owner=owner.public_key(), - request_queue=self.state.request_queue(), - event_queue=self.state.event_queue(), - bids=self.state.bids(), - asks=self.state.asks(), - base_vault=self.state.base_vault(), - quote_vault=self.state.quote_vault(), - side=side, - limit_price=self.state.price_number_to_lots(limit_price), - max_base_quantity=self.state.base_size_number_to_lots(max_quantity), - max_quote_quantity=self.state.base_size_number_to_lots(max_quantity) - * self.state.quote_lot_size() - * self.state.price_number_to_lots(limit_price), - order_type=order_type, - client_id=client_id, - program_id=self.state.program_id(), - self_trade_behavior=SelfTradeBehavior.DECREMENT_TAKE, - fee_discount_pubkey=fee_discount_pubkey, - limit=65535, - ) - ) - def cancel_order_by_client_id( self, owner: Account, open_orders_account: PublicKey, client_id: int, opts: TxOpts = TxOpts() ) -> RPCResponse: - txs = Transaction().add(self.make_cancel_order_by_client_id_instruction(owner, open_orders_account, client_id)) + txs = self._build_cancel_order_by_client_id_tx( + owner=owner, open_orders_account=open_orders_account, client_id=client_id + ) return self._conn.send_transaction(txs, owner, opts=opts) - def make_cancel_order_by_client_id_instruction( - self, owner: Account, open_orders_account: PublicKey, client_id: int - ) -> TransactionInstruction: - if self._use_request_queue(): - return instructions.cancel_order_by_client_id( - instructions.CancelOrderByClientIDParams( - market=self.state.public_key(), - owner=owner.public_key(), - open_orders=open_orders_account, - request_queue=self.state.request_queue(), - client_id=client_id, - program_id=self.state.program_id(), - ) - ) - return instructions.cancel_order_by_client_id_v2( - instructions.CancelOrderByClientIDV2Params( - market=self.state.public_key(), - owner=owner.public_key(), - open_orders=open_orders_account, - bids=self.state.bids(), - asks=self.state.asks(), - event_queue=self.state.event_queue(), - client_id=client_id, - program_id=self.state.program_id(), - ) - ) - def cancel_order(self, owner: Account, order: t.Order, opts: TxOpts = TxOpts()) -> RPCResponse: - txn = Transaction().add(self.make_cancel_order_instruction(owner.public_key(), order)) + txn = self._build_cancel_order_tx(owner=owner, order=order) return self._conn.send_transaction(txn, owner, opts=opts) - def make_cancel_order_instruction(self, owner: PublicKey, order: t.Order) -> TransactionInstruction: - if self._use_request_queue(): - return instructions.cancel_order( - instructions.CancelOrderParams( - market=self.state.public_key(), - owner=owner, - open_orders=order.open_order_address, - request_queue=self.state.request_queue(), - side=order.side, - order_id=order.order_id, - open_orders_slot=order.open_order_slot, - program_id=self.state.program_id(), - ) - ) - return instructions.cancel_order_v2( - instructions.CancelOrderV2Params( - market=self.state.public_key(), - owner=owner, - open_orders=order.open_order_address, - bids=self.state.bids(), - asks=self.state.asks(), - event_queue=self.state.event_queue(), - side=order.side, - order_id=order.order_id, - open_orders_slot=order.open_order_slot, - program_id=self.state.program_id(), - ) - ) - def match_orders(self, fee_payer: Account, limit: int, opts: TxOpts = TxOpts()) -> RPCResponse: - txn = Transaction().add(self.make_match_orders_instruction(limit)) + txn = self._build_match_orders_tx(limit) return self._conn.send_transaction(txn, fee_payer, opts=opts) - def make_match_orders_instruction(self, limit: int) -> TransactionInstruction: - params = instructions.MatchOrdersParams( - market=self.state.public_key(), - request_queue=self.state.request_queue(), - event_queue=self.state.event_queue(), - bids=self.state.bids(), - asks=self.state.asks(), - base_vault=self.state.base_vault(), - quote_vault=self.state.quote_vault(), - limit=limit, - program_id=self.state.program_id(), - ) - return instructions.match_orders(params) - def settle_funds( # pylint: disable=too-many-arguments self, owner: Account, @@ -429,90 +152,16 @@ class Market: opts: TxOpts = TxOpts(), ) -> RPCResponse: # TODO: Handle wrapped sol accounts - if open_orders.owner != owner.public_key(): - raise Exception("Invalid open orders account") - vault_signer = PublicKey.create_program_address( - [bytes(self.state.public_key()), self.state.vault_signer_nonce().to_bytes(8, byteorder="little")], - self.state.program_id(), + should_wrap_sol = self._settle_funds_should_wrap_sol() + min_bal_for_rent_exemption = ( + self._conn.get_minimum_balance_for_rent_exemption(165)["result"] if should_wrap_sol else 0 + ) # value only matters if should_wrap_sol + transaction = self._build_settle_funds_tx( + owner=owner, + open_orders=open_orders, + base_wallet=base_wallet, + quote_wallet=quote_wallet, + min_bal_for_rent_exemption=min_bal_for_rent_exemption, + should_wrap_sol=should_wrap_sol, ) - transaction = Transaction() - signers: List[Account] = [owner] - - should_wrap_sol = (self.state.quote_mint() == WRAPPED_SOL_MINT) or (self.state.base_mint() == WRAPPED_SOL_MINT) - - if should_wrap_sol: - wrapped_sol_account = Account() - signers.append(wrapped_sol_account) - # make a wrapped SOL account with enough balance to - # fund the trade, run the program, then send itself back home - transaction.add( - create_account( - CreateAccountParams( - from_pubkey=owner.public_key(), - new_account_pubkey=wrapped_sol_account.public_key(), - lamports=self._conn.get_minimum_balance_for_rent_exemption(165)["result"], - space=ACCOUNT_LEN, - program_id=TOKEN_PROGRAM_ID, - ) - ) - ) - # this was also broken upstream. it should be minting wrapped SOL, and using the token program ID - transaction.add( - initialize_account( - InitializeAccountParams( - account=wrapped_sol_account.public_key(), - mint=WRAPPED_SOL_MINT, - owner=owner.public_key(), - program_id=TOKEN_PROGRAM_ID, - ) - ) - ) - - transaction.add( - self.make_settle_funds_instruction( - open_orders, - base_wallet if self.state.base_mint() != WRAPPED_SOL_MINT else wrapped_sol_account.public_key(), - quote_wallet if self.state.quote_mint() != WRAPPED_SOL_MINT else wrapped_sol_account.public_key(), - vault_signer, - ) - ) - - if should_wrap_sol: - # close out the account and send the funds home when the trade is completed/cancelled - transaction.add( - close_account( - CloseAccountParams( - account=wrapped_sol_account.public_key(), - owner=owner.public_key(), - dest=owner.public_key(), - program_id=TOKEN_PROGRAM_ID, - ) - ) - ) return self._conn.send_transaction(transaction, owner, opts=opts) - - def make_settle_funds_instruction( - self, - open_orders_account: OpenOrdersAccount, - base_wallet: PublicKey, - quote_wallet: PublicKey, - vault_signer: PublicKey, - ) -> TransactionInstruction: - if base_wallet == self.state.base_vault(): - raise ValueError("base_wallet should not be a vault address") - if quote_wallet == self.state.quote_vault(): - raise ValueError("quote_wallet should not be a vault address") - - return instructions.settle_funds( - instructions.SettleFundsParams( - market=self.state.public_key(), - open_orders=open_orders_account.address, - owner=open_orders_account.owner, - base_vault=self.state.base_vault(), - quote_vault=self.state.quote_vault(), - base_wallet=base_wallet, - quote_wallet=quote_wallet, - vault_signer=vault_signer, - program_id=self.state.program_id(), - ) - ) diff --git a/pyserum/market/state.py b/pyserum/market/state.py index 970f433..4e3861e 100644 --- a/pyserum/market/state.py +++ b/pyserum/market/state.py @@ -5,8 +5,9 @@ import math from construct import Container, Struct from solana.publickey import PublicKey from solana.rpc.api import Client +from solana.rpc.async_api import AsyncClient -from pyserum.utils import get_mint_decimals, load_bytes_data +from pyserum import utils, async_utils from .._layouts.market import MARKET_LAYOUT from .types import AccountFlags @@ -27,21 +28,34 @@ class MarketState: # pylint: disable=too-many-public-methods return MARKET_LAYOUT @staticmethod - def load(conn: Client, market_address: PublicKey, program_id: PublicKey) -> MarketState: - bytes_data = load_bytes_data(market_address, conn) + def _make_parsed_market(bytes_data: bytes) -> Container: parsed_market = MARKET_LAYOUT.parse(bytes_data) # TODO: add ownAddress check! if not parsed_market.account_flags.initialized or not parsed_market.account_flags.market: raise Exception("Invalid market") + return parsed_market - base_mint_decimals = get_mint_decimals(conn, PublicKey(parsed_market.base_mint)) - quote_mint_decimals = get_mint_decimals(conn, PublicKey(parsed_market.quote_mint)) - return MarketState(parsed_market, program_id, base_mint_decimals, quote_mint_decimals) + @classmethod + def load(cls, conn: Client, market_address: PublicKey, program_id: PublicKey) -> MarketState: + bytes_data = utils.load_bytes_data(market_address, conn) + parsed_market = cls._make_parsed_market(bytes_data) - @staticmethod + base_mint_decimals = utils.get_mint_decimals(conn, PublicKey(parsed_market.base_mint)) + quote_mint_decimals = utils.get_mint_decimals(conn, PublicKey(parsed_market.quote_mint)) + return cls(parsed_market, program_id, base_mint_decimals, quote_mint_decimals) + + @classmethod + async def async_load(cls, conn: AsyncClient, market_address: PublicKey, program_id: PublicKey) -> MarketState: + bytes_data = await async_utils.load_bytes_data(market_address, conn) + parsed_market = cls._make_parsed_market(bytes_data) + base_mint_decimals = await async_utils.get_mint_decimals(conn, PublicKey(parsed_market.base_mint)) + quote_mint_decimals = await async_utils.get_mint_decimals(conn, PublicKey(parsed_market.quote_mint)) + return cls(parsed_market, program_id, base_mint_decimals, quote_mint_decimals) + + @classmethod def from_bytes( - program_id: PublicKey, base_mint_decimals: int, quote_mint_decimals: int, buffer: bytes + cls, program_id: PublicKey, base_mint_decimals: int, quote_mint_decimals: int, buffer: bytes ) -> MarketState: parsed_market = MARKET_LAYOUT.parse(buffer) # TODO: add ownAddress check! @@ -49,7 +63,7 @@ class MarketState: # pylint: disable=too-many-public-methods if not parsed_market.account_flags.initialized or not parsed_market.account_flags.market: raise Exception("Invalid market") - return MarketState(parsed_market, program_id, base_mint_decimals, quote_mint_decimals) + return cls(parsed_market, program_id, base_mint_decimals, quote_mint_decimals) def program_id(self) -> PublicKey: return self._program_id diff --git a/pyserum/open_orders_account.py b/pyserum/open_orders_account.py index a6d3fe3..c973a56 100644 --- a/pyserum/open_orders_account.py +++ b/pyserum/open_orders_account.py @@ -1,12 +1,12 @@ from __future__ import annotations import base64 -from typing import List, NamedTuple +from typing import List, NamedTuple, TypeVar, Type, Tuple from solana.publickey import PublicKey from solana.rpc.api import Client from solana.rpc.commitment import Recent -from solana.rpc.types import Commitment, MemcmpOpts +from solana.rpc.types import Commitment, MemcmpOpts, RPCResponse from solana.system_program import CreateAccountParams, create_account from solana.transaction import TransactionInstruction @@ -23,9 +23,11 @@ class ProgramAccount(NamedTuple): owner: PublicKey -class OpenOrdersAccount: +_T = TypeVar("_T", bound="_OpenOrdersAccountCore") + + +class _OpenOrdersAccountCore: # pylint: disable=too-many-instance-attributes,too-few-public-methods # pylint: disable=too-many-arguments - # pylint: disable=too-many-instance-attributes def __init__( self, address: PublicKey, @@ -52,13 +54,13 @@ class OpenOrdersAccount: self.orders = orders self.client_ids = client_ids - @staticmethod - def from_bytes(address: PublicKey, buffer: bytes) -> OpenOrdersAccount: + @classmethod + def from_bytes(cls: Type[_T], address: PublicKey, buffer: bytes) -> _T: open_order_decoded = OPEN_ORDERS_LAYOUT.parse(buffer) if not open_order_decoded.account_flags.open_orders or not open_order_decoded.account_flags.initialized: raise Exception("Not an open order account or not initialized.") - return OpenOrdersAccount( + return cls( address=address, market=PublicKey(open_order_decoded.market), owner=PublicKey(open_order_decoded.owner), @@ -72,27 +74,8 @@ class OpenOrdersAccount: client_ids=open_order_decoded.client_ids, ) - @staticmethod - def find_for_market_and_owner( - conn: Client, market: PublicKey, owner: PublicKey, program_id: PublicKey, commitment: Commitment = Recent - ) -> List[OpenOrdersAccount]: - filters = [ - MemcmpOpts( - offset=5 + 8, # 5 bytes of padding, 8 bytes of account flag - bytes=str(market), - ), - MemcmpOpts( - offset=5 + 8 + 32, # 5 bytes of padding, 8 bytes of account flag, 32 bytes of market public key - bytes=str(owner), - ), - ] - resp = conn.get_program_accounts( - program_id, - commitment=commitment, - encoding="base64", - memcmp_opts=filters, - data_size=OPEN_ORDERS_LAYOUT.sizeof(), - ) + @classmethod + def _process_get_program_accounts_resp(cls: Type[_T], resp: RPCResponse) -> List[_T]: accounts = [] for account in resp["result"]: account_details = account["account"] @@ -106,13 +89,49 @@ class OpenOrdersAccount: ) ) - return [OpenOrdersAccount.from_bytes(account.public_key, account.data) for account in accounts] + return [cls.from_bytes(account.public_key, account.data) for account in accounts] @staticmethod - def load(conn: Client, address: str) -> OpenOrdersAccount: + def _build_get_program_accounts_args( + market: PublicKey, program_id: PublicKey, owner: PublicKey, commitment: Commitment + ) -> Tuple[PublicKey, Commitment, str, None, int, List[MemcmpOpts]]: + filters = [ + MemcmpOpts( + offset=5 + 8, # 5 bytes of padding, 8 bytes of account flag + bytes=str(market), + ), + MemcmpOpts( + offset=5 + 8 + 32, # 5 bytes of padding, 8 bytes of account flag, 32 bytes of market public key + bytes=str(owner), + ), + ] + data_slice = None + return ( + program_id, + commitment, + "base64", + data_slice, + OPEN_ORDERS_LAYOUT.sizeof(), + filters, + ) + + +class OpenOrdersAccount(_OpenOrdersAccountCore): + @classmethod + def find_for_market_and_owner( # pylint: disable=too-many-arguments + cls, conn: Client, market: PublicKey, owner: PublicKey, program_id: PublicKey, commitment: Commitment = Recent + ) -> List[OpenOrdersAccount]: + args = cls._build_get_program_accounts_args( + market=market, program_id=program_id, owner=owner, commitment=commitment + ) + resp = conn.get_program_accounts(*args) + return cls._process_get_program_accounts_resp(resp) + + @classmethod + def load(cls, conn: Client, address: str) -> OpenOrdersAccount: addr_pub_key = PublicKey(address) bytes_data = load_bytes_data(addr_pub_key, conn) - return OpenOrdersAccount.from_bytes(addr_pub_key, bytes_data) + return cls.from_bytes(addr_pub_key, bytes_data) def make_create_account_instruction( diff --git a/pyserum/utils.py b/pyserum/utils.py index 1596dde..5c87db4 100644 --- a/pyserum/utils.py +++ b/pyserum/utils.py @@ -2,23 +2,32 @@ import base64 from solana.publickey import PublicKey from solana.rpc.api import Client +from solana.rpc.types import RPCResponse from spl.token.constants import WRAPPED_SOL_MINT from pyserum._layouts.market import MINT_LAYOUT -def load_bytes_data(addr: PublicKey, conn: Client): - res = conn.get_account_info(addr) +def parse_bytes_data(res: RPCResponse) -> bytes: if ("result" not in res) or ("value" not in res["result"]) or ("data" not in res["result"]["value"]): raise Exception("Cannot load byte data.") data = res["result"]["value"]["data"][0] return base64.decodebytes(data.encode("ascii")) +def load_bytes_data(addr: PublicKey, conn: Client) -> bytes: + res = conn.get_account_info(addr) + return parse_bytes_data(res) + + +def parse_mint_decimals(bytes_data: bytes) -> int: + return MINT_LAYOUT.parse(bytes_data).decimals + + def get_mint_decimals(conn: Client, mint_pub_key: PublicKey) -> int: """Get the mint decimals for a token mint""" if mint_pub_key == WRAPPED_SOL_MINT: return 9 bytes_data = load_bytes_data(mint_pub_key, conn) - return MINT_LAYOUT.parse(bytes_data).decimals + return parse_mint_decimals(bytes_data) diff --git a/pytest.ini b/pytest.ini index f06c069..a43aee8 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,5 @@ [pytest] +addopts = -p no:anyio markers = integration: mark a test as a integration test. + async_integration: mark a test as an async_integration test. diff --git a/scripts/bootstrap_dex.sh b/scripts/bootstrap_dex.sh index 08215b2..73eb789 100755 --- a/scripts/bootstrap_dex.sh +++ b/scripts/bootstrap_dex.sh @@ -18,11 +18,13 @@ if ! hash solana 2>/dev/null; then echo Installing Solana tool suite ... curl -sSf https://raw.githubusercontent.com/solana-labs/solana/v1.5.8/install/solana-install-init.sh | SOLANA_RELEASE=v1.5.8 sh -s - v1.5.8 export PATH="/home/runner/.local/share/solana/install/active_release/bin:$PATH" - echo Generating keypair ... - solana-keygen new -o ~/.config/solana/id.json --no-passphrase --silent + if [ ! -f ~/.config/solana/id.json ]; then + echo Generating keypair ... + solana-keygen new -o ~/.config/solana/id.json --no-passphrase --silent + fi fi -solana-test-validator & +solana-test-validator & echo $! > solana_test_validator.pid solana config set --url "http://127.0.0.1:8899" curl -s -L "https://github.com/serum-community/serum-dex/releases/download/v2/serum_dex-$os_type.so" > serum_dex.so sleep 1 diff --git a/scripts/clean_up.sh b/scripts/clean_up.sh index 4bdfb28..f1da4d1 100644 --- a/scripts/clean_up.sh +++ b/scripts/clean_up.sh @@ -5,3 +5,6 @@ if [[ $KEEP_ARTIFACTS == "" ]]; then rm -rf tests/crank.log crank serum_dex.so fi docker-compose down +kill $(<"solana_test_validator.pid") +rm -rf test-ledger +rm -rf solana_test_validator.pid diff --git a/scripts/run_async_int_tests.sh b/scripts/run_async_int_tests.sh new file mode 100755 index 0000000..2be5699 --- /dev/null +++ b/scripts/run_async_int_tests.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +bash scripts/bootstrap_dex.sh + +wait_time=20 +echo "Waiting $wait_time seconds to make sure the market has started" +sleep $wait_time + + + +exit_code=1 +if (pipenv run pytest -vv -m async_integration); then + echo "The script ran ok" + exit_code=0 +fi + +bash scripts/clean_up.sh + +exit $exit_code diff --git a/scripts/run_coverage.sh b/scripts/run_coverage.sh index b7b9d71..8cea9a5 100755 --- a/scripts/run_coverage.sh +++ b/scripts/run_coverage.sh @@ -1,5 +1,7 @@ #!/bin/bash +pipenv run pytest -m "not integration and not async_integration" --cov=./ --cov-report=xml --cov-append + bash scripts/bootstrap_dex.sh wait_time=20 @@ -7,9 +9,18 @@ echo "Waiting $wait_time seconds to make sure the market has started" sleep $wait_time +pipenv run pytest -m integration --cov=./ --cov-report=xml --cov-append + +bash scripts/clean_up.sh + +bash scripts/bootstrap_dex.sh + +wait_time=20 +echo "Waiting $wait_time seconds to make sure the market has started" +sleep $wait_time exit_code=1 -if (pipenv run pytest --cov=./ --cov-report=xml); then +if (pipenv run pytest -m async_integration --cov=./ --cov-report=xml --cov-append); then echo "The script ran ok" exit_code=0 fi diff --git a/setup.py b/setup.py index bb4fd00..6fc3526 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ from setuptools import find_packages, setup setup( name="pyserum", - version="0.3.5a1", + version="0.4.0a1", author="serum-community", description="""Python client library for interacting with the Project Serum DEX.""", install_requires=[ diff --git a/tests/conftest.py b/tests/conftest.py index 5665770..413ff27 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,14 @@ from typing import Dict +import asyncio import pytest from solana.account import Account from solana.publickey import PublicKey from solana.rpc.api import Client +from solana.rpc.async_api import AsyncClient from pyserum.connection import conn +from pyserum.async_connection import async_conn @pytest.mark.integration @@ -150,3 +153,24 @@ def http_client() -> Client: if not cc.is_connected(): raise Exception("Could not connect to local node. Please run `make int-tests` to run integration tests.") return cc + + +@pytest.fixture(scope="session") +def event_loop(): + """Event loop for pytest-asyncio.""" + loop = asyncio.get_event_loop() + yield loop + loop.close() + + +@pytest.mark.async_integration +@pytest.fixture(scope="session") +def async_http_client(event_loop) -> AsyncClient: # pylint: disable=redefined-outer-name + """Solana async http client.""" + cc = async_conn("http://localhost:8899") # pylint: disable=invalid-name + if not event_loop.run_until_complete(cc.is_connected()): + raise Exception( + "Could not connect to local node. Please run `make async-int-tests` to run async integration tests." + ) + yield cc + event_loop.run_until_complete(cc.close()) diff --git a/tests/integration/test_async_connection.py b/tests/integration/test_async_connection.py new file mode 100644 index 0000000..699170f --- /dev/null +++ b/tests/integration/test_async_connection.py @@ -0,0 +1,24 @@ +# pylint: disable=R0801 +import pytest +import httpx + +from pyserum.async_connection import get_live_markets, get_token_mints +from pyserum.market.types import MarketInfo, TokenInfo + + +@pytest.mark.async_integration +@pytest.mark.asyncio +async def test_get_live_markets(): + """Test get_live_markets.""" + async with httpx.AsyncClient() as client: + resp = await get_live_markets(client) + assert all(isinstance(market_info, MarketInfo) for market_info in resp) + + +@pytest.mark.async_integration +@pytest.mark.asyncio +async def test_get_token_mints(): + """Test get_token_mints.""" + async with httpx.AsyncClient() as client: + resp = await get_token_mints(client) + assert all(isinstance(token_info, TokenInfo) for token_info in resp) diff --git a/tests/integration/test_async_market.py b/tests/integration/test_async_market.py new file mode 100644 index 0000000..aff09c9 --- /dev/null +++ b/tests/integration/test_async_market.py @@ -0,0 +1,204 @@ +# pylint: disable=redefined-outer-name + +import pytest +from solana.account import Account +from solana.publickey import PublicKey +from solana.rpc.async_api import AsyncClient +from solana.rpc.types import TxOpts + +from pyserum.enums import OrderType, Side +from pyserum.market import AsyncMarket + + +@pytest.mark.async_integration +@pytest.fixture(scope="module") +def bootstrapped_market( + async_http_client: AsyncClient, stubbed_market_pk: PublicKey, stubbed_dex_program_pk: PublicKey, event_loop +) -> AsyncMarket: + return event_loop.run_until_complete( + AsyncMarket.load(async_http_client, stubbed_market_pk, stubbed_dex_program_pk, force_use_request_queue=True) + ) + + +@pytest.mark.async_integration +@pytest.mark.asyncio +async def test_bootstrapped_market( + bootstrapped_market: AsyncMarket, + stubbed_market_pk: PublicKey, + stubbed_dex_program_pk: PublicKey, + stubbed_base_mint: PublicKey, + stubbed_quote_mint: PublicKey, +): + assert isinstance(bootstrapped_market, AsyncMarket) + assert bootstrapped_market.state.public_key() == stubbed_market_pk + assert bootstrapped_market.state.program_id() == stubbed_dex_program_pk + assert bootstrapped_market.state.base_mint() == stubbed_base_mint.public_key() + assert bootstrapped_market.state.quote_mint() == stubbed_quote_mint.public_key() + + +@pytest.mark.async_integration +@pytest.mark.asyncio +async def test_market_load_bid(bootstrapped_market: AsyncMarket): + # TODO: test for non-zero order case. + bids = await bootstrapped_market.load_bids() + assert sum(1 for _ in bids) == 0 + + +@pytest.mark.async_integration +@pytest.mark.asyncio +async def test_market_load_asks(bootstrapped_market: AsyncMarket): + # TODO: test for non-zero order case. + asks = await bootstrapped_market.load_asks() + assert sum(1 for _ in asks) == 0 + + +@pytest.mark.async_integration +@pytest.mark.asyncio +async def test_market_load_events(bootstrapped_market: AsyncMarket): + event_queue = await bootstrapped_market.load_event_queue() + assert len(event_queue) == 0 + + +@pytest.mark.async_integration +@pytest.mark.asyncio +async def test_market_load_requests(bootstrapped_market: AsyncMarket): + request_queue = await bootstrapped_market.load_request_queue() + # 2 requests in the request queue in the beginning with one bid and one ask + assert len(request_queue) == 2 + + +@pytest.mark.async_integration +@pytest.mark.asyncio +async def test_match_order(bootstrapped_market: AsyncMarket, stubbed_payer: Account): + await bootstrapped_market.match_orders(stubbed_payer, 2, TxOpts(skip_confirmation=False)) + + request_queue = await bootstrapped_market.load_request_queue() + # 0 request after matching. + assert len(request_queue) == 0 + + event_queue = await bootstrapped_market.load_event_queue() + # 5 event after the order is matched, including 2 fill events. + assert len(event_queue) == 5 + + # There should be no bid order. + bids = await bootstrapped_market.load_bids() + assert sum(1 for _ in bids) == 0 + + # There should be no ask order. + asks = await bootstrapped_market.load_asks() + assert sum(1 for _ in asks) == 0 + + +@pytest.mark.async_integration +@pytest.mark.asyncio +async def test_settle_fund( + bootstrapped_market: AsyncMarket, + stubbed_payer: Account, + stubbed_quote_wallet: Account, + stubbed_base_wallet: Account, +): + open_order_accounts = await bootstrapped_market.find_open_orders_accounts_for_owner(stubbed_payer.public_key()) + + with pytest.raises(ValueError): + # Should not allow base_wallet to be base_vault + await bootstrapped_market.settle_funds( + stubbed_payer, + open_order_accounts[0], + bootstrapped_market.state.base_vault(), + stubbed_quote_wallet.public_key(), + ) + + with pytest.raises(ValueError): + # Should not allow quote_wallet to be wallet_vault + await bootstrapped_market.settle_funds( + stubbed_payer, + open_order_accounts[0], + stubbed_base_wallet.public_key(), + bootstrapped_market.state.quote_vault(), + ) + + for open_order_account in open_order_accounts: + assert "error" not in await bootstrapped_market.settle_funds( + stubbed_payer, + open_order_account, + stubbed_base_wallet.public_key(), + stubbed_quote_wallet.public_key(), + opts=TxOpts(skip_confirmation=False), + ) + + # TODO: Check account states after settling funds + + +@pytest.mark.async_integration +@pytest.mark.asyncio +async def test_order_placement_cancellation_cycle( + bootstrapped_market: AsyncMarket, + stubbed_payer: Account, + stubbed_quote_wallet: Account, + stubbed_base_wallet: Account, +): + initial_request_len = len(await bootstrapped_market.load_request_queue()) + await bootstrapped_market.place_order( + payer=stubbed_quote_wallet.public_key(), + owner=stubbed_payer, + side=Side.BUY, + order_type=OrderType.LIMIT, + limit_price=1000, + max_quantity=3000, + opts=TxOpts(skip_confirmation=False), + ) + + request_queue = await bootstrapped_market.load_request_queue() + # 0 request after matching. + assert len(request_queue) == initial_request_len + 1 + + # There should be no bid order. + bids = await bootstrapped_market.load_bids() + assert sum(1 for _ in bids) == 0 + + # There should be no ask order. + asks = await bootstrapped_market.load_asks() + assert sum(1 for _ in asks) == 0 + + await bootstrapped_market.place_order( + payer=stubbed_base_wallet.public_key(), + owner=stubbed_payer, + side=Side.SELL, + order_type=OrderType.LIMIT, + limit_price=1500, + max_quantity=3000, + opts=TxOpts(skip_confirmation=False), + ) + + # The two order shouldn't get executed since there is a price difference of 1 + await bootstrapped_market.match_orders( + stubbed_payer, + 2, + opts=TxOpts(skip_confirmation=False), + ) + + # There should be 1 bid order that we sent earlier. + bids = await bootstrapped_market.load_bids() + assert sum(1 for _ in bids) == 1 + + # There should be 1 ask order that we sent earlier. + asks = await bootstrapped_market.load_asks() + assert sum(1 for _ in asks) == 1 + + for bid in bids: + await bootstrapped_market.cancel_order(stubbed_payer, bid, opts=TxOpts(skip_confirmation=False)) + + await bootstrapped_market.match_orders(stubbed_payer, 1, opts=TxOpts(skip_confirmation=False)) + + # All bid order should have been cancelled. + bids = await bootstrapped_market.load_bids() + assert sum(1 for _ in bids) == 0 + + for ask in asks: + await bootstrapped_market.cancel_order(stubbed_payer, ask, opts=TxOpts(skip_confirmation=False)) + + await bootstrapped_market.match_orders(stubbed_payer, 1, opts=TxOpts(skip_confirmation=False)) + + # All ask order should have been cancelled. + asks = await bootstrapped_market.load_asks() + assert sum(1 for _ in asks) == 0 diff --git a/tests/integration/test_connection.py b/tests/integration/test_connection.py index 8d17dcd..e93198a 100644 --- a/tests/integration/test_connection.py +++ b/tests/integration/test_connection.py @@ -1,12 +1,16 @@ +import pytest + from pyserum.connection import get_live_markets, get_token_mints from pyserum.market.types import MarketInfo, TokenInfo +@pytest.mark.integration def test_get_live_markets(): """Test get_live_markets.""" assert all(isinstance(market_info, MarketInfo) for market_info in get_live_markets()) +@pytest.mark.integration def test_get_token_mints(): """Test get_token_mints.""" assert all(isinstance(token_info, TokenInfo) for token_info in get_token_mints())