package rpc import ( "context" "encoding/json" "fmt" "github.com/asymmetric-research/solana_exporter/pkg/slog" "go.uber.org/zap" "net" "net/http" "sync" "testing" "time" ) type MockOpt int const ( BalanceOpt MockOpt = iota InflationRewardsOpt EasyResultsOpt SlotInfosOpt ValidatorInfoOpt ) type ( // MockServer represents a mock Solana RPC server for testing MockServer struct { server *http.Server listener net.Listener mu sync.RWMutex logger *zap.SugaredLogger balances map[string]int inflationRewards map[string]int easyResults map[string]any slotInfos map[int]MockSlotInfo validatorInfos map[string]MockValidatorInfo } MockBlockInfo struct { Fee int Transactions [][]string } MockSlotInfo struct { Leader string Block *MockBlockInfo } MockValidatorInfo struct { Votekey string Stake int LastVote int Delinquent bool } ) // NewMockServer creates a new mock server instance func NewMockServer( easyResults map[string]any, balances map[string]int, inflationRewards map[string]int, slotInfos map[int]MockSlotInfo, validatorInfos map[string]MockValidatorInfo, ) (*MockServer, error) { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return nil, fmt.Errorf("failed to create listener: %v", err) } ms := &MockServer{ listener: listener, logger: slog.Get(), easyResults: easyResults, balances: balances, inflationRewards: inflationRewards, slotInfos: slotInfos, validatorInfos: validatorInfos, } mux := http.NewServeMux() mux.HandleFunc("/", ms.handleRPCRequest) ms.server = &http.Server{Handler: mux} go func() { _ = ms.server.Serve(listener) }() return ms, nil } // URL returns the URL of the mock server func (s *MockServer) URL() string { return fmt.Sprintf("http://%s", s.listener.Addr().String()) } // Close shuts down the mock server func (s *MockServer) Close() error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() return s.server.Shutdown(ctx) } func (s *MockServer) MustClose() { if err := s.Close(); err != nil { panic(err) } } func (s *MockServer) SetOpt(opt MockOpt, key any, value any) { s.mu.Lock() defer s.mu.Unlock() switch opt { case BalanceOpt: if s.balances == nil { s.balances = make(map[string]int) } s.balances[key.(string)] = value.(int) case InflationRewardsOpt: if s.inflationRewards == nil { s.inflationRewards = make(map[string]int) } s.inflationRewards[key.(string)] = value.(int) case EasyResultsOpt: if s.easyResults == nil { s.easyResults = make(map[string]any) } s.easyResults[key.(string)] = value case SlotInfosOpt: if s.slotInfos == nil { s.slotInfos = make(map[int]MockSlotInfo) } s.slotInfos[key.(int)] = value.(MockSlotInfo) case ValidatorInfoOpt: if s.validatorInfos == nil { s.validatorInfos = make(map[string]MockValidatorInfo) } s.validatorInfos[key.(string)] = value.(MockValidatorInfo) } } func (s *MockServer) GetValidatorInfo(nodekey string) MockValidatorInfo { s.mu.RLock() defer s.mu.RUnlock() return s.validatorInfos[nodekey] } func (s *MockServer) getResult(method string, params ...any) (any, *RPCError) { s.mu.RLock() defer s.mu.RUnlock() if method == "getBalance" && s.balances != nil { address := params[0].(string) result := map[string]any{ "context": map[string]int{"slot": 1}, "value": s.balances[address], } return result, nil } if method == "getInflationReward" && s.inflationRewards != nil { addresses := params[0].([]any) config := params[1].(map[string]any) epoch := int(config["epoch"].(float64)) rewards := make([]map[string]int, len(addresses)) for i, item := range addresses { address := item.(string) // TODO: make inflation rewards fetchable by epoch rewards[i] = map[string]int{"amount": s.inflationRewards[address], "epoch": epoch} } return rewards, nil } if method == "getBlock" && s.slotInfos != nil { // get params: slot := int(params[0].(float64)) config := params[1].(map[string]any) transactionDetails, rewardsIncluded := config["transactionDetails"].(string), config["rewards"].(bool) slotInfo, ok := s.slotInfos[slot] if !ok { s.logger.Warnf("no slot info for slot %d", slot) return nil, &RPCError{Code: BlockCleanedUpCode, Message: "Block cleaned up."} } if slotInfo.Block == nil { return nil, &RPCError{Code: SlotSkippedCode, Message: "Slot skipped."} } var ( transactions []map[string]any rewards []map[string]any ) if transactionDetails == "full" { for _, tx := range slotInfo.Block.Transactions { transactions = append( transactions, map[string]any{ "transaction": map[string]map[string][]string{"message": {"accountKeys": tx}}, }, ) } } if rewardsIncluded { rewards = append( rewards, map[string]any{"pubkey": slotInfo.Leader, "lamports": slotInfo.Block.Fee, "rewardType": "fee"}, ) } return map[string]any{"rewards": rewards, "transactions": transactions}, nil } if method == "getBlockProduction" && s.slotInfos != nil { // get params: config := params[0].(map[string]any) slotRange := config["range"].(map[string]any) firstSlot, lastSlot := int(slotRange["firstSlot"].(float64)), int(slotRange["lastSlot"].(float64)) byIdentity := make(map[string][]int) for nodekey := range s.validatorInfos { byIdentity[nodekey] = []int{0, 0} } for i := firstSlot; i <= lastSlot; i++ { info := s.slotInfos[i] production := byIdentity[info.Leader] production[0]++ if info.Block != nil { production[1]++ } byIdentity[info.Leader] = production } blockProduction := map[string]any{ "context": map[string]int{"slot": 1}, "value": map[string]any{"byIdentity": byIdentity, "range": slotRange}, } return blockProduction, nil } if method == "getVoteAccounts" && s.validatorInfos != nil { var currentVoteAccounts, delinquentVoteAccounts []map[string]any for nodekey, info := range s.validatorInfos { voteAccount := map[string]any{ "activatedStake": int64(info.Stake), "lastVote": info.LastVote, "nodePubkey": nodekey, "rootSlot": 0, "votePubkey": info.Votekey, } if info.Delinquent { delinquentVoteAccounts = append(delinquentVoteAccounts, voteAccount) } else { currentVoteAccounts = append(currentVoteAccounts, voteAccount) } } voteAccounts := map[string][]map[string]any{ "current": currentVoteAccounts, "delinquent": delinquentVoteAccounts, } return voteAccounts, nil } // default is use easy results: result, ok := s.easyResults[method] if !ok { return nil, &RPCError{Code: -32601, Message: "Method not found"} } return result, nil } func (s *MockServer) handleRPCRequest(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Only POST method is allowed", http.StatusMethodNotAllowed) return } var request Request if err := json.NewDecoder(r.Body).Decode(&request); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } response := Response[any]{Jsonrpc: "2.0", Id: request.Id} result, rpcErr := s.getResult(request.Method, request.Params...) if rpcErr != nil { response.Error = *rpcErr } else { response.Result = result } w.Header().Set("Content-Type", "application/json") err := json.NewEncoder(w).Encode(response) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } } // NewMockClient creates a new test client with a running mock server func NewMockClient( t *testing.T, easyResults map[string]any, balances map[string]int, inflationRewards map[string]int, slotInfos map[int]MockSlotInfo, validatorInfos map[string]MockValidatorInfo, ) (*MockServer, *Client) { server, err := NewMockServer(easyResults, balances, inflationRewards, slotInfos, validatorInfos) if err != nil { t.Fatalf("failed to create mock server: %v", err) } t.Cleanup(func() { if err := server.Close(); err != nil { t.Errorf("failed to close mock server: %v", err) } }) client := NewRPCClient(server.URL(), time.Second) return server, client }