diff --git a/mango/client.py b/mango/client.py index cbec194..8972213 100644 --- a/mango/client.py +++ b/mango/client.py @@ -410,6 +410,7 @@ class CompoundRPCCaller(HTTPProvider): def __init__(self, providers: typing.Sequence[RPCCaller]): self.logger_: logging.Logger = logging.getLogger(self.__class__.__name__) self.__providers: typing.Sequence[RPCCaller] = providers + self.on_provider_change: typing.Callable[[], None] = lambda: None @property def current(self) -> RPCCaller: @@ -419,6 +420,19 @@ class CompoundRPCCaller(HTTPProvider): def all_providers(self) -> typing.Sequence[RPCCaller]: return self.__providers + def shift_to_next_provider(self) -> None: + # This is called when the current provider is raising errors, when the next provider might not. + # Typical RPC host errors are trapped and managed via make_request(), but some errors can't be + # handled properly there. For example, BlockhashNotFound exceptions can be trapped there, but + # the right answer is to switch to the next provider AND THEN fetch a fresh blockhash and retry + # the transaction. That's not possible to do atomically (without a lot of nasty, fragile work) so + # it's better to handle it at the higher level. That's what this method allows - the higher level + # can call this to switch to the next provider, and it can then fetch the fresh blockhash and + # resubmit the transaction. + if len(self.__providers) > 1: + self.__providers = [*self.__providers[1:], *self.__providers[:1]] + self.on_provider_change() + def make_request(self, method: RPCMethod, *params: typing.Any) -> RPCResponse: last_exception: Exception for provider in self.__providers: @@ -428,12 +442,12 @@ class CompoundRPCCaller(HTTPProvider): if successful_index != 0: # Rebase the providers' list so we continue to use this successful one (until it fails) self.__providers = [*self.__providers[successful_index:], *self.__providers[:successful_index]] + self.on_provider_change() return result except (RateLimitException, NodeIsBehindException, StaleSlotException, - FailedToFetchBlockhashException, - BlockhashNotFoundException) as exception: + FailedToFetchBlockhashException) as exception: last_exception = exception self.logger_.info(f"Moving to next provider - {provider} gave {exception}") @@ -502,6 +516,14 @@ class BetterClient: client: Client = _MaxRetriesZeroClient(cluster_url, commitment=commitment, blockhash_cache=blockhash_cache) client._provider = provider + def __on_provider_change() -> None: + if client.blockhash_cache: + # Clear out the blockhash cache on retrying + client.blockhash_cache.unused_blockhashes.clear() + client.blockhash_cache.used_blockhashes.clear() + + provider.on_provider_change = __on_provider_change + return BetterClient(client, name, cluster_name, commitment, skip_preflight, encoding, blockhash_cache_duration, provider) @property @@ -586,28 +608,46 @@ class BetterClient: return response["result"]["value"] def send_transaction(self, transaction: Transaction, *signers: Keypair, opts: TxOpts = TxOpts(preflight_commitment=UnspecifiedCommitment)) -> str: - proper_commitment: Commitment = opts.preflight_commitment - if proper_commitment == UnspecifiedCommitment: - proper_commitment = self.commitment + # This method is an exception to the normal exception-handling to fail over to the next RPC provider. + # + # Normal RPC exceptions just move on to the next RPC provider and try again. That won't work with the + # BlockhashNotFoundException, since a stale blockhash will be stale for all providers and it probably + # indicates a problem with the current node returning the stale blockhash anyway. + # + # What we want to do in this situation is: retry the same transaction (which we know for certain failed) + # but retry it with the next provider in the list, with a fresh recent_blockhash. (Setting the transaction's + # recent_blockhash to None makes the client fetch a fresh one.) + last_exception: BlockhashNotFoundException + for _ in self.rpc_caller.all_providers: + try: + proper_commitment: Commitment = opts.preflight_commitment + if proper_commitment == UnspecifiedCommitment: + proper_commitment = self.commitment - proper_opts = TxOpts(preflight_commitment=proper_commitment, - skip_confirmation=opts.skip_confirmation, - skip_preflight=opts.skip_preflight) + proper_opts = TxOpts(preflight_commitment=proper_commitment, + skip_confirmation=opts.skip_confirmation, + skip_preflight=opts.skip_preflight) - response = self.compatible_client.send_transaction(transaction, *signers, opts=proper_opts) - signature: str = str(response["result"]) + response = self.compatible_client.send_transaction(transaction, *signers, opts=proper_opts) + signature: str = str(response["result"]) - if signature != _STUB_TRANSACTION_SIGNATURE: - transaction_status = self.compatible_client.get_signature_statuses([signature]) - if "result" in transaction_status and "context" in transaction_status["result"] and "slot" in transaction_status["result"]["context"]: - slot: int = transaction_status["result"]["context"]["slot"] - self.rpc_caller.current.require_data_from_fresh_slot(slot) - else: - self.logger.error(f"Could not get status for signature {signature}") - else: - self.logger.error("Could not get status for stub signature") + if signature != _STUB_TRANSACTION_SIGNATURE: + transaction_status = self.compatible_client.get_signature_statuses([signature]) + if "result" in transaction_status and "context" in transaction_status["result"] and "slot" in transaction_status["result"]["context"]: + slot: int = transaction_status["result"]["context"]["slot"] + self.rpc_caller.current.require_data_from_fresh_slot(slot) + else: + self.logger.error(f"Could not get status for signature {signature}") + else: + self.logger.error("Could not get status for stub signature") - return signature + return signature + except BlockhashNotFoundException as blockhash_not_found_exception: + last_exception = blockhash_not_found_exception + transaction.recent_blockhash = None + self.rpc_caller.shift_to_next_provider() + + raise last_exception def wait_for_confirmation(self, transaction_ids: typing.Sequence[str], max_wait_in_seconds: int = 60) -> typing.Sequence[str]: self.logger.info(f"Waiting up to {max_wait_in_seconds} seconds for {transaction_ids}.") diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..62314d4 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,158 @@ +import pytest +import typing + +from .context import mango + +from solana.rpc.types import RPCMethod, RPCResponse + + +__FAKE_RPC_METHOD = RPCMethod("fake") + + +class FakeRPCCaller(mango.RPCCaller): + def __init__(self) -> None: + super().__init__("Fake", "https://localhost", [0.1, 0.2], mango.SlotHolder(), mango.InstructionReporter()) + self.called = False + + def make_request(self, method: RPCMethod, *params: typing.Any) -> RPCResponse: + self.called = True + return { + "jsonrpc": "2.0", + "id": 0, + "result": {} + } + + +class RaisingRPCCaller(mango.RPCCaller): + def __init__(self) -> None: + super().__init__("Fake", "https://localhost", [0.1, 0.2], mango.SlotHolder(), mango.InstructionReporter()) + self.called = False + + def make_request(self, method: RPCMethod, *params: typing.Any) -> RPCResponse: + self.called = True + raise mango.TooManyRequestsRateLimitException("Fake", "fake-name", "https://fake") + + +def test_constructor_sets_correct_values() -> None: + provider = FakeRPCCaller() + actual = mango.CompoundRPCCaller([provider]) + assert actual is not None + assert len(actual.all_providers) == 1 + assert actual.current == provider + assert actual.all_providers[0] == provider + + +def test_constructor_sets_correct_values_with_three_providers() -> None: + provider1 = FakeRPCCaller() + provider2 = FakeRPCCaller() + provider3 = FakeRPCCaller() + actual = mango.CompoundRPCCaller([provider1, provider2, provider3]) + assert actual is not None + assert len(actual.all_providers) == 3 + assert actual.current == provider1 + assert actual.all_providers[0] == provider1 + assert actual.all_providers[1] == provider2 + assert actual.all_providers[2] == provider3 + + # Paranoid check to make sure we don't have equality issues + assert actual.all_providers[0] != provider2 + assert actual.all_providers[0] != provider3 + + +def test_switching_with_one_provider() -> None: + provider = FakeRPCCaller() + actual = mango.CompoundRPCCaller([provider]) + + assert actual.current == provider + actual.shift_to_next_provider() + assert actual.current == provider + + +def test_switching_with_three_providers() -> None: + provider1 = FakeRPCCaller() + provider2 = FakeRPCCaller() + provider3 = FakeRPCCaller() + actual = mango.CompoundRPCCaller([provider1, provider2, provider3]) + + assert actual.current == provider1 + actual.shift_to_next_provider() + assert actual.current == provider2 + + +def test_switching_with_three_providers_circular() -> None: + provider1 = FakeRPCCaller() + provider2 = FakeRPCCaller() + provider3 = FakeRPCCaller() + actual = mango.CompoundRPCCaller([provider1, provider2, provider3]) + + assert actual.current == provider1 + + actual.shift_to_next_provider() + assert actual.current == provider2 + + actual.shift_to_next_provider() + assert actual.current == provider3 + + actual.shift_to_next_provider() + assert actual.current == provider1 + + +def test_successful_calling_does_not_call_second_provider() -> None: + provider1 = FakeRPCCaller() + provider2 = FakeRPCCaller() + provider3 = FakeRPCCaller() + actual = mango.CompoundRPCCaller([provider1, provider2, provider3]) + + assert not provider1.called + assert not provider2.called + assert not provider3.called + + actual.make_request(__FAKE_RPC_METHOD, "fake") + + assert provider1.called + assert not provider2.called + assert not provider3.called + + +def test_failed_calling_calls_second_provider() -> None: + provider1 = RaisingRPCCaller() + provider2 = FakeRPCCaller() + provider3 = FakeRPCCaller() + actual = mango.CompoundRPCCaller([provider1, provider2, provider3]) + + assert not provider1.called + assert not provider2.called + assert not provider3.called + + actual.make_request(__FAKE_RPC_METHOD, "fake") + + assert provider1.called + assert provider2.called + assert not provider3.called + + +def test_failed_calling_updates_current_to_second_provider() -> None: + provider1 = RaisingRPCCaller() + provider2 = FakeRPCCaller() + provider3 = FakeRPCCaller() + actual = mango.CompoundRPCCaller([provider1, provider2, provider3]) + + assert actual.current == provider1 + + actual.make_request(__FAKE_RPC_METHOD, "fake") + + assert actual.current == provider2 + + +def test_all_failing_raises_exception() -> None: + provider1 = RaisingRPCCaller() + provider2 = RaisingRPCCaller() + provider3 = RaisingRPCCaller() + actual = mango.CompoundRPCCaller([provider1, provider2, provider3]) + + assert actual.current == provider1 + + with pytest.raises(mango.TooManyRequestsRateLimitException): + actual.make_request(__FAKE_RPC_METHOD, "fake") + + assert actual.current == provider1