From 26c5c82b4500369bf776cf6508f99c5668b1ff13 Mon Sep 17 00:00:00 2001 From: aniketfuryrocks Date: Fri, 31 Mar 2023 00:16:24 +0530 Subject: [PATCH] optional ssl --- src/workers/postgres.rs | 73 +++++++++++++++++++++++++++-------------- 1 file changed, 48 insertions(+), 25 deletions(-) diff --git a/src/workers/postgres.rs b/src/workers/postgres.rs index 4edef5e1..9539cafb 100644 --- a/src/workers/postgres.rs +++ b/src/workers/postgres.rs @@ -12,7 +12,10 @@ use tokio::{ }, task::JoinHandle, }; -use tokio_postgres::{types::ToSql, Client, Statement}; +use tokio_postgres::{ + config::SslMode, tls::MakeTlsConnect, types::ToSql, Client, NoTls, Socket, + Statement, +}; use native_tls::{Certificate, Identity, TlsConnector}; @@ -71,36 +74,37 @@ pub struct PostgresSession { impl PostgresSession { pub async fn new() -> anyhow::Result { - let ca_pem_b64 = std::env::var("CA_PEM_B64").context("env CA_PEM_B64 not found")?; - let client_pks_b64 = - std::env::var("CLIENT_PKS_B64").context("env CLIENT_PKS_B64 not found")?; - let client_pks_password = - std::env::var("CLIENT_PKS_PASS").context("env CLIENT_PKS_PASS not found")?; let pg_config = std::env::var("PG_CONFIG").context("env PG_CONFIG not found")?; + let pg_config = pg_config.parse::()?; - let ca_pem = BinaryEncoding::Base64 - .decode(ca_pem_b64) - .context("ca pem decode")?; + let client = if let SslMode::Disable = pg_config.get_ssl_mode() { + Self::spawn_connection(pg_config, NoTls).await? + } else { + let ca_pem_b64 = std::env::var("CA_PEM_B64").context("env CA_PEM_B64 not found")?; + let client_pks_b64 = + std::env::var("CLIENT_PKS_B64").context("env CLIENT_PKS_B64 not found")?; + let client_pks_password = + std::env::var("CLIENT_PKS_PASS").context("env CLIENT_PKS_PASS not found")?; - let client_pks = BinaryEncoding::Base64 - .decode(client_pks_b64) - .context("client pks decode")?; + let ca_pem = BinaryEncoding::Base64 + .decode(ca_pem_b64) + .context("ca pem decode")?; - let connector = TlsConnector::builder() - .add_root_certificate(Certificate::from_pem(&ca_pem)?) - .identity(Identity::from_pkcs12(&client_pks, &client_pks_password).context("Identity")?) - .danger_accept_invalid_hostnames(true) - .danger_accept_invalid_certs(true) - .build()?; + let client_pks = BinaryEncoding::Base64 + .decode(client_pks_b64) + .context("client pks decode")?; - let connector = MakeTlsConnector::new(connector); - let (client, connection) = tokio_postgres::connect(&pg_config, connector.clone()).await?; + let connector = TlsConnector::builder() + .add_root_certificate(Certificate::from_pem(&ca_pem)?) + .identity( + Identity::from_pkcs12(&client_pks, &client_pks_password).context("Identity")?, + ) + .danger_accept_invalid_hostnames(true) + .danger_accept_invalid_certs(true) + .build()?; - tokio::spawn(async move { - if let Err(err) = connection.await { - log::error!("Connection to Postgres broke {err:?}"); - }; - }); + Self::spawn_connection(pg_config, MakeTlsConnector::new(connector)).await? + }; let update_tx_statement = client .prepare( @@ -118,6 +122,25 @@ impl PostgresSession { }) } + async fn spawn_connection( + pg_config: tokio_postgres::Config, + connector: T, + ) -> anyhow::Result + where + T: MakeTlsConnect + Send + 'static, + >::Stream: Send, + { + let (client, connection) = pg_config.connect(connector).await.context("a")?; + + tokio::spawn(async move { + if let Err(err) = connection.await { + log::error!("Connection to Postgres broke {err:?}"); + }; + }); + + Ok(client) + } + pub fn multiline_query(query: &mut String, args: usize, rows: usize) { let mut arg_index = 1usize; for _ in 0..rows {