Added match and cancel order to market class.

Added the following:

- place order API
- cancel order API
- integration tests for existing functionalities
- split the run_int-test script into several files so that we can just start the docker image without running the integration tests
- fix the issue when the test fails but the build is still red.
This commit is contained in:
Leonard G 2020-09-15 11:52:19 +08:00 committed by GitHub
parent bdd3e418ec
commit 969c14f3e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 293 additions and 75 deletions

View File

@ -18,7 +18,7 @@ unit-tests:
pipenv run pytest -v -m "not integration"
int-tests:
sh scripts/run_int_tests.sh
bash scripts/run_int_tests.sh
# Minimal makefile for Sphinx documentation
#

36
scripts/bootstrap_dex.sh Normal file
View File

@ -0,0 +1,36 @@
#!/bin/bash
set -e
os_type=""
if [[ $OSTYPE == "linux-gnu"* ]]; then
os_type="linux"
elif [[ $OSTYPE == "darwin"* ]]; then
os_type="darwin"
else
echo "$OSTYPE is not supported."
exit 1
fi
docker-compose up -d
if ! hash solana 2>/dev/null; then
echo Installing Solana tool suite...
curl -sSf https://raw.githubusercontent.com/solana-labs/solana/v1.3.9/install/solana-install-init.sh | sh -s - v1.3.9
export PATH="/home/runner/.local/share/solana/install/active_release/bin:$PATH"
echo Generating keypair...
solana-keygen new -o ~/.config/solana/id.json --no-passphrase --silent
fi
solana config set --url "http://localhost:8899"
curl -s -L "https://github.com/serum-community/serum-dex/releases/download/refs%2Fheads%2Fmaster/serum_dex-$os_type.so" > serum_dex.so
sleep 1
solana airdrop 10000
DEX_PROGRAM_ID="$(solana deploy --use-deprecated-loader serum_dex.so | jq .programId -r)"
echo DEX_PROGRAM_ID: "$DEX_PROGRAM_ID"
curl -s -L "https://github.com/serum-community/serum-dex/releases/download/refs%2Fheads%2Fmaster/crank-$os_type" > crank
chmod +x crank
./crank l pyserum-setup ~/.config/solana/id.json "$DEX_PROGRAM_ID"
echo "dex_program_id: $DEX_PROGRAM_ID" >> crank.log
mv crank.log tests/crank.log
cat tests/crank.log

7
scripts/clean_up.sh Normal file
View File

@ -0,0 +1,7 @@
#!/bin/bash
if [[ $KEEP_ARTIFACTS == "" ]]; then
echo Deleting artifacts...
rm -rf tests/crank.log crank serum_dex.so
fi
docker-compose down

View File

@ -1,41 +1,19 @@
#!/bin/bash
set -e
bash scripts/bootstrap_dex.sh
os_type=""
wait_time=20
echo "Waiting $wait_time seconds to make sure the market has started"
sleep $wait_time
if [[ $OSTYPE == "linux-gnu"* ]]; then
os_type="linux"
elif [[ $OSTYPE == "darwin"* ]]; then
os_type="darwin"
else
echo "$OSTYPE is not supported."
exit 1
exit_code=1
if (pipenv run pytest -vv -m integration); then
echo "The script ran ok"
exit_code=0
fi
docker-compose up -d
if ! hash solana 2>/dev/null; then
echo Installing Solana tool suite...
curl -sSf https://raw.githubusercontent.com/solana-labs/solana/v1.3.9/install/solana-install-init.sh | sh -s - v1.3.9
export PATH="/home/runner/.local/share/solana/install/active_release/bin:$PATH"
echo Generating keypair...
solana-keygen new -o ~/.config/solana/id.json --no-passphrase --silent
fi
solana config set --url "http://localhost:8899"
curl -s -L "https://github.com/serum-community/serum-dex/releases/download/refs%2Fheads%2Fmaster/serum_dex-$os_type.so" > serum_dex.so
sleep 1
solana airdrop 10000
DEX_PROGRAM_ID="$(solana deploy --use-deprecated-loader serum_dex.so | jq .programId -r)"
echo DEX_PROGRAM_ID: $DEX_PROGRAM_ID
curl -s -L "https://github.com/serum-community/serum-dex/releases/download/refs%2Fheads%2Fmaster/crank-$os_type" > crank
chmod +x crank
./crank l pyserum-setup ~/.config/solana/id.json $DEX_PROGRAM_ID
echo "dex_program_id: $DEX_PROGRAM_ID" >> crank.log
mv crank.log tests/crank.log
cat tests/crank.log
pipenv run pytest -vv -m integration
if [[ $KEEP_ARTIFACTS == "" ]]; then
echo Deleting artifacts...
rm -rf tests/crank.log crank serum_dex.so
fi
docker-compose down
bash scripts/clean_up.sh
exit $exit_code

View File

@ -1,31 +1,32 @@
"""Market module to interact with Serum DEX."""
from __future__ import annotations
import base64
import logging
import math
from typing import Any, Iterable, List, NamedTuple, Tuple
from solana.account import Account
from solana.publickey import PublicKey
from solana.rpc.api import Client
from solana.transaction import Transaction
from solana.transaction import Transaction, TransactionInstruction
from ._layouts.account_flags import ACCOUNT_FLAGS_LAYOUT
from ._layouts.market import MARKET_LAYOUT, MINT_LAYOUT
from ._layouts.slab import Slab
from .enums import Side
from .instructions import DEFAULT_DEX_PROGRAM_ID, NewOrderParams
from .instructions import DEFAULT_DEX_PROGRAM_ID, CancelOrderParams, MatchOrdersParams, NewOrderParams
from .instructions import cancel_order as cancel_order_inst
from .instructions import match_orders as match_order_inst
from .queue_ import decode_event_queue, decode_request_queue
from .utils import load_bytes_data
def _load_bytes_data(addr: PublicKey, endpoint: str):
res = Client(endpoint).get_account_info(addr)
data = res["result"]["value"]["data"][0]
return base64.decodebytes(data.encode("ascii"))
# pylint: disable=too-many-public-methods
class Market:
"""Represents a Serum Market."""
logger = logging.getLogger("serum.market")
_decode: Any
_baseSplTokenDecimals: int
_quoteSplTokenDecimals: int
@ -60,7 +61,7 @@ class Market:
endpoint: str, market_address: str, options: Any, program_id: PublicKey = DEFAULT_DEX_PROGRAM_ID
) -> Market:
"""Factory method to create a Market."""
bytes_data = _load_bytes_data(PublicKey(market_address), endpoint)
bytes_data = load_bytes_data(PublicKey(market_address), endpoint)
market_state = MARKET_LAYOUT.parse(bytes_data)
# TODO: add ownAddress check!
@ -70,19 +71,25 @@ class Market:
base_mint_decimals = Market.get_mint_decimals(endpoint, PublicKey(market_state.base_mint))
quote_mint_decimals = Market.get_mint_decimals(endpoint, PublicKey(market_state.quote_mint))
return Market(market_state, base_mint_decimals, quote_mint_decimals, options, endpoint)
return Market(market_state, base_mint_decimals, quote_mint_decimals, options, endpoint, program_id=program_id)
def address(self) -> PublicKey:
"""Return market address."""
raise NotImplementedError("address is not implemented yet")
return PublicKey(self._decode.own_address)
def public_key(self) -> PublicKey:
return self.address()
def program_id(self) -> PublicKey:
return self._program_id
def base_mint_address(self) -> PublicKey:
"""Returns base mint address."""
raise NotImplementedError("base_mint_address is not implemented yet")
return PublicKey(self._decode.base_mint)
def quote_mint_address(self) -> PublicKey:
"""Returns quote mint address."""
raise NotImplementedError("quote_mint_address is not implemented yet")
return PublicKey(self._decode.quote_mint)
def __base_spl_token_multiplier(self) -> int:
return 10 ** self._base_spl_token_decimals
@ -113,34 +120,34 @@ class Market:
@staticmethod
def get_mint_decimals(endpoint: str, mint_pub_key: PublicKey) -> int:
"""Get the mint decimals from given public key."""
bytes_data = _load_bytes_data(mint_pub_key, endpoint)
bytes_data = load_bytes_data(mint_pub_key, endpoint)
return MINT_LAYOUT.parse(bytes_data).decimals
def load_bids(self):
def load_bids(self) -> OrderBook:
"""Load the bid order book"""
bids_addr = PublicKey(self._decode.bids)
bytes_data = _load_bytes_data(bids_addr, self._endpoint)
bytes_data = load_bytes_data(bids_addr, self._endpoint)
return OrderBook.decode(self, bytes_data)
def load_asks(self):
def load_asks(self) -> OrderBook:
"""Load the Ask order book."""
asks_addr = PublicKey(self._decode.asks)
bytes_data = _load_bytes_data(asks_addr, self._endpoint)
bytes_data = load_bytes_data(asks_addr, self._endpoint)
return OrderBook.decode(self, bytes_data)
def load_event_queue(self):
def load_event_queue(self): # returns raw construct type
event_queue_addr = PublicKey(self._decode.event_queue)
bytes_data = _load_bytes_data(event_queue_addr, self._endpoint)
bytes_data = load_bytes_data(event_queue_addr, self._endpoint)
return decode_event_queue(bytes_data)
def load_request_queue(self):
def load_request_queue(self): # returns raw construct type
request_queue_addr = PublicKey(self._decode.request_queue)
bytes_data = _load_bytes_data(request_queue_addr, self._endpoint)
bytes_data = load_bytes_data(request_queue_addr, self._endpoint)
return decode_request_queue(bytes_data)
def load_fills(self, limit=100):
def load_fills(self, limit=100) -> List[FilledOrder]:
event_queue_addr = PublicKey(self._decode.event_queue)
bytes_data = _load_bytes_data(event_queue_addr, self._endpoint)
bytes_data = load_bytes_data(event_queue_addr, self._endpoint)
events = decode_event_queue(bytes_data, limit)
return [
self.parse_fill_event(event)
@ -148,7 +155,7 @@ class Market:
if event.event_flags.fill and event.native_quantity_paid > 0
]
def parse_fill_event(self, event):
def parse_fill_event(self, event) -> FilledOrder:
if event.event_flags.bid:
side = Side.Buy
price_before_fees = (
@ -185,6 +192,56 @@ class Market:
def find_open_orders_accounts_for_owner(self, owner_address: PublicKey):
pass
def cancel_order_by_client_id(self, owner: str) -> str:
pass
def cancel_order(self, owner: Account, order: Order) -> str:
transaction = Transaction()
transaction.add(self.make_cancel_order_instruction(owner.public_key(), order))
return self._send_transaction(transaction, owner)
def match_orders(self, fee_payer: Account, limit: int) -> str:
transaction = Transaction()
transaction.add(self.make_match_orders_instruction(limit))
return self._send_transaction(transaction, fee_payer)
def make_cancel_order_instruction(self, owner: PublicKey, order: Order) -> TransactionInstruction:
params = CancelOrderParams(
market=self.address(),
owner=owner,
open_orders=order.open_order_address,
request_queue=self._decode.request_queue,
side=order.side,
order_id=order.order_id,
open_orders_slot=order.open_order_slot,
program_id=self._program_id,
)
return cancel_order_inst(params)
def make_match_orders_instruction(self, limit: int) -> TransactionInstruction:
params = MatchOrdersParams(
market=self.address(),
request_queue=PublicKey(self._decode.request_queue),
event_queue=PublicKey(self._decode.event_queue),
bids=PublicKey(self._decode.bids),
asks=PublicKey(self._decode.asks),
base_vault=PublicKey(self._decode.base_vault),
quote_vault=PublicKey(self._decode.quote_vault),
limit=limit,
program_id=self._program_id,
)
return match_order_inst(params)
def _send_transaction(self, transaction: Transaction, *signers: Account) -> str:
connection = Client(self._endpoint)
res = connection.send_transaction(transaction, *signers, skip_preflight=self._skip_preflight)
if self._confirmations > 0:
self.logger.warning("Cannot confirm transaction yet.")
signature = res.get("result")
if not signature:
raise Exception("Transaction not sent successfully.")
return str(signature)
class FilledOrder(NamedTuple):
order_id: int
@ -205,9 +262,10 @@ class Order(NamedTuple):
order_id: int
client_id: int
open_order_address: PublicKey
open_order_slot: int
fee_tier: int
order_info: OrderInfo
side: str
side: Side
# The key is constructed as the (price << 64) + (seq_no if ask_order else !seq_no)
@ -281,5 +339,6 @@ class OrderBook:
size=self._market.base_size_lots_to_number(node.quantity),
size_lots=node.quantity,
),
side="buy" if self._is_bids else "sell",
side=Side.Buy if self._is_bids else Side.Sell,
open_order_slot=node.owner_slot,
)

View File

@ -6,9 +6,10 @@ from solana.publickey import PublicKey
from solana.rpc.api import Client
from ._layouts.open_orders import OPEN_ORDERS_LAYOUT
from .utils import load_bytes_data
class OpenOrder:
class OpenOrderAccount:
# pylint: disable=too-many-arguments
# pylint: disable=too-many-instance-attributes
def __init__(
@ -38,9 +39,9 @@ class OpenOrder:
self.client_ids = client_ids
@staticmethod
def from_bytes(address: PublicKey, data_bytes: bytes) -> OpenOrder:
def from_bytes(address: PublicKey, data_bytes: bytes) -> OpenOrderAccount:
open_order_decoded = OPEN_ORDERS_LAYOUT.parse(data_bytes)
return OpenOrder(
return OpenOrderAccount(
address=address,
market=PublicKey(open_order_decoded.market),
owner=PublicKey(open_order_decoded.owner),
@ -54,5 +55,12 @@ class OpenOrder:
client_ids=open_order_decoded.client_ids,
)
def find_for_market_and_owner(self, connection: Client, market: PublicKey, owner: PublicKey):
@staticmethod
def find_for_market_and_owner(connection: Client, market: PublicKey, owner: PublicKey):
pass
@staticmethod
def load(endpoint: str, address: str) -> OpenOrderAccount:
addr_pub_key = PublicKey(address)
bytes_data = load_bytes_data(addr_pub_key, endpoint)
return OpenOrderAccount.from_bytes(addr_pub_key, bytes_data)

View File

@ -7,14 +7,18 @@ from ._layouts.queue import EVENT_LAYOUT, QUEUE_HEADER_LAYOUT, REQUEST_LAYOUT
# Expect header_layout and node_layout to be construct layout
def _decode_queue(header_layout: Any, node_layout: Any, buffer: bytes, history: Optional[int]) -> Tuple[Any, Any]:
header = header_layout.parse(buffer)
alloc_len = math.floor(float(len(buffer) - header_layout.sizeof() / node_layout.sizeof()))
alloc_len = math.floor(len(buffer) - header_layout.sizeof() / node_layout.sizeof())
nodes = []
num_of_nodes = min(history, alloc_len) if history else header.count
for i in range(num_of_nodes):
node_index = (header.head + header.count + alloc_len - 1 - i) % alloc_len
nodes.append(
node_layout.parse(buffer[header_layout.sizeof() + node_index * node_layout.sizeof() :]) # noqa: E203
)
if history:
for i in range(min(history, alloc_len)):
node_index = (header.head + header.count + alloc_len - 1 - i) % alloc_len
offset = header_layout.sizeof() + node_index * node_layout.sizeof()
nodes.append(node_layout.parse(buffer[offset : offset + node_layout.sizeof()])) # noqa: E203 # noqa: E203
else:
for i in range(header.count):
node_index = (header.head + i) % alloc_len
offset = header_layout.sizeof() + node_index * node_layout.sizeof()
nodes.append(node_layout.parse(buffer[offset : offset + node_layout.sizeof()])) # noqa: E203 # noqa: E203
return header, nodes

12
src/utils.py Normal file
View File

@ -0,0 +1,12 @@
import base64
from solana.publickey import PublicKey
from solana.rpc.api import Client
def load_bytes_data(addr: PublicKey, endpoint: str):
res = Client(endpoint).get_account_info(addr)
if ("result" not in res) or ("value" not in res["result"]) or ("data" not in res["result"]["value"]):
raise Exception("Cannot load byte data.")
data = res["result"]["value"]["data"][0]
return base64.decodebytes(data.encode("ascii"))

View File

@ -3,6 +3,7 @@ from typing import Dict
import pytest
from solana.account import Account
from solana.publickey import PublicKey
from solana.rpc.api import Client
__cached_params = {}
@ -139,3 +140,11 @@ def stubbed_bid_account_pk(__bs_params) -> PublicKey:
def stubbed_ask_account_pk(__bs_params) -> PublicKey:
"""Public key of the initial ask order account."""
return PublicKey(__bs_params["ask_account"])
@pytest.mark.integration
@pytest.fixture(scope="session")
def http_client() -> Client:
"""Solana http client."""
client = Client()
return client

View File

View File

@ -0,0 +1,79 @@
# pylint: disable=redefined-outer-name
import pytest
from solana.account import Account
from solana.publickey import PublicKey
from solana.rpc.api import Client
from src.market import Market
from .utils import confirm_transaction
@pytest.mark.integration
@pytest.fixture(scope="session")
def bootstrapped_market(stubbed_market_pk: PublicKey, stubbed_dex_program_pk: PublicKey) -> Market:
return Market.load("http://localhost:8899", str(stubbed_market_pk), None, program_id=stubbed_dex_program_pk)
@pytest.mark.integration
def test_bootstrapped_market(
bootstrapped_market: Market,
stubbed_market_pk: PublicKey,
stubbed_dex_program_pk: PublicKey,
stubbed_base_mint: PublicKey,
stubbed_quote_mint: PublicKey,
):
assert isinstance(bootstrapped_market, Market)
assert bootstrapped_market.address() == stubbed_market_pk
assert bootstrapped_market.program_id() == stubbed_dex_program_pk
assert bootstrapped_market.base_mint_address() == stubbed_base_mint.public_key()
assert bootstrapped_market.quote_mint_address() == stubbed_quote_mint.public_key()
@pytest.mark.integration
def test_market_load_bid(bootstrapped_market: Market):
# TODO: test for non-zero order case.
bids = bootstrapped_market.load_bids()
assert sum(1 for _ in bids) == 0
@pytest.mark.integration
def test_market_load_asks(bootstrapped_market: Market):
# TODO: test for non-zero order case.
asks = bootstrapped_market.load_asks()
assert sum(1 for _ in asks) == 0
@pytest.mark.integration
def test_market_load_events(bootstrapped_market: Market):
event_queue = bootstrapped_market.load_event_queue()
assert len(event_queue) == 0
@pytest.mark.integration
def test_market_load_requests(bootstrapped_market: Market):
request_queue = bootstrapped_market.load_request_queue()
# 2 requests in the request queue in the beginning with one bid and one ask
assert len(request_queue) == 2
@pytest.mark.integration
def test_match_order(bootstrapped_market: Market, stubbed_payer: Account, http_client: Client):
sig = bootstrapped_market.match_orders(stubbed_payer, 2)
confirm_transaction(http_client, sig)
request_queue = bootstrapped_market.load_request_queue()
# 0 request after matching.
assert len(request_queue) == 0
event_queue = bootstrapped_market.load_event_queue()
# 5 event after the order is matched, including 2 fill events.
assert len(event_queue) == 5
# There should be no bid order.
bids = bootstrapped_market.load_bids()
assert sum(1 for _ in bids) == 0
# There should be no ask order.
asks = bootstrapped_market.load_asks()
assert sum(1 for _ in asks) == 0

View File

@ -0,0 +1,26 @@
import time
from solana.rpc.api import Client
from solana.rpc.types import RPCResponse
DEFAULT_MAX_TIMEOUT = 30 # 30 seconds pylint: disable=invalid-name
def confirm_transaction(client: Client, tx_sig: str, time_out: int = DEFAULT_MAX_TIMEOUT) -> RPCResponse:
"""Confirm a transaction."""
elapsed_time = 0
while elapsed_time < time_out:
sleep_time = 3
if not elapsed_time:
sleep_time = 7
time.sleep(sleep_time)
else:
time.sleep(sleep_time)
resp = client.get_confirmed_transaction(tx_sig)
if resp.get("result"):
break
elapsed_time += sleep_time
if not resp.get("result"):
raise RuntimeError("could not confirm transaction: ", tx_sig)
return resp