Add Pubsub getVersion, and support programSubscribe filter mapping (#26482)
* Add pubsub getVersion api * Generalize maybe_map_filters * Add filter mapping to blocking PubsubClient * Add version tracking to nonblocking PubsubClient * Add filter mapping to nonblocking PubsubClient
This commit is contained in:
parent
312748721d
commit
b8b521535c
|
@ -6,9 +6,10 @@ use {
|
|||
RpcProgramAccountsConfig, RpcSignatureSubscribeConfig, RpcTransactionLogsConfig,
|
||||
RpcTransactionLogsFilter,
|
||||
},
|
||||
rpc_filter::maybe_map_filters,
|
||||
rpc_response::{
|
||||
Response as RpcResponse, RpcBlockUpdate, RpcKeyedAccount, RpcLogsResponse,
|
||||
RpcSignatureResult, RpcVote, SlotInfo, SlotUpdate,
|
||||
RpcSignatureResult, RpcVersionInfo, RpcVote, SlotInfo, SlotUpdate,
|
||||
},
|
||||
},
|
||||
futures_util::{
|
||||
|
@ -25,7 +26,7 @@ use {
|
|||
thiserror::Error,
|
||||
tokio::{
|
||||
net::TcpStream,
|
||||
sync::{mpsc, oneshot},
|
||||
sync::{mpsc, oneshot, RwLock},
|
||||
task::JoinHandle,
|
||||
time::{sleep, Duration},
|
||||
},
|
||||
|
@ -62,6 +63,9 @@ pub enum PubsubClientError {
|
|||
|
||||
#[error("subscribe failed: {reason}")]
|
||||
SubscribeFailed { reason: String, message: String },
|
||||
|
||||
#[error("request failed: {reason}")]
|
||||
RequestFailed { reason: String, message: String },
|
||||
}
|
||||
|
||||
type UnsubscribeFn = Box<dyn FnOnce() -> BoxFuture<'static, ()> + Send>;
|
||||
|
@ -69,11 +73,18 @@ type SubscribeResponseMsg =
|
|||
Result<(mpsc::UnboundedReceiver<Value>, UnsubscribeFn), PubsubClientError>;
|
||||
type SubscribeRequestMsg = (String, Value, oneshot::Sender<SubscribeResponseMsg>);
|
||||
type SubscribeResult<'a, T> = PubsubClientResult<(BoxStream<'a, T>, UnsubscribeFn)>;
|
||||
type RequestMsg = (
|
||||
String,
|
||||
Value,
|
||||
oneshot::Sender<Result<Value, PubsubClientError>>,
|
||||
);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PubsubClient {
|
||||
subscribe_tx: mpsc::UnboundedSender<SubscribeRequestMsg>,
|
||||
request_tx: mpsc::UnboundedSender<RequestMsg>,
|
||||
shutdown_tx: oneshot::Sender<()>,
|
||||
node_version: RwLock<Option<semver::Version>>,
|
||||
ws: JoinHandle<PubsubClientResult>,
|
||||
}
|
||||
|
||||
|
@ -85,12 +96,20 @@ impl PubsubClient {
|
|||
.map_err(PubsubClientError::ConnectionError)?;
|
||||
|
||||
let (subscribe_tx, subscribe_rx) = mpsc::unbounded_channel();
|
||||
let (request_tx, request_rx) = mpsc::unbounded_channel();
|
||||
let (shutdown_tx, shutdown_rx) = oneshot::channel();
|
||||
|
||||
Ok(Self {
|
||||
subscribe_tx,
|
||||
request_tx,
|
||||
shutdown_tx,
|
||||
ws: tokio::spawn(PubsubClient::run_ws(ws, subscribe_rx, shutdown_rx)),
|
||||
node_version: RwLock::new(None),
|
||||
ws: tokio::spawn(PubsubClient::run_ws(
|
||||
ws,
|
||||
subscribe_rx,
|
||||
request_rx,
|
||||
shutdown_rx,
|
||||
)),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -99,6 +118,37 @@ impl PubsubClient {
|
|||
self.ws.await.unwrap() // WS future should not be cancelled or panicked
|
||||
}
|
||||
|
||||
async fn get_node_version(&self) -> PubsubClientResult<semver::Version> {
|
||||
let r_node_version = self.node_version.read().await;
|
||||
if let Some(version) = &*r_node_version {
|
||||
Ok(version.clone())
|
||||
} else {
|
||||
drop(r_node_version);
|
||||
let mut w_node_version = self.node_version.write().await;
|
||||
let node_version = self.get_version().await?;
|
||||
*w_node_version = Some(node_version.clone());
|
||||
Ok(node_version)
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_version(&self) -> PubsubClientResult<semver::Version> {
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
self.request_tx
|
||||
.send(("getVersion".to_string(), Value::Null, response_tx))
|
||||
.map_err(|err| PubsubClientError::ConnectionClosed(err.to_string()))?;
|
||||
let result = response_rx
|
||||
.await
|
||||
.map_err(|err| PubsubClientError::ConnectionClosed(err.to_string()))??;
|
||||
let node_version: RpcVersionInfo = serde_json::from_value(result)?;
|
||||
let node_version = semver::Version::parse(&node_version.solana_core).map_err(|e| {
|
||||
PubsubClientError::RequestFailed {
|
||||
reason: format!("failed to parse cluster version: {}", e),
|
||||
message: "getVersion".to_string(),
|
||||
}
|
||||
})?;
|
||||
Ok(node_version)
|
||||
}
|
||||
|
||||
async fn subscribe<'a, T>(&self, operation: &str, params: Value) -> SubscribeResult<'a, T>
|
||||
where
|
||||
T: DeserializeOwned + Send + 'a,
|
||||
|
@ -147,8 +197,22 @@ impl PubsubClient {
|
|||
pub async fn program_subscribe(
|
||||
&self,
|
||||
pubkey: &Pubkey,
|
||||
config: Option<RpcProgramAccountsConfig>,
|
||||
mut config: Option<RpcProgramAccountsConfig>,
|
||||
) -> SubscribeResult<'_, RpcResponse<RpcKeyedAccount>> {
|
||||
if let Some(ref mut config) = config {
|
||||
if let Some(ref mut filters) = config.filters {
|
||||
let node_version = self.get_node_version().await.ok();
|
||||
// If node does not support the pubsub `getVersion` method, assume version is old
|
||||
// and filters should be mapped (node_version.is_none()).
|
||||
maybe_map_filters(node_version, filters).map_err(|e| {
|
||||
PubsubClientError::RequestFailed {
|
||||
reason: e,
|
||||
message: "maybe_map_filters".to_string(),
|
||||
}
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
let params = json!([pubkey.to_string(), config]);
|
||||
self.subscribe("program", params).await
|
||||
}
|
||||
|
@ -181,12 +245,14 @@ impl PubsubClient {
|
|||
async fn run_ws(
|
||||
mut ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||
mut subscribe_rx: mpsc::UnboundedReceiver<SubscribeRequestMsg>,
|
||||
mut request_rx: mpsc::UnboundedReceiver<RequestMsg>,
|
||||
mut shutdown_rx: oneshot::Receiver<()>,
|
||||
) -> PubsubClientResult {
|
||||
let mut request_id: u64 = 0;
|
||||
|
||||
let mut requests_subscribe = BTreeMap::new();
|
||||
let mut requests_unsubscribe = BTreeMap::<u64, oneshot::Sender<()>>::new();
|
||||
let mut other_requests = BTreeMap::new();
|
||||
let mut subscriptions = BTreeMap::new();
|
||||
let (unsubscribe_tx, mut unsubscribe_rx) = mpsc::unbounded_channel();
|
||||
|
||||
|
@ -220,6 +286,13 @@ impl PubsubClient {
|
|||
ws.send(Message::Text(text)).await?;
|
||||
requests_unsubscribe.insert(request_id, response_tx);
|
||||
},
|
||||
// Read message for other requests
|
||||
Some((method, params, response_tx)) = request_rx.recv() => {
|
||||
request_id += 1;
|
||||
let text = json!({"jsonrpc":"2.0","id":request_id,"method":method,"params":params}).to_string();
|
||||
ws.send(Message::Text(text)).await?;
|
||||
other_requests.insert(request_id, response_tx);
|
||||
}
|
||||
// Read incoming WebSocket message
|
||||
next_msg = ws.next() => {
|
||||
let msg = match next_msg {
|
||||
|
@ -264,7 +337,21 @@ impl PubsubClient {
|
|||
}
|
||||
});
|
||||
|
||||
if let Some(response_tx) = requests_unsubscribe.remove(&id) {
|
||||
if let Some(response_tx) = other_requests.remove(&id) {
|
||||
match err {
|
||||
Some(reason) => {
|
||||
let _ = response_tx.send(Err(PubsubClientError::RequestFailed { reason, message: text.clone()}));
|
||||
},
|
||||
None => {
|
||||
let json_result = json.get("result").ok_or_else(|| {
|
||||
PubsubClientError::RequestFailed { reason: "missing `result` field".into(), message: text.clone() }
|
||||
})?;
|
||||
if response_tx.send(Ok(json_result.clone())).is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if let Some(response_tx) = requests_unsubscribe.remove(&id) {
|
||||
let _ = response_tx.send(()); // do not care if receiver is closed
|
||||
} else if let Some((operation, response_tx)) = requests_subscribe.remove(&id) {
|
||||
match err {
|
||||
|
|
|
@ -19,7 +19,7 @@ use {
|
|||
mock_sender::MockSender,
|
||||
rpc_client::{GetConfirmedSignaturesForAddress2Config, RpcClientConfig},
|
||||
rpc_config::{RpcAccountInfoConfig, *},
|
||||
rpc_filter::{MemcmpEncodedBytes, RpcFilterType},
|
||||
rpc_filter::{self, RpcFilterType},
|
||||
rpc_request::{RpcError, RpcRequest, RpcResponseErrorData, TokenAccountsFilter},
|
||||
rpc_response::*,
|
||||
rpc_sender::*,
|
||||
|
@ -587,24 +587,8 @@ impl RpcClient {
|
|||
mut filters: Vec<RpcFilterType>,
|
||||
) -> Result<Vec<RpcFilterType>, RpcError> {
|
||||
let node_version = self.get_node_version().await?;
|
||||
if node_version < semver::Version::new(1, 11, 2) {
|
||||
for filter in filters.iter_mut() {
|
||||
if let RpcFilterType::Memcmp(memcmp) = filter {
|
||||
match &memcmp.bytes {
|
||||
MemcmpEncodedBytes::Base58(string) => {
|
||||
memcmp.bytes = MemcmpEncodedBytes::Binary(string.clone());
|
||||
}
|
||||
MemcmpEncodedBytes::Base64(_) => {
|
||||
return Err(RpcError::RpcRequestError(format!(
|
||||
"RPC node on old version {} does not support base64 encoding for memcmp filters",
|
||||
node_version
|
||||
)));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
rpc_filter::maybe_map_filters(Some(node_version), &mut filters)
|
||||
.map_err(RpcError::RpcRequestError)?;
|
||||
Ok(filters)
|
||||
}
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ use {
|
|||
RpcProgramAccountsConfig, RpcSignatureSubscribeConfig, RpcTransactionLogsConfig,
|
||||
RpcTransactionLogsFilter,
|
||||
},
|
||||
rpc_filter,
|
||||
rpc_response::{
|
||||
Response as RpcResponse, RpcBlockUpdate, RpcKeyedAccount, RpcLogsResponse,
|
||||
RpcSignatureResult, RpcVote, SlotInfo, SlotUpdate,
|
||||
|
@ -48,6 +49,9 @@ pub enum PubsubClientError {
|
|||
|
||||
#[error("unexpected message format: {0}")]
|
||||
UnexpectedMessageError(String),
|
||||
|
||||
#[error("request error: {0}")]
|
||||
RequestError(String),
|
||||
}
|
||||
|
||||
pub struct PubsubClientSubscription<T>
|
||||
|
@ -123,6 +127,43 @@ where
|
|||
.map_err(|err| err.into())
|
||||
}
|
||||
|
||||
fn get_version(
|
||||
writable_socket: &Arc<RwLock<WebSocket<MaybeTlsStream<TcpStream>>>>,
|
||||
) -> Result<semver::Version, PubsubClientError> {
|
||||
writable_socket
|
||||
.write()
|
||||
.unwrap()
|
||||
.write_message(Message::Text(
|
||||
json!({
|
||||
"jsonrpc":"2.0","id":1,"method":"getVersion",
|
||||
})
|
||||
.to_string(),
|
||||
))?;
|
||||
let message = writable_socket.write().unwrap().read_message()?;
|
||||
let message_text = &message.into_text()?;
|
||||
let json_msg: Map<String, Value> = serde_json::from_str(message_text)?;
|
||||
|
||||
if let Some(Object(version_map)) = json_msg.get("result") {
|
||||
if let Some(node_version) = version_map.get("solana-core") {
|
||||
let node_version = semver::Version::parse(
|
||||
node_version.as_str().unwrap_or_default(),
|
||||
)
|
||||
.map_err(|e| {
|
||||
PubsubClientError::RequestError(format!(
|
||||
"failed to parse cluster version: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
return Ok(node_version);
|
||||
}
|
||||
}
|
||||
// TODO: Add proper JSON RPC response/error handling...
|
||||
Err(PubsubClientError::UnexpectedMessageError(format!(
|
||||
"{:?}",
|
||||
json_msg
|
||||
)))
|
||||
}
|
||||
|
||||
fn read_message(
|
||||
writable_socket: &Arc<RwLock<WebSocket<MaybeTlsStream<TcpStream>>>>,
|
||||
) -> Result<T, PubsubClientError> {
|
||||
|
@ -357,7 +398,7 @@ impl PubsubClient {
|
|||
pub fn program_subscribe(
|
||||
url: &str,
|
||||
pubkey: &Pubkey,
|
||||
config: Option<RpcProgramAccountsConfig>,
|
||||
mut config: Option<RpcProgramAccountsConfig>,
|
||||
) -> Result<ProgramSubscription, PubsubClientError> {
|
||||
let url = Url::parse(url)?;
|
||||
let socket = connect_with_retry(url)?;
|
||||
|
@ -367,6 +408,17 @@ impl PubsubClient {
|
|||
let socket_clone = socket.clone();
|
||||
let exit = Arc::new(AtomicBool::new(false));
|
||||
let exit_clone = exit.clone();
|
||||
|
||||
if let Some(ref mut config) = config {
|
||||
if let Some(ref mut filters) = config.filters {
|
||||
let node_version = PubsubProgramClientSubscription::get_version(&socket_clone).ok();
|
||||
// If node does not support the pubsub `getVersion` method, assume version is old
|
||||
// and filters should be mapped (node_version.is_none()).
|
||||
rpc_filter::maybe_map_filters(node_version, filters)
|
||||
.map_err(PubsubClientError::RequestError)?;
|
||||
}
|
||||
}
|
||||
|
||||
let body = json!({
|
||||
"jsonrpc":"2.0",
|
||||
"id":1,
|
||||
|
|
|
@ -259,6 +259,30 @@ impl From<RpcMemcmp> for Memcmp {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn maybe_map_filters(
|
||||
node_version: Option<semver::Version>,
|
||||
filters: &mut [RpcFilterType],
|
||||
) -> Result<(), String> {
|
||||
if node_version.is_none() || node_version.unwrap() < semver::Version::new(1, 11, 2) {
|
||||
for filter in filters.iter_mut() {
|
||||
if let RpcFilterType::Memcmp(memcmp) = filter {
|
||||
match &memcmp.bytes {
|
||||
MemcmpEncodedBytes::Base58(string) => {
|
||||
memcmp.bytes = MemcmpEncodedBytes::Binary(string.clone());
|
||||
}
|
||||
MemcmpEncodedBytes::Base64(_) => {
|
||||
return Err("RPC node on old version does not support base64 \
|
||||
encoding for memcmp filters"
|
||||
.to_string());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
@ -24,7 +24,7 @@ use {
|
|||
},
|
||||
rpc_response::{
|
||||
Response as RpcResponse, RpcBlockUpdate, RpcKeyedAccount, RpcLogsResponse,
|
||||
RpcSignatureResult, RpcVote, SlotInfo, SlotUpdate,
|
||||
RpcSignatureResult, RpcVersionInfo, RpcVote, SlotInfo, SlotUpdate,
|
||||
},
|
||||
},
|
||||
solana_sdk::{clock::Slot, pubkey::Pubkey, signature::Signature},
|
||||
|
@ -348,6 +348,10 @@ mod internal {
|
|||
// Unsubscribe from slot notification subscription.
|
||||
#[rpc(name = "rootUnsubscribe")]
|
||||
fn root_unsubscribe(&self, id: SubscriptionId) -> Result<bool>;
|
||||
|
||||
// Get the current solana version running on the node
|
||||
#[rpc(name = "getVersion")]
|
||||
fn get_version(&self) -> Result<RpcVersionInfo>;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -576,6 +580,14 @@ impl RpcSolPubSubInternal for RpcSolPubSubImpl {
|
|||
fn root_unsubscribe(&self, id: SubscriptionId) -> Result<bool> {
|
||||
self.unsubscribe(id)
|
||||
}
|
||||
|
||||
fn get_version(&self) -> Result<RpcVersionInfo> {
|
||||
let version = solana_version::Version::default();
|
||||
Ok(RpcVersionInfo {
|
||||
solana_core: version.to_string(),
|
||||
feature_set: Some(version.feature_set),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -1370,4 +1382,21 @@ mod tests {
|
|||
assert!(rpc.vote_unsubscribe(42.into()).is_err());
|
||||
assert!(rpc.vote_unsubscribe(sub_id).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_version() {
|
||||
let GenesisConfigInfo { genesis_config, .. } = create_genesis_config(10_000);
|
||||
let bank = Bank::new_for_tests(&genesis_config);
|
||||
let bank_forks = Arc::new(RwLock::new(BankForks::new(bank)));
|
||||
let max_complete_transaction_status_slot = Arc::new(AtomicU64::default());
|
||||
let rpc_subscriptions = Arc::new(RpcSubscriptions::default_with_bank_forks(
|
||||
max_complete_transaction_status_slot,
|
||||
bank_forks,
|
||||
));
|
||||
let (rpc, _receiver) = rpc_pubsub_service::test_connection(&rpc_subscriptions);
|
||||
let version = rpc.get_version().unwrap();
|
||||
let expected_version = solana_version::Version::default();
|
||||
assert_eq!(version.to_string(), expected_version.to_string());
|
||||
assert_eq!(version.feature_set.unwrap(), expected_version.feature_set);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue