pyth-crosschain/express_relay/sdk/python/express_relay/client.py

443 lines
16 KiB
Python

import asyncio
from asyncio import Task
import json
import urllib.parse
from typing import Callable, Any
from collections.abc import Coroutine
from uuid import UUID
import httpx
import web3
import websockets
from websockets.client import WebSocketClientProtocol
from eth_abi import encode
from eth_account.account import Account
from web3.auto import w3
from express_relay.types import (
Opportunity,
BidStatusUpdate,
ClientMessage,
BidStatus,
Bid,
OpportunityBid,
OpportunityParams,
)
class ExpressRelayClientException(Exception):
pass
class ExpressRelayClient:
def __init__(
self,
server_url: str,
opportunity_callback: (
Callable[[Opportunity], Coroutine[Any, Any, Any]] | None
) = None,
bid_status_callback: (
Callable[[BidStatusUpdate], Coroutine[Any, Any, Any]] | None
) = None,
timeout_response_secs: int = 10,
ws_options: dict[str, Any] | None = None,
http_options: dict[str, Any] | None = None,
):
"""
Args:
server_url: The URL of the auction server.
opportunity_callback: An async function that serves as the callback on a new opportunity. Should take in one external argument of type Opportunity.
bid_status_callback: An async function that serves as the callback on a new bid status update. Should take in one external argument of type BidStatusUpdate.
timeout_response_secs: The number of seconds to wait for a response message from the server.
ws_options: Keyword arguments to pass to the websocket connection.
http_options: Keyword arguments to pass to the HTTP client.
"""
parsed_url = urllib.parse.urlparse(server_url)
if parsed_url.scheme == "https":
ws_scheme = "wss"
elif parsed_url.scheme == "http":
ws_scheme = "ws"
else:
raise ValueError("Invalid server URL")
self.server_url = server_url
self.ws_endpoint = parsed_url._replace(scheme=ws_scheme, path="/v1/ws").geturl()
self.ws_msg_counter = 0
self.ws: WebSocketClientProtocol
self.ws_lock = asyncio.Lock()
self.ws_loop: Task[Any]
self.ws_msg_futures: dict[str, asyncio.Future] = {}
self.timeout_response_secs = timeout_response_secs
if ws_options is None:
ws_options = {}
self.ws_options = ws_options
if http_options is None:
http_options = {}
self.http_options = http_options
self.opportunity_callback = opportunity_callback
self.bid_status_callback = bid_status_callback
async def start_ws(self):
"""
Initializes the websocket connection to the server, if not already connected.
"""
async with self.ws_lock:
if not hasattr(self, "ws"):
self.ws = await websockets.connect(self.ws_endpoint, **self.ws_options)
if not hasattr(self, "ws_loop"):
ws_call = self.ws_handler(
self.opportunity_callback, self.bid_status_callback
)
self.ws_loop = asyncio.create_task(ws_call)
async def close_ws(self):
"""
Closes the websocket connection to the server.
"""
async with self.ws_lock:
await self.ws.close()
async def get_ws_loop(self) -> asyncio.Task:
"""
Returns the websocket handler loop.
"""
await self.start_ws()
return self.ws_loop
def convert_client_msg_to_server(self, client_msg: ClientMessage) -> dict:
"""
Converts the params of a ClientMessage model dict to the format expected by the server.
Args:
client_msg: The message to send to the server.
Returns:
The message as a dict with the params converted to the format expected by the server.
"""
msg = client_msg.model_dump()
method = msg["params"]["method"]
msg["id"] = str(self.ws_msg_counter)
self.ws_msg_counter += 1
if method == "post_bid":
params = {
"bid": {
"amount": msg["params"]["amount"],
"target_contract": msg["params"]["target_contract"],
"chain_id": msg["params"]["chain_id"],
"target_calldata": msg["params"]["target_calldata"],
"permission_key": msg["params"]["permission_key"],
}
}
msg["params"] = params
elif method == "post_opportunity_bid":
params = {
"opportunity_id": msg["params"]["opportunity_id"],
"opportunity_bid": {
"amount": msg["params"]["amount"],
"executor": msg["params"]["executor"],
"permission_key": msg["params"]["permission_key"],
"signature": msg["params"]["signature"],
"valid_until": msg["params"]["valid_until"],
},
}
msg["params"] = params
msg["method"] = method
return msg
async def send_ws_msg(self, client_msg: ClientMessage) -> dict:
"""
Sends a message to the server via websocket.
Args:
client_msg: The message to send.
Returns:
The result of the response message from the server.
"""
await self.start_ws()
msg = self.convert_client_msg_to_server(client_msg)
future = asyncio.get_event_loop().create_future()
self.ws_msg_futures[msg["id"]] = future
await self.ws.send(json.dumps(msg))
# await the response for the sent ws message from the server
msg_response = await asyncio.wait_for(
future, timeout=self.timeout_response_secs
)
return self.process_response_msg(msg_response)
def process_response_msg(self, msg: dict) -> dict:
"""
Processes a response message received from the server via websocket.
Args:
msg: The message to process.
Returns:
The result field of the message.
"""
if msg.get("status") and msg.get("status") != "success":
raise ExpressRelayClientException(
f"Error in websocket response with message id {msg.get('id')}: {msg.get('result')}"
)
return msg["result"]
async def subscribe_chains(self, chain_ids: list[str]):
"""
Subscribes websocket to a list of chain IDs for new opportunities.
Args:
chain_ids: A list of chain IDs to subscribe to.
"""
params = {
"method": "subscribe",
"chain_ids": chain_ids,
}
client_msg = ClientMessage.model_validate({"params": params})
await self.send_ws_msg(client_msg)
async def unsubscribe_chains(self, chain_ids: list[str]):
"""
Unsubscribes websocket from a list of chain IDs for new opportunities.
Args:
chain_ids: A list of chain IDs to unsubscribe from.
"""
params = {
"method": "unsubscribe",
"chain_ids": chain_ids,
}
client_msg = ClientMessage.model_validate({"params": params})
await self.send_ws_msg(client_msg)
async def submit_bid(self, bid: Bid, subscribe_to_updates: bool = True) -> UUID:
"""
Submits a bid to the auction server.
Args:
bid: An object representing the bid to submit.
subscribe_to_updates: A boolean indicating whether to subscribe to the bid status updates.
Returns:
The ID of the submitted bid.
"""
bid_dict = bid.model_dump()
if subscribe_to_updates:
bid_dict["method"] = "post_bid"
client_msg = ClientMessage.model_validate({"params": bid_dict})
result = await self.send_ws_msg(client_msg)
bid_id = UUID(result.get("id"))
else:
async with httpx.AsyncClient(**self.http_options) as client:
resp = await client.post(
urllib.parse.urlparse(self.server_url)
._replace(path="/v1/bids")
.geturl(),
json=bid_dict,
)
resp.raise_for_status()
bid_id = UUID(resp.json().get("id"))
return bid_id
async def submit_opportunity_bid(
self,
opportunity_bid: OpportunityBid,
subscribe_to_updates: bool = True,
) -> UUID:
"""
Submits a bid on an opportunity to the server via websocket.
Args:
opportunity_bid: An object representing the bid to submit on an opportunity.
subscribe_to_updates: A boolean indicating whether to subscribe to the bid status updates.
Returns:
The ID of the submitted bid.
"""
opportunity_bid_dict = opportunity_bid.model_dump()
if subscribe_to_updates:
params = {
"method": "post_opportunity_bid",
"opportunity_id": opportunity_bid.opportunity_id,
"amount": opportunity_bid.amount,
"executor": opportunity_bid.executor,
"permission_key": opportunity_bid.permission_key,
"signature": opportunity_bid.signature,
"valid_until": opportunity_bid.valid_until,
}
client_msg = ClientMessage.model_validate({"params": params})
result = await self.send_ws_msg(client_msg)
bid_id = UUID(result.get("id"))
else:
async with httpx.AsyncClient(**self.http_options) as client:
resp = await client.post(
urllib.parse.urlparse(self.server_url)
._replace(
path=f"/v1/opportunities/{opportunity_bid.opportunity_id}/bids"
)
.geturl(),
json=opportunity_bid_dict,
)
resp.raise_for_status()
bid_id = UUID(resp.json().get("id"))
return bid_id
async def ws_handler(
self,
opportunity_callback: (
Callable[[Opportunity], Coroutine[Any, Any, Any]] | None
) = None,
bid_status_callback: (
Callable[[BidStatusUpdate], Coroutine[Any, Any, Any]] | None
) = None,
):
"""
Continually handles new ws messages as they are received from the server via websocket.
Args:
opportunity_callback: An async function that serves as the callback on a new opportunity. Should take in one external argument of type Opportunity.
bid_status_callback: An async function that serves as the callback on a new bid status update. Should take in one external argument of type BidStatusUpdate.
"""
if not self.ws:
raise ExpressRelayClientException("Websocket not connected")
async for msg in self.ws:
msg_json = json.loads(msg)
if msg_json.get("type"):
if msg_json.get("type") == "new_opportunity":
if opportunity_callback is not None:
opportunity = Opportunity.process_opportunity_dict(
msg_json["opportunity"]
)
if opportunity:
asyncio.create_task(opportunity_callback(opportunity))
elif msg_json.get("type") == "bid_status_update":
if bid_status_callback is not None:
id = msg_json["status"]["id"]
bid_status = msg_json["status"]["bid_status"]["status"]
result = msg_json["status"]["bid_status"].get("result")
bid_status_update = BidStatusUpdate(
id=id, bid_status=BidStatus(bid_status), result=result
)
asyncio.create_task(bid_status_callback(bid_status_update))
elif msg_json.get("id"):
future = self.ws_msg_futures.pop(msg_json["id"])
future.set_result(msg_json)
async def get_opportunities(self, chain_id: str | None = None) -> list[Opportunity]:
"""
Connects to the server and fetches opportunities.
Args:
chain_id: The chain ID to fetch opportunities for. If None, fetches opportunities across all chains.
Returns:
A list of opportunities.
"""
params = {}
if chain_id:
params["chain_id"] = chain_id
async with httpx.AsyncClient(**self.http_options) as client:
resp = await client.get(
urllib.parse.urlparse(self.server_url)
._replace(path="/v1/opportunities")
.geturl(),
params=params,
)
resp.raise_for_status()
opportunities = []
for opportunity in resp.json():
opportunity_processed = Opportunity.process_opportunity_dict(opportunity)
if opportunity_processed:
opportunities.append(opportunity_processed)
return opportunities
async def submit_opportunity(self, opportunity: OpportunityParams) -> UUID:
"""
Submits an opportunity to the server.
Args:
opportunity: An object representing the opportunity to submit.
Returns:
The ID of the submitted opportunity.
"""
async with httpx.AsyncClient(**self.http_options) as client:
resp = await client.post(
urllib.parse.urlparse(self.server_url)
._replace(path="/v1/opportunities")
.geturl(),
json=opportunity.params.model_dump(),
)
resp.raise_for_status()
return UUID(resp.json()["opportunity_id"])
def sign_bid(
opportunity: Opportunity,
bid_amount: int,
valid_until: int,
private_key: str,
) -> OpportunityBid:
"""
Constructs a signature for a searcher's bid and returns the OpportunityBid object to be submitted to the server.
Args:
opportunity: An object representing the opportunity, of type Opportunity.
bid_amount: An integer representing the amount of the bid (in wei).
valid_until: An integer representing the unix timestamp until which the bid is valid.
private_key: A 0x-prefixed hex string representing the searcher's private key.
Returns:
A OpportunityBid object, representing the transaction to submit to the server. This object contains the searcher's signature.
"""
sell_tokens = [
(token.token, int(token.amount)) for token in opportunity.sell_tokens
]
buy_tokens = [(token.token, int(token.amount)) for token in opportunity.buy_tokens]
target_calldata = bytes.fromhex(opportunity.target_calldata.replace("0x", ""))
digest = encode(
[
"(address,uint256)[]",
"(address,uint256)[]",
"address",
"bytes",
"uint256",
"uint256",
"uint256",
],
[
sell_tokens,
buy_tokens,
opportunity.target_contract,
target_calldata,
opportunity.target_call_value,
bid_amount,
valid_until,
],
)
msg_data = web3.Web3.solidity_keccak(["bytes"], [digest])
signature = w3.eth.account.signHash(msg_data, private_key=private_key)
opportunity_bid = OpportunityBid(
opportunity_id=opportunity.opportunity_id,
permission_key=opportunity.permission_key,
amount=bid_amount,
valid_until=valid_until,
executor=Account.from_key(private_key).address,
signature=signature,
)
return opportunity_bid