diff --git a/Cargo.lock b/Cargo.lock index 4ff070c..eff9760 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5904,6 +5904,7 @@ dependencies = [ "prost 0.9.0", "prost-derive 0.9.0", "tokio", + "tokio-rustls", "tokio-stream", "tokio-util", "tower", diff --git a/connector-mango/example-config.toml b/connector-mango/example-config.toml index dc25b3b..f092d4d 100644 --- a/connector-mango/example-config.toml +++ b/connector-mango/example-config.toml @@ -5,6 +5,11 @@ name = "server" connection_string = "http://[::1]:10000" retry_connection_sleep_secs = 30 +#[grpc_sources.tls] +#ca_cert_path = "ca.pem" +#client_cert_path = "client.pem" +#client_key_path = "client.pem" + [snapshot_source] rpc_http_url = "" program_id = "mv3ekLzLbnVPNxjSKvqBpU3ZeZXPQdEC3bp5MDEBG68" diff --git a/connector-raw/example-config.toml b/connector-raw/example-config.toml index da53158..d5da7a5 100644 --- a/connector-raw/example-config.toml +++ b/connector-raw/example-config.toml @@ -5,6 +5,11 @@ name = "server" connection_string = "http://[::1]:10000" retry_connection_sleep_secs = 30 +#[grpc_sources.tls] +#ca_cert_path = "ca.pem" +#client_cert_path = "client.pem" +#client_key_path = "client.pem" + [snapshot_source] rpc_http_url = "" program_id = "" diff --git a/lib/Cargo.toml b/lib/Cargo.toml index b4599cb..3699fd5 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -30,7 +30,7 @@ serde = "1.0.130" serde_derive = "1.0.130" serde_json = "1.0.68" -tonic = "0.6" +tonic = { version = "0.6", features = ["tls"] } prost = "0.9" bs58 = "0.3.1" diff --git a/lib/src/grpc_plugin_source.rs b/lib/src/grpc_plugin_source.rs index 53132fc..9203bfe 100644 --- a/lib/src/grpc_plugin_source.rs +++ b/lib/src/grpc_plugin_source.rs @@ -8,7 +8,7 @@ use solana_rpc::{rpc::rpc_full::FullClient, rpc::OptionalContext}; use solana_sdk::{account::Account, commitment_config::CommitmentConfig, pubkey::Pubkey}; use futures::{future, future::FutureExt}; -use tonic::transport::Endpoint; +use tonic::transport::{Certificate, ClientTlsConfig, Endpoint, Identity}; use log::*; use std::{collections::HashMap, str::FromStr, time::Duration}; @@ -20,7 +20,7 @@ use accountsdb_proto::accounts_db_client::AccountsDbClient; use crate::{ metrics, AccountWrite, AnyhowWrap, Config, GrpcSourceConfig, SlotStatus, SlotUpdate, - SnapshotSourceConfig, + SnapshotSourceConfig, TlsConfig, }; type SnapshotData = Response>; @@ -75,13 +75,21 @@ async fn get_snapshot( async fn feed_data_accountsdb( grpc_config: &GrpcSourceConfig, + tls_config: Option, snapshot_config: &SnapshotSourceConfig, sender: async_channel::Sender, ) -> anyhow::Result<()> { let program_id = Pubkey::from_str(&snapshot_config.program_id)?; - let mut client = - AccountsDbClient::connect(Endpoint::from_str(&grpc_config.connection_string)?).await?; + let endpoint = Endpoint::from_str(&grpc_config.connection_string)?; + let channel = if let Some(tls) = tls_config { + endpoint.tls_config(tls)? + } else { + endpoint + } + .connect() + .await?; + let mut client = AccountsDbClient::new(channel); let mut update_stream = client .subscribe(accountsdb_proto::SubscribeRequest {}) @@ -141,6 +149,18 @@ async fn feed_data_accountsdb( } } +fn make_tls_config(config: &TlsConfig) -> ClientTlsConfig { + let server_root_ca_cert = + std::fs::read(&config.ca_cert_path).expect("reading server root ca cert"); + let server_root_ca_cert = Certificate::from_pem(server_root_ca_cert); + let client_cert = std::fs::read(&config.client_cert_path).expect("reading client cert"); + let client_key = std::fs::read(&config.client_key_path).expect("reading client key"); + let client_identity = Identity::from_pem(client_cert, client_key); + ClientTlsConfig::new() + .ca_certificate(server_root_ca_cert) + .identity(client_identity) +} + pub async fn process_events( config: Config, account_write_queue_sender: async_channel::Sender, @@ -153,6 +173,10 @@ pub async fn process_events( let msg_sender = msg_sender.clone(); let snapshot_source = config.snapshot_source.clone(); let metrics_sender = metrics_sender.clone(); + + // Make TLS config if configured + let tls_config = grpc_source.tls.as_ref().map(make_tls_config); + tokio::spawn(async move { let mut metric_retries = metrics_sender.register_u64(format!( "grpc_source_{}_connection_retries", @@ -164,7 +188,12 @@ pub async fn process_events( // Continuously reconnect on failure loop { metric_status.set("connected".into()); - let out = feed_data_accountsdb(&grpc_source, &snapshot_source, msg_sender.clone()); + let out = feed_data_accountsdb( + &grpc_source, + tls_config.clone(), + &snapshot_source, + msg_sender.clone(), + ); let result = out.await; assert!(result.is_err()); if let Err(err) = result { diff --git a/lib/src/lib.rs b/lib/src/lib.rs index 5041c60..732324f 100644 --- a/lib/src/lib.rs +++ b/lib/src/lib.rs @@ -84,11 +84,19 @@ pub struct PostgresConfig { pub allow_invalid_certs: bool, } +#[derive(Clone, Debug, Deserialize)] +pub struct TlsConfig { + pub ca_cert_path: String, + pub client_cert_path: String, + pub client_key_path: String, +} + #[derive(Clone, Debug, Deserialize)] pub struct GrpcSourceConfig { pub name: String, pub connection_string: String, pub retry_connection_sleep_secs: u64, + pub tls: Option, } #[derive(Clone, Debug, Deserialize)]