diff --git a/client/src/nonblocking/pubsub_client.rs b/client/src/nonblocking/pubsub_client.rs index 56a7817153..4f252e4fd8 100644 --- a/client/src/nonblocking/pubsub_client.rs +++ b/client/src/nonblocking/pubsub_client.rs @@ -6,9 +6,10 @@ use { RpcProgramAccountsConfig, RpcSignatureSubscribeConfig, RpcTransactionLogsConfig, RpcTransactionLogsFilter, }, + rpc_filter::maybe_map_filters, rpc_response::{ Response as RpcResponse, RpcBlockUpdate, RpcKeyedAccount, RpcLogsResponse, - RpcSignatureResult, RpcVote, SlotInfo, SlotUpdate, + RpcSignatureResult, RpcVersionInfo, RpcVote, SlotInfo, SlotUpdate, }, }, futures_util::{ @@ -25,7 +26,7 @@ use { thiserror::Error, tokio::{ net::TcpStream, - sync::{mpsc, oneshot}, + sync::{mpsc, oneshot, RwLock}, task::JoinHandle, time::{sleep, Duration}, }, @@ -62,6 +63,9 @@ pub enum PubsubClientError { #[error("subscribe failed: {reason}")] SubscribeFailed { reason: String, message: String }, + + #[error("request failed: {reason}")] + RequestFailed { reason: String, message: String }, } type UnsubscribeFn = Box BoxFuture<'static, ()> + Send>; @@ -69,11 +73,18 @@ type SubscribeResponseMsg = Result<(mpsc::UnboundedReceiver, UnsubscribeFn), PubsubClientError>; type SubscribeRequestMsg = (String, Value, oneshot::Sender); type SubscribeResult<'a, T> = PubsubClientResult<(BoxStream<'a, T>, UnsubscribeFn)>; +type RequestMsg = ( + String, + Value, + oneshot::Sender>, +); #[derive(Debug)] pub struct PubsubClient { subscribe_tx: mpsc::UnboundedSender, + request_tx: mpsc::UnboundedSender, shutdown_tx: oneshot::Sender<()>, + node_version: RwLock>, ws: JoinHandle, } @@ -85,12 +96,20 @@ impl PubsubClient { .map_err(PubsubClientError::ConnectionError)?; let (subscribe_tx, subscribe_rx) = mpsc::unbounded_channel(); + let (request_tx, request_rx) = mpsc::unbounded_channel(); let (shutdown_tx, shutdown_rx) = oneshot::channel(); Ok(Self { subscribe_tx, + request_tx, shutdown_tx, - ws: tokio::spawn(PubsubClient::run_ws(ws, subscribe_rx, shutdown_rx)), + node_version: RwLock::new(None), + ws: tokio::spawn(PubsubClient::run_ws( + ws, + subscribe_rx, + request_rx, + shutdown_rx, + )), }) } @@ -99,6 +118,37 @@ impl PubsubClient { self.ws.await.unwrap() // WS future should not be cancelled or panicked } + async fn get_node_version(&self) -> PubsubClientResult { + let r_node_version = self.node_version.read().await; + if let Some(version) = &*r_node_version { + Ok(version.clone()) + } else { + drop(r_node_version); + let mut w_node_version = self.node_version.write().await; + let node_version = self.get_version().await?; + *w_node_version = Some(node_version.clone()); + Ok(node_version) + } + } + + async fn get_version(&self) -> PubsubClientResult { + let (response_tx, response_rx) = oneshot::channel(); + self.request_tx + .send(("getVersion".to_string(), Value::Null, response_tx)) + .map_err(|err| PubsubClientError::ConnectionClosed(err.to_string()))?; + let result = response_rx + .await + .map_err(|err| PubsubClientError::ConnectionClosed(err.to_string()))??; + let node_version: RpcVersionInfo = serde_json::from_value(result)?; + let node_version = semver::Version::parse(&node_version.solana_core).map_err(|e| { + PubsubClientError::RequestFailed { + reason: format!("failed to parse cluster version: {}", e), + message: "getVersion".to_string(), + } + })?; + Ok(node_version) + } + async fn subscribe<'a, T>(&self, operation: &str, params: Value) -> SubscribeResult<'a, T> where T: DeserializeOwned + Send + 'a, @@ -147,8 +197,22 @@ impl PubsubClient { pub async fn program_subscribe( &self, pubkey: &Pubkey, - config: Option, + mut config: Option, ) -> SubscribeResult<'_, RpcResponse> { + if let Some(ref mut config) = config { + if let Some(ref mut filters) = config.filters { + let node_version = self.get_node_version().await.ok(); + // If node does not support the pubsub `getVersion` method, assume version is old + // and filters should be mapped (node_version.is_none()). + maybe_map_filters(node_version, filters).map_err(|e| { + PubsubClientError::RequestFailed { + reason: e, + message: "maybe_map_filters".to_string(), + } + })?; + } + } + let params = json!([pubkey.to_string(), config]); self.subscribe("program", params).await } @@ -181,12 +245,14 @@ impl PubsubClient { async fn run_ws( mut ws: WebSocketStream>, mut subscribe_rx: mpsc::UnboundedReceiver, + mut request_rx: mpsc::UnboundedReceiver, mut shutdown_rx: oneshot::Receiver<()>, ) -> PubsubClientResult { let mut request_id: u64 = 0; let mut requests_subscribe = BTreeMap::new(); let mut requests_unsubscribe = BTreeMap::>::new(); + let mut other_requests = BTreeMap::new(); let mut subscriptions = BTreeMap::new(); let (unsubscribe_tx, mut unsubscribe_rx) = mpsc::unbounded_channel(); @@ -220,6 +286,13 @@ impl PubsubClient { ws.send(Message::Text(text)).await?; requests_unsubscribe.insert(request_id, response_tx); }, + // Read message for other requests + Some((method, params, response_tx)) = request_rx.recv() => { + request_id += 1; + let text = json!({"jsonrpc":"2.0","id":request_id,"method":method,"params":params}).to_string(); + ws.send(Message::Text(text)).await?; + other_requests.insert(request_id, response_tx); + } // Read incoming WebSocket message next_msg = ws.next() => { let msg = match next_msg { @@ -264,7 +337,21 @@ impl PubsubClient { } }); - if let Some(response_tx) = requests_unsubscribe.remove(&id) { + if let Some(response_tx) = other_requests.remove(&id) { + match err { + Some(reason) => { + let _ = response_tx.send(Err(PubsubClientError::RequestFailed { reason, message: text.clone()})); + }, + None => { + let json_result = json.get("result").ok_or_else(|| { + PubsubClientError::RequestFailed { reason: "missing `result` field".into(), message: text.clone() } + })?; + if response_tx.send(Ok(json_result.clone())).is_err() { + break; + } + } + } + } else if let Some(response_tx) = requests_unsubscribe.remove(&id) { let _ = response_tx.send(()); // do not care if receiver is closed } else if let Some((operation, response_tx)) = requests_subscribe.remove(&id) { match err { diff --git a/client/src/nonblocking/rpc_client.rs b/client/src/nonblocking/rpc_client.rs index e855addaab..e4afd7dae9 100644 --- a/client/src/nonblocking/rpc_client.rs +++ b/client/src/nonblocking/rpc_client.rs @@ -19,7 +19,7 @@ use { mock_sender::MockSender, rpc_client::{GetConfirmedSignaturesForAddress2Config, RpcClientConfig}, rpc_config::{RpcAccountInfoConfig, *}, - rpc_filter::{MemcmpEncodedBytes, RpcFilterType}, + rpc_filter::{self, RpcFilterType}, rpc_request::{RpcError, RpcRequest, RpcResponseErrorData, TokenAccountsFilter}, rpc_response::*, rpc_sender::*, @@ -587,24 +587,8 @@ impl RpcClient { mut filters: Vec, ) -> Result, RpcError> { let node_version = self.get_node_version().await?; - if node_version < semver::Version::new(1, 11, 2) { - for filter in filters.iter_mut() { - if let RpcFilterType::Memcmp(memcmp) = filter { - match &memcmp.bytes { - MemcmpEncodedBytes::Base58(string) => { - memcmp.bytes = MemcmpEncodedBytes::Binary(string.clone()); - } - MemcmpEncodedBytes::Base64(_) => { - return Err(RpcError::RpcRequestError(format!( - "RPC node on old version {} does not support base64 encoding for memcmp filters", - node_version - ))); - } - _ => {} - } - } - } - } + rpc_filter::maybe_map_filters(Some(node_version), &mut filters) + .map_err(RpcError::RpcRequestError)?; Ok(filters) } diff --git a/client/src/pubsub_client.rs b/client/src/pubsub_client.rs index 22d5182ae4..22bf49b6ec 100644 --- a/client/src/pubsub_client.rs +++ b/client/src/pubsub_client.rs @@ -5,6 +5,7 @@ use { RpcProgramAccountsConfig, RpcSignatureSubscribeConfig, RpcTransactionLogsConfig, RpcTransactionLogsFilter, }, + rpc_filter, rpc_response::{ Response as RpcResponse, RpcBlockUpdate, RpcKeyedAccount, RpcLogsResponse, RpcSignatureResult, RpcVote, SlotInfo, SlotUpdate, @@ -48,6 +49,9 @@ pub enum PubsubClientError { #[error("unexpected message format: {0}")] UnexpectedMessageError(String), + + #[error("request error: {0}")] + RequestError(String), } pub struct PubsubClientSubscription @@ -123,6 +127,43 @@ where .map_err(|err| err.into()) } + fn get_version( + writable_socket: &Arc>>>, + ) -> Result { + writable_socket + .write() + .unwrap() + .write_message(Message::Text( + json!({ + "jsonrpc":"2.0","id":1,"method":"getVersion", + }) + .to_string(), + ))?; + let message = writable_socket.write().unwrap().read_message()?; + let message_text = &message.into_text()?; + let json_msg: Map = serde_json::from_str(message_text)?; + + if let Some(Object(version_map)) = json_msg.get("result") { + if let Some(node_version) = version_map.get("solana-core") { + let node_version = semver::Version::parse( + node_version.as_str().unwrap_or_default(), + ) + .map_err(|e| { + PubsubClientError::RequestError(format!( + "failed to parse cluster version: {}", + e + )) + })?; + return Ok(node_version); + } + } + // TODO: Add proper JSON RPC response/error handling... + Err(PubsubClientError::UnexpectedMessageError(format!( + "{:?}", + json_msg + ))) + } + fn read_message( writable_socket: &Arc>>>, ) -> Result { @@ -357,7 +398,7 @@ impl PubsubClient { pub fn program_subscribe( url: &str, pubkey: &Pubkey, - config: Option, + mut config: Option, ) -> Result { let url = Url::parse(url)?; let socket = connect_with_retry(url)?; @@ -367,6 +408,17 @@ impl PubsubClient { let socket_clone = socket.clone(); let exit = Arc::new(AtomicBool::new(false)); let exit_clone = exit.clone(); + + if let Some(ref mut config) = config { + if let Some(ref mut filters) = config.filters { + let node_version = PubsubProgramClientSubscription::get_version(&socket_clone).ok(); + // If node does not support the pubsub `getVersion` method, assume version is old + // and filters should be mapped (node_version.is_none()). + rpc_filter::maybe_map_filters(node_version, filters) + .map_err(PubsubClientError::RequestError)?; + } + } + let body = json!({ "jsonrpc":"2.0", "id":1, diff --git a/client/src/rpc_filter.rs b/client/src/rpc_filter.rs index 483fba8028..1f6548c80a 100644 --- a/client/src/rpc_filter.rs +++ b/client/src/rpc_filter.rs @@ -259,6 +259,30 @@ impl From for Memcmp { } } +pub(crate) fn maybe_map_filters( + node_version: Option, + filters: &mut [RpcFilterType], +) -> Result<(), String> { + if node_version.is_none() || node_version.unwrap() < semver::Version::new(1, 11, 2) { + for filter in filters.iter_mut() { + if let RpcFilterType::Memcmp(memcmp) = filter { + match &memcmp.bytes { + MemcmpEncodedBytes::Base58(string) => { + memcmp.bytes = MemcmpEncodedBytes::Binary(string.clone()); + } + MemcmpEncodedBytes::Base64(_) => { + return Err("RPC node on old version does not support base64 \ + encoding for memcmp filters" + .to_string()); + } + _ => {} + } + } + } + } + Ok(()) +} + #[cfg(test)] mod tests { use super::*; diff --git a/rpc/src/rpc_pubsub.rs b/rpc/src/rpc_pubsub.rs index 0b14f641cb..162a8a06ff 100644 --- a/rpc/src/rpc_pubsub.rs +++ b/rpc/src/rpc_pubsub.rs @@ -24,7 +24,7 @@ use { }, rpc_response::{ Response as RpcResponse, RpcBlockUpdate, RpcKeyedAccount, RpcLogsResponse, - RpcSignatureResult, RpcVote, SlotInfo, SlotUpdate, + RpcSignatureResult, RpcVersionInfo, RpcVote, SlotInfo, SlotUpdate, }, }, solana_sdk::{clock::Slot, pubkey::Pubkey, signature::Signature}, @@ -348,6 +348,10 @@ mod internal { // Unsubscribe from slot notification subscription. #[rpc(name = "rootUnsubscribe")] fn root_unsubscribe(&self, id: SubscriptionId) -> Result; + + // Get the current solana version running on the node + #[rpc(name = "getVersion")] + fn get_version(&self) -> Result; } } @@ -576,6 +580,14 @@ impl RpcSolPubSubInternal for RpcSolPubSubImpl { fn root_unsubscribe(&self, id: SubscriptionId) -> Result { self.unsubscribe(id) } + + fn get_version(&self) -> Result { + let version = solana_version::Version::default(); + Ok(RpcVersionInfo { + solana_core: version.to_string(), + feature_set: Some(version.feature_set), + }) + } } #[cfg(test)] @@ -1370,4 +1382,21 @@ mod tests { assert!(rpc.vote_unsubscribe(42.into()).is_err()); assert!(rpc.vote_unsubscribe(sub_id).is_ok()); } + + #[test] + fn test_get_version() { + let GenesisConfigInfo { genesis_config, .. } = create_genesis_config(10_000); + let bank = Bank::new_for_tests(&genesis_config); + let bank_forks = Arc::new(RwLock::new(BankForks::new(bank))); + let max_complete_transaction_status_slot = Arc::new(AtomicU64::default()); + let rpc_subscriptions = Arc::new(RpcSubscriptions::default_with_bank_forks( + max_complete_transaction_status_slot, + bank_forks, + )); + let (rpc, _receiver) = rpc_pubsub_service::test_connection(&rpc_subscriptions); + let version = rpc.get_version().unwrap(); + let expected_version = solana_version::Version::default(); + assert_eq!(version.to_string(), expected_version.to_string()); + assert_eq!(version.feature_set.unwrap(), expected_version.feature_set); + } }