Add Pubsub getVersion, and support programSubscribe filter mapping (#26482)

* Add pubsub getVersion api

* Generalize maybe_map_filters

* Add filter mapping to blocking PubsubClient

* Add version tracking to nonblocking PubsubClient

* Add filter mapping to nonblocking PubsubClient
This commit is contained in:
Tyera Eulberg 2022-07-07 20:55:18 -06:00 committed by GitHub
parent 312748721d
commit b8b521535c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 202 additions and 26 deletions

View File

@ -6,9 +6,10 @@ use {
RpcProgramAccountsConfig, RpcSignatureSubscribeConfig, RpcTransactionLogsConfig, RpcProgramAccountsConfig, RpcSignatureSubscribeConfig, RpcTransactionLogsConfig,
RpcTransactionLogsFilter, RpcTransactionLogsFilter,
}, },
rpc_filter::maybe_map_filters,
rpc_response::{ rpc_response::{
Response as RpcResponse, RpcBlockUpdate, RpcKeyedAccount, RpcLogsResponse, Response as RpcResponse, RpcBlockUpdate, RpcKeyedAccount, RpcLogsResponse,
RpcSignatureResult, RpcVote, SlotInfo, SlotUpdate, RpcSignatureResult, RpcVersionInfo, RpcVote, SlotInfo, SlotUpdate,
}, },
}, },
futures_util::{ futures_util::{
@ -25,7 +26,7 @@ use {
thiserror::Error, thiserror::Error,
tokio::{ tokio::{
net::TcpStream, net::TcpStream,
sync::{mpsc, oneshot}, sync::{mpsc, oneshot, RwLock},
task::JoinHandle, task::JoinHandle,
time::{sleep, Duration}, time::{sleep, Duration},
}, },
@ -62,6 +63,9 @@ pub enum PubsubClientError {
#[error("subscribe failed: {reason}")] #[error("subscribe failed: {reason}")]
SubscribeFailed { reason: String, message: String }, SubscribeFailed { reason: String, message: String },
#[error("request failed: {reason}")]
RequestFailed { reason: String, message: String },
} }
type UnsubscribeFn = Box<dyn FnOnce() -> BoxFuture<'static, ()> + Send>; type UnsubscribeFn = Box<dyn FnOnce() -> BoxFuture<'static, ()> + Send>;
@ -69,11 +73,18 @@ type SubscribeResponseMsg =
Result<(mpsc::UnboundedReceiver<Value>, UnsubscribeFn), PubsubClientError>; Result<(mpsc::UnboundedReceiver<Value>, UnsubscribeFn), PubsubClientError>;
type SubscribeRequestMsg = (String, Value, oneshot::Sender<SubscribeResponseMsg>); type SubscribeRequestMsg = (String, Value, oneshot::Sender<SubscribeResponseMsg>);
type SubscribeResult<'a, T> = PubsubClientResult<(BoxStream<'a, T>, UnsubscribeFn)>; type SubscribeResult<'a, T> = PubsubClientResult<(BoxStream<'a, T>, UnsubscribeFn)>;
type RequestMsg = (
String,
Value,
oneshot::Sender<Result<Value, PubsubClientError>>,
);
#[derive(Debug)] #[derive(Debug)]
pub struct PubsubClient { pub struct PubsubClient {
subscribe_tx: mpsc::UnboundedSender<SubscribeRequestMsg>, subscribe_tx: mpsc::UnboundedSender<SubscribeRequestMsg>,
request_tx: mpsc::UnboundedSender<RequestMsg>,
shutdown_tx: oneshot::Sender<()>, shutdown_tx: oneshot::Sender<()>,
node_version: RwLock<Option<semver::Version>>,
ws: JoinHandle<PubsubClientResult>, ws: JoinHandle<PubsubClientResult>,
} }
@ -85,12 +96,20 @@ impl PubsubClient {
.map_err(PubsubClientError::ConnectionError)?; .map_err(PubsubClientError::ConnectionError)?;
let (subscribe_tx, subscribe_rx) = mpsc::unbounded_channel(); let (subscribe_tx, subscribe_rx) = mpsc::unbounded_channel();
let (request_tx, request_rx) = mpsc::unbounded_channel();
let (shutdown_tx, shutdown_rx) = oneshot::channel(); let (shutdown_tx, shutdown_rx) = oneshot::channel();
Ok(Self { Ok(Self {
subscribe_tx, subscribe_tx,
request_tx,
shutdown_tx, shutdown_tx,
ws: tokio::spawn(PubsubClient::run_ws(ws, subscribe_rx, shutdown_rx)), node_version: RwLock::new(None),
ws: tokio::spawn(PubsubClient::run_ws(
ws,
subscribe_rx,
request_rx,
shutdown_rx,
)),
}) })
} }
@ -99,6 +118,37 @@ impl PubsubClient {
self.ws.await.unwrap() // WS future should not be cancelled or panicked self.ws.await.unwrap() // WS future should not be cancelled or panicked
} }
async fn get_node_version(&self) -> PubsubClientResult<semver::Version> {
let r_node_version = self.node_version.read().await;
if let Some(version) = &*r_node_version {
Ok(version.clone())
} else {
drop(r_node_version);
let mut w_node_version = self.node_version.write().await;
let node_version = self.get_version().await?;
*w_node_version = Some(node_version.clone());
Ok(node_version)
}
}
async fn get_version(&self) -> PubsubClientResult<semver::Version> {
let (response_tx, response_rx) = oneshot::channel();
self.request_tx
.send(("getVersion".to_string(), Value::Null, response_tx))
.map_err(|err| PubsubClientError::ConnectionClosed(err.to_string()))?;
let result = response_rx
.await
.map_err(|err| PubsubClientError::ConnectionClosed(err.to_string()))??;
let node_version: RpcVersionInfo = serde_json::from_value(result)?;
let node_version = semver::Version::parse(&node_version.solana_core).map_err(|e| {
PubsubClientError::RequestFailed {
reason: format!("failed to parse cluster version: {}", e),
message: "getVersion".to_string(),
}
})?;
Ok(node_version)
}
async fn subscribe<'a, T>(&self, operation: &str, params: Value) -> SubscribeResult<'a, T> async fn subscribe<'a, T>(&self, operation: &str, params: Value) -> SubscribeResult<'a, T>
where where
T: DeserializeOwned + Send + 'a, T: DeserializeOwned + Send + 'a,
@ -147,8 +197,22 @@ impl PubsubClient {
pub async fn program_subscribe( pub async fn program_subscribe(
&self, &self,
pubkey: &Pubkey, pubkey: &Pubkey,
config: Option<RpcProgramAccountsConfig>, mut config: Option<RpcProgramAccountsConfig>,
) -> SubscribeResult<'_, RpcResponse<RpcKeyedAccount>> { ) -> SubscribeResult<'_, RpcResponse<RpcKeyedAccount>> {
if let Some(ref mut config) = config {
if let Some(ref mut filters) = config.filters {
let node_version = self.get_node_version().await.ok();
// If node does not support the pubsub `getVersion` method, assume version is old
// and filters should be mapped (node_version.is_none()).
maybe_map_filters(node_version, filters).map_err(|e| {
PubsubClientError::RequestFailed {
reason: e,
message: "maybe_map_filters".to_string(),
}
})?;
}
}
let params = json!([pubkey.to_string(), config]); let params = json!([pubkey.to_string(), config]);
self.subscribe("program", params).await self.subscribe("program", params).await
} }
@ -181,12 +245,14 @@ impl PubsubClient {
async fn run_ws( async fn run_ws(
mut ws: WebSocketStream<MaybeTlsStream<TcpStream>>, mut ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
mut subscribe_rx: mpsc::UnboundedReceiver<SubscribeRequestMsg>, mut subscribe_rx: mpsc::UnboundedReceiver<SubscribeRequestMsg>,
mut request_rx: mpsc::UnboundedReceiver<RequestMsg>,
mut shutdown_rx: oneshot::Receiver<()>, mut shutdown_rx: oneshot::Receiver<()>,
) -> PubsubClientResult { ) -> PubsubClientResult {
let mut request_id: u64 = 0; let mut request_id: u64 = 0;
let mut requests_subscribe = BTreeMap::new(); let mut requests_subscribe = BTreeMap::new();
let mut requests_unsubscribe = BTreeMap::<u64, oneshot::Sender<()>>::new(); let mut requests_unsubscribe = BTreeMap::<u64, oneshot::Sender<()>>::new();
let mut other_requests = BTreeMap::new();
let mut subscriptions = BTreeMap::new(); let mut subscriptions = BTreeMap::new();
let (unsubscribe_tx, mut unsubscribe_rx) = mpsc::unbounded_channel(); let (unsubscribe_tx, mut unsubscribe_rx) = mpsc::unbounded_channel();
@ -220,6 +286,13 @@ impl PubsubClient {
ws.send(Message::Text(text)).await?; ws.send(Message::Text(text)).await?;
requests_unsubscribe.insert(request_id, response_tx); requests_unsubscribe.insert(request_id, response_tx);
}, },
// Read message for other requests
Some((method, params, response_tx)) = request_rx.recv() => {
request_id += 1;
let text = json!({"jsonrpc":"2.0","id":request_id,"method":method,"params":params}).to_string();
ws.send(Message::Text(text)).await?;
other_requests.insert(request_id, response_tx);
}
// Read incoming WebSocket message // Read incoming WebSocket message
next_msg = ws.next() => { next_msg = ws.next() => {
let msg = match next_msg { let msg = match next_msg {
@ -264,7 +337,21 @@ impl PubsubClient {
} }
}); });
if let Some(response_tx) = requests_unsubscribe.remove(&id) { if let Some(response_tx) = other_requests.remove(&id) {
match err {
Some(reason) => {
let _ = response_tx.send(Err(PubsubClientError::RequestFailed { reason, message: text.clone()}));
},
None => {
let json_result = json.get("result").ok_or_else(|| {
PubsubClientError::RequestFailed { reason: "missing `result` field".into(), message: text.clone() }
})?;
if response_tx.send(Ok(json_result.clone())).is_err() {
break;
}
}
}
} else if let Some(response_tx) = requests_unsubscribe.remove(&id) {
let _ = response_tx.send(()); // do not care if receiver is closed let _ = response_tx.send(()); // do not care if receiver is closed
} else if let Some((operation, response_tx)) = requests_subscribe.remove(&id) { } else if let Some((operation, response_tx)) = requests_subscribe.remove(&id) {
match err { match err {

View File

@ -19,7 +19,7 @@ use {
mock_sender::MockSender, mock_sender::MockSender,
rpc_client::{GetConfirmedSignaturesForAddress2Config, RpcClientConfig}, rpc_client::{GetConfirmedSignaturesForAddress2Config, RpcClientConfig},
rpc_config::{RpcAccountInfoConfig, *}, rpc_config::{RpcAccountInfoConfig, *},
rpc_filter::{MemcmpEncodedBytes, RpcFilterType}, rpc_filter::{self, RpcFilterType},
rpc_request::{RpcError, RpcRequest, RpcResponseErrorData, TokenAccountsFilter}, rpc_request::{RpcError, RpcRequest, RpcResponseErrorData, TokenAccountsFilter},
rpc_response::*, rpc_response::*,
rpc_sender::*, rpc_sender::*,
@ -587,24 +587,8 @@ impl RpcClient {
mut filters: Vec<RpcFilterType>, mut filters: Vec<RpcFilterType>,
) -> Result<Vec<RpcFilterType>, RpcError> { ) -> Result<Vec<RpcFilterType>, RpcError> {
let node_version = self.get_node_version().await?; let node_version = self.get_node_version().await?;
if node_version < semver::Version::new(1, 11, 2) { rpc_filter::maybe_map_filters(Some(node_version), &mut filters)
for filter in filters.iter_mut() { .map_err(RpcError::RpcRequestError)?;
if let RpcFilterType::Memcmp(memcmp) = filter {
match &memcmp.bytes {
MemcmpEncodedBytes::Base58(string) => {
memcmp.bytes = MemcmpEncodedBytes::Binary(string.clone());
}
MemcmpEncodedBytes::Base64(_) => {
return Err(RpcError::RpcRequestError(format!(
"RPC node on old version {} does not support base64 encoding for memcmp filters",
node_version
)));
}
_ => {}
}
}
}
}
Ok(filters) Ok(filters)
} }

View File

@ -5,6 +5,7 @@ use {
RpcProgramAccountsConfig, RpcSignatureSubscribeConfig, RpcTransactionLogsConfig, RpcProgramAccountsConfig, RpcSignatureSubscribeConfig, RpcTransactionLogsConfig,
RpcTransactionLogsFilter, RpcTransactionLogsFilter,
}, },
rpc_filter,
rpc_response::{ rpc_response::{
Response as RpcResponse, RpcBlockUpdate, RpcKeyedAccount, RpcLogsResponse, Response as RpcResponse, RpcBlockUpdate, RpcKeyedAccount, RpcLogsResponse,
RpcSignatureResult, RpcVote, SlotInfo, SlotUpdate, RpcSignatureResult, RpcVote, SlotInfo, SlotUpdate,
@ -48,6 +49,9 @@ pub enum PubsubClientError {
#[error("unexpected message format: {0}")] #[error("unexpected message format: {0}")]
UnexpectedMessageError(String), UnexpectedMessageError(String),
#[error("request error: {0}")]
RequestError(String),
} }
pub struct PubsubClientSubscription<T> pub struct PubsubClientSubscription<T>
@ -123,6 +127,43 @@ where
.map_err(|err| err.into()) .map_err(|err| err.into())
} }
fn get_version(
writable_socket: &Arc<RwLock<WebSocket<MaybeTlsStream<TcpStream>>>>,
) -> Result<semver::Version, PubsubClientError> {
writable_socket
.write()
.unwrap()
.write_message(Message::Text(
json!({
"jsonrpc":"2.0","id":1,"method":"getVersion",
})
.to_string(),
))?;
let message = writable_socket.write().unwrap().read_message()?;
let message_text = &message.into_text()?;
let json_msg: Map<String, Value> = serde_json::from_str(message_text)?;
if let Some(Object(version_map)) = json_msg.get("result") {
if let Some(node_version) = version_map.get("solana-core") {
let node_version = semver::Version::parse(
node_version.as_str().unwrap_or_default(),
)
.map_err(|e| {
PubsubClientError::RequestError(format!(
"failed to parse cluster version: {}",
e
))
})?;
return Ok(node_version);
}
}
// TODO: Add proper JSON RPC response/error handling...
Err(PubsubClientError::UnexpectedMessageError(format!(
"{:?}",
json_msg
)))
}
fn read_message( fn read_message(
writable_socket: &Arc<RwLock<WebSocket<MaybeTlsStream<TcpStream>>>>, writable_socket: &Arc<RwLock<WebSocket<MaybeTlsStream<TcpStream>>>>,
) -> Result<T, PubsubClientError> { ) -> Result<T, PubsubClientError> {
@ -357,7 +398,7 @@ impl PubsubClient {
pub fn program_subscribe( pub fn program_subscribe(
url: &str, url: &str,
pubkey: &Pubkey, pubkey: &Pubkey,
config: Option<RpcProgramAccountsConfig>, mut config: Option<RpcProgramAccountsConfig>,
) -> Result<ProgramSubscription, PubsubClientError> { ) -> Result<ProgramSubscription, PubsubClientError> {
let url = Url::parse(url)?; let url = Url::parse(url)?;
let socket = connect_with_retry(url)?; let socket = connect_with_retry(url)?;
@ -367,6 +408,17 @@ impl PubsubClient {
let socket_clone = socket.clone(); let socket_clone = socket.clone();
let exit = Arc::new(AtomicBool::new(false)); let exit = Arc::new(AtomicBool::new(false));
let exit_clone = exit.clone(); let exit_clone = exit.clone();
if let Some(ref mut config) = config {
if let Some(ref mut filters) = config.filters {
let node_version = PubsubProgramClientSubscription::get_version(&socket_clone).ok();
// If node does not support the pubsub `getVersion` method, assume version is old
// and filters should be mapped (node_version.is_none()).
rpc_filter::maybe_map_filters(node_version, filters)
.map_err(PubsubClientError::RequestError)?;
}
}
let body = json!({ let body = json!({
"jsonrpc":"2.0", "jsonrpc":"2.0",
"id":1, "id":1,

View File

@ -259,6 +259,30 @@ impl From<RpcMemcmp> for Memcmp {
} }
} }
pub(crate) fn maybe_map_filters(
node_version: Option<semver::Version>,
filters: &mut [RpcFilterType],
) -> Result<(), String> {
if node_version.is_none() || node_version.unwrap() < semver::Version::new(1, 11, 2) {
for filter in filters.iter_mut() {
if let RpcFilterType::Memcmp(memcmp) = filter {
match &memcmp.bytes {
MemcmpEncodedBytes::Base58(string) => {
memcmp.bytes = MemcmpEncodedBytes::Binary(string.clone());
}
MemcmpEncodedBytes::Base64(_) => {
return Err("RPC node on old version does not support base64 \
encoding for memcmp filters"
.to_string());
}
_ => {}
}
}
}
}
Ok(())
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View File

@ -24,7 +24,7 @@ use {
}, },
rpc_response::{ rpc_response::{
Response as RpcResponse, RpcBlockUpdate, RpcKeyedAccount, RpcLogsResponse, Response as RpcResponse, RpcBlockUpdate, RpcKeyedAccount, RpcLogsResponse,
RpcSignatureResult, RpcVote, SlotInfo, SlotUpdate, RpcSignatureResult, RpcVersionInfo, RpcVote, SlotInfo, SlotUpdate,
}, },
}, },
solana_sdk::{clock::Slot, pubkey::Pubkey, signature::Signature}, solana_sdk::{clock::Slot, pubkey::Pubkey, signature::Signature},
@ -348,6 +348,10 @@ mod internal {
// Unsubscribe from slot notification subscription. // Unsubscribe from slot notification subscription.
#[rpc(name = "rootUnsubscribe")] #[rpc(name = "rootUnsubscribe")]
fn root_unsubscribe(&self, id: SubscriptionId) -> Result<bool>; fn root_unsubscribe(&self, id: SubscriptionId) -> Result<bool>;
// Get the current solana version running on the node
#[rpc(name = "getVersion")]
fn get_version(&self) -> Result<RpcVersionInfo>;
} }
} }
@ -576,6 +580,14 @@ impl RpcSolPubSubInternal for RpcSolPubSubImpl {
fn root_unsubscribe(&self, id: SubscriptionId) -> Result<bool> { fn root_unsubscribe(&self, id: SubscriptionId) -> Result<bool> {
self.unsubscribe(id) self.unsubscribe(id)
} }
fn get_version(&self) -> Result<RpcVersionInfo> {
let version = solana_version::Version::default();
Ok(RpcVersionInfo {
solana_core: version.to_string(),
feature_set: Some(version.feature_set),
})
}
} }
#[cfg(test)] #[cfg(test)]
@ -1370,4 +1382,21 @@ mod tests {
assert!(rpc.vote_unsubscribe(42.into()).is_err()); assert!(rpc.vote_unsubscribe(42.into()).is_err());
assert!(rpc.vote_unsubscribe(sub_id).is_ok()); assert!(rpc.vote_unsubscribe(sub_id).is_ok());
} }
#[test]
fn test_get_version() {
let GenesisConfigInfo { genesis_config, .. } = create_genesis_config(10_000);
let bank = Bank::new_for_tests(&genesis_config);
let bank_forks = Arc::new(RwLock::new(BankForks::new(bank)));
let max_complete_transaction_status_slot = Arc::new(AtomicU64::default());
let rpc_subscriptions = Arc::new(RpcSubscriptions::default_with_bank_forks(
max_complete_transaction_status_slot,
bank_forks,
));
let (rpc, _receiver) = rpc_pubsub_service::test_connection(&rpc_subscriptions);
let version = rpc.get_version().unwrap();
let expected_version = solana_version::Version::default();
assert_eq!(version.to_string(), expected_version.to_string());
assert_eq!(version.feature_set.unwrap(), expected_version.feature_set);
}
} }