refactored getResponse to pure function

This commit is contained in:
Matt Johnstone 2024-10-08 17:16:10 +02:00
parent c65c347504
commit dd6f72c822
No known key found for this signature in database
GPG Key ID: BE985FBB9BE7D3BB
2 changed files with 23 additions and 30 deletions

View File

@ -16,11 +16,6 @@ type (
rpcAddr string
}
rpcError struct {
Message string `json:"message"`
Code int64 `json:"code"`
}
rpcRequest struct {
Version string `json:"jsonrpc"`
ID int `json:"id"`
@ -99,7 +94,9 @@ func NewRPCClient(rpcAddr string) *Client {
return &Client{httpClient: http.Client{}, rpcAddr: rpcAddr}
}
func (c *Client) getResponse(ctx context.Context, method string, params []any, result HasRPCError) error {
func getResponse[T any](
ctx context.Context, httpClient http.Client, url string, method string, params []any, rpcResponse *response[T],
) error {
// format request:
request := &rpcRequest{Version: "2.0", ID: 1, Method: method, Params: params}
buffer, err := json.Marshal(request)
@ -109,13 +106,13 @@ func (c *Client) getResponse(ctx context.Context, method string, params []any, r
klog.V(2).Infof("jsonrpc request: %s", string(buffer))
// make request:
req, err := http.NewRequestWithContext(ctx, "POST", c.rpcAddr, bytes.NewBuffer(buffer))
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(buffer))
if err != nil {
klog.Fatalf("failed to create request: %v", err)
}
req.Header.Set("content-type", "application/json")
resp, err := c.httpClient.Do(req)
resp, err := httpClient.Do(req)
if err != nil {
return fmt.Errorf("%s RPC call failed: %w", method, err)
}
@ -130,15 +127,14 @@ func (c *Client) getResponse(ctx context.Context, method string, params []any, r
klog.V(2).Infof("%s response: %v", method, string(body))
// unmarshal the response into the predicted format
if err = json.Unmarshal(body, result); err != nil {
if err = json.Unmarshal(body, rpcResponse); err != nil {
return fmt.Errorf("failed to decode %s response body: %w", method, err)
}
// last error check:
if result.getError().Code != 0 {
return fmt.Errorf("RPC error: %d %v", result.getError().Code, result.getError().Message)
if rpcResponse.Error.Code != 0 {
return fmt.Errorf("RPC error: %d %v", rpcResponse.Error.Code, rpcResponse.Error.Message)
}
return nil
}
@ -146,7 +142,7 @@ func (c *Client) getResponse(ctx context.Context, method string, params []any, r
// See API docs: https://solana.com/docs/rpc/http/getepochinfo
func (c *Client) GetEpochInfo(ctx context.Context, commitment Commitment) (*EpochInfo, error) {
var resp response[EpochInfo]
if err := c.getResponse(ctx, "getEpochInfo", []any{commitment}, &resp); err != nil {
if err := getResponse(ctx, c.httpClient, c.rpcAddr, "getEpochInfo", []any{commitment}, &resp); err != nil {
return nil, err
}
return &resp.Result, nil
@ -164,7 +160,7 @@ func (c *Client) GetVoteAccounts(
}
var resp response[VoteAccounts]
if err := c.getResponse(ctx, "getVoteAccounts", []any{config}, &resp); err != nil {
if err := getResponse(ctx, c.httpClient, c.rpcAddr, "getVoteAccounts", []any{config}, &resp); err != nil {
return nil, err
}
return &resp.Result, nil
@ -176,7 +172,7 @@ func (c *Client) GetVersion(ctx context.Context) (string, error) {
var resp response[struct {
Version string `json:"solana-core"`
}]
if err := c.getResponse(ctx, "getVersion", []any{}, &resp); err != nil {
if err := getResponse(ctx, c.httpClient, c.rpcAddr, "getVersion", []any{}, &resp); err != nil {
return "", err
}
return resp.Result.Version, nil
@ -187,7 +183,7 @@ func (c *Client) GetVersion(ctx context.Context) (string, error) {
func (c *Client) GetSlot(ctx context.Context, commitment Commitment) (int64, error) {
config := map[string]string{"commitment": string(commitment)}
var resp response[int64]
if err := c.getResponse(ctx, "getSlot", []any{config}, &resp); err != nil {
if err := getResponse(ctx, c.httpClient, c.rpcAddr, "getSlot", []any{config}, &resp); err != nil {
return 0, err
}
return resp.Result, nil
@ -223,7 +219,7 @@ func (c *Client) GetBlockProduction(
// make request:
var resp response[contextualResult[BlockProduction]]
if err := c.getResponse(ctx, "getBlockProduction", []any{config}, &resp); err != nil {
if err := getResponse(ctx, c.httpClient, c.rpcAddr, "getBlockProduction", []any{config}, &resp); err != nil {
return nil, err
}
return &resp.Result.Value, nil
@ -234,7 +230,7 @@ func (c *Client) GetBlockProduction(
func (c *Client) GetBalance(ctx context.Context, commitment Commitment, address string) (float64, error) {
config := map[string]string{"commitment": string(commitment)}
var resp response[contextualResult[int64]]
if err := c.getResponse(ctx, "getBalance", []any{address, config}, &resp); err != nil {
if err := getResponse(ctx, c.httpClient, c.rpcAddr, "getBalance", []any{address, config}, &resp); err != nil {
return 0, err
}
return float64(resp.Result.Value) / float64(LamportsInSol), nil
@ -255,7 +251,7 @@ func (c *Client) GetInflationReward(
}
var resp response[[]InflationReward]
if err := c.getResponse(ctx, "getInflationReward", []any{addresses, config}, &resp); err != nil {
if err := getResponse(ctx, c.httpClient, c.rpcAddr, "getInflationReward", []any{addresses, config}, &resp); err != nil {
return nil, err
}
return resp.Result, nil
@ -266,7 +262,7 @@ func (c *Client) GetInflationReward(
func (c *Client) GetLeaderSchedule(ctx context.Context, commitment Commitment, slot int64) (map[string][]int64, error) {
config := map[string]any{"commitment": string(commitment)}
var resp response[map[string][]int64]
if err := c.getResponse(ctx, "getLeaderSchedule", []any{slot, config}, &resp); err != nil {
if err := getResponse(ctx, c.httpClient, c.rpcAddr, "getLeaderSchedule", []any{slot, config}, &resp); err != nil {
return nil, err
}
return resp.Result, nil
@ -285,7 +281,7 @@ func (c *Client) GetBlock(ctx context.Context, commitment Commitment, slot int64
"rewards": true, // what we here for!
}
var resp response[Block]
if err := c.getResponse(ctx, "getBlock", []any{slot, config}, &resp); err != nil {
if err := getResponse(ctx, c.httpClient, c.rpcAddr, "getBlock", []any{slot, config}, &resp); err != nil {
return nil, err
}
return &resp.Result, nil

View File

@ -6,10 +6,15 @@ import (
)
type (
RPCError struct {
Message string `json:"message"`
Code int64 `json:"code"`
}
response[T any] struct {
jsonrpc string
Result T `json:"result"`
Error rpcError `json:"error"`
Error RPCError `json:"error"`
Id int `json:"id"`
}
@ -104,11 +109,3 @@ func (hp *HostProduction) UnmarshalJSON(data []byte) error {
hp.BlocksProduced = arr[1]
return nil
}
func (r response[T]) getError() rpcError {
return r.Error
}
type HasRPCError interface {
getError() rpcError
}