Simplify access token refreshing

This commit is contained in:
Michael Vines 2020-07-24 13:53:02 -07:00
parent e56ea138c7
commit 1f7af14386
2 changed files with 96 additions and 74 deletions

View File

@ -2,11 +2,16 @@
use goauth::{ use goauth::{
auth::{JwtClaims, Token}, auth::{JwtClaims, Token},
credentials::Credentials, credentials::Credentials,
get_token,
}; };
use log::*; use log::*;
use smpl_jwt::Jwt; use smpl_jwt::Jwt;
use std::time::Instant; use std::{
sync::{
atomic::{AtomicBool, Ordering},
{Arc, RwLock},
},
time::Instant,
};
pub use goauth::scopes::Scope; pub use goauth::scopes::Scope;
@ -23,36 +28,29 @@ fn load_credentials() -> Result<Credentials, String> {
}) })
} }
#[derive(Clone)]
pub struct AccessToken { pub struct AccessToken {
credentials: Credentials, credentials: Credentials,
jwt: Jwt<JwtClaims>, scope: Scope,
token: Option<(Token, Instant)>, refresh_active: Arc<AtomicBool>,
token: Arc<RwLock<(Token, Instant)>>,
} }
impl AccessToken { impl AccessToken {
pub fn new(scope: &Scope) -> Result<Self, String> { pub async fn new(scope: Scope) -> Result<Self, String> {
let credentials = load_credentials()?; let credentials = load_credentials()?;
if let Err(err) = credentials.rsa_key() {
let claims = JwtClaims::new( Err(format!("Invalid rsa key: {}", err))
credentials.iss(), } else {
&scope, let token = Arc::new(RwLock::new(Self::get_token(&credentials, &scope).await?));
credentials.token_uri(), let access_token = Self {
None,
None,
);
let jwt = Jwt::new(
claims,
credentials
.rsa_key()
.map_err(|err| format!("Invalid rsa key: {}", err))?,
None,
);
Ok(Self {
credentials, credentials,
jwt, scope,
token: None, token,
}) refresh_active: Arc::new(AtomicBool::new(false)),
};
Ok(access_token)
}
} }
/// The project that this token grants access to /// The project that this token grants access to
@ -60,32 +58,61 @@ impl AccessToken {
self.credentials.project() self.credentials.project()
} }
/// Call this function regularly, and before calling `access_token()` async fn get_token(
pub async fn refresh(&mut self) { credentials: &Credentials,
if let Some((token, last_refresh)) = self.token.as_ref() { scope: &Scope,
if last_refresh.elapsed().as_secs() < token.expires_in() as u64 / 2 { ) -> Result<(Token, Instant), String> {
info!("Requesting token for {:?} scope", scope);
let claims = JwtClaims::new(
credentials.iss(),
scope,
credentials.token_uri(),
None,
None,
);
let jwt = Jwt::new(claims, credentials.rsa_key().unwrap(), None);
let token = goauth::get_token(&jwt, credentials)
.await
.map_err(|err| format!("Failed to refresh access token: {}", err))?;
info!("Token expires in {} seconds", token.expires_in());
Ok((token, Instant::now()))
}
/// Call this function regularly to ensure the access token does not expire
pub async fn refresh(&self) {
// Check if it's time to try a token refresh
{
let token_r = self.token.read().unwrap();
if token_r.1.elapsed().as_secs() < token_r.0.expires_in() as u64 / 2 {
return;
}
if self
.refresh_active
.compare_and_swap(false, true, Ordering::Relaxed)
{
// Refresh already pending
return; return;
} }
} }
info!("Refreshing token"); info!("Refreshing token");
match get_token(&self.jwt, &self.credentials).await { let new_token = Self::get_token(&self.credentials, &self.scope).await;
Ok(new_token) => { {
info!("Token expires in {} seconds", new_token.expires_in()); let mut token_w = self.token.write().unwrap();
self.token = Some((new_token, Instant::now())); match new_token {
} Ok(new_token) => *token_w = new_token,
Err(err) => { Err(err) => warn!("{}", err),
warn!("Failed to get new token: {}", err);
} }
self.refresh_active.store(false, Ordering::Relaxed);
} }
} }
/// Return an access token suitable for use in an HTTP authorization header /// Return an access token suitable for use in an HTTP authorization header
pub fn get(&self) -> Result<String, String> { pub fn get(&self) -> String {
if let Some((token, _)) = self.token.as_ref() { let token_r = self.token.read().unwrap();
Ok(format!("{} {}", token.token_type(), token.access_token())) format!("{} {}", token_r.0.token_type(), token_r.0.access_token())
} else {
Err("Access token not available".into())
}
} }
} }

View File

@ -4,7 +4,6 @@ use crate::access_token::{AccessToken, Scope};
use crate::compression::{compress_best, decompress}; use crate::compression::{compress_best, decompress};
use crate::root_ca_certificate; use crate::root_ca_certificate;
use log::*; use log::*;
use std::sync::{Arc, RwLock};
use thiserror::Error; use thiserror::Error;
use tonic::{metadata::MetadataValue, transport::ClientTlsConfig, Request}; use tonic::{metadata::MetadataValue, transport::ClientTlsConfig, Request};
@ -86,7 +85,7 @@ pub type Result<T> = std::result::Result<T, Error>;
#[derive(Clone)] #[derive(Clone)]
pub struct BigTableConnection { pub struct BigTableConnection {
access_token: Option<Arc<RwLock<AccessToken>>>, access_token: Option<AccessToken>,
channel: tonic::transport::Channel, channel: tonic::transport::Channel,
table_prefix: String, table_prefix: String,
} }
@ -115,14 +114,14 @@ impl BigTableConnection {
} }
Err(_) => { Err(_) => {
let mut access_token = AccessToken::new(if read_only { let access_token = AccessToken::new(if read_only {
&Scope::BigTableDataReadOnly Scope::BigTableDataReadOnly
} else { } else {
&Scope::BigTableData Scope::BigTableData
}) })
.await
.map_err(Error::AccessTokenError)?; .map_err(Error::AccessTokenError)?;
access_token.refresh().await;
let table_prefix = format!( let table_prefix = format!(
"projects/{}/instances/{}/tables/", "projects/{}/instances/{}/tables/",
access_token.project(), access_token.project(),
@ -130,7 +129,7 @@ impl BigTableConnection {
); );
Ok(Self { Ok(Self {
access_token: Some(Arc::new(RwLock::new(access_token))), access_token: Some(access_token),
channel: tonic::transport::Channel::from_static( channel: tonic::transport::Channel::from_static(
"https://bigtable.googleapis.com", "https://bigtable.googleapis.com",
) )
@ -153,28 +152,25 @@ impl BigTableConnection {
/// Clients require `&mut self`, due to `Tonic::transport::Channel` limitations, however /// Clients require `&mut self`, due to `Tonic::transport::Channel` limitations, however
/// creating new clients is cheap and thus can be used as a work around for ease of use. /// creating new clients is cheap and thus can be used as a work around for ease of use.
pub fn client(&self) -> BigTable { pub fn client(&self) -> BigTable {
let client = { let client = if let Some(access_token) = &self.access_token {
if let Some(ref access_token) = self.access_token {
let access_token = access_token.clone(); let access_token = access_token.clone();
bigtable_client::BigtableClient::with_interceptor( bigtable_client::BigtableClient::with_interceptor(
self.channel.clone(), self.channel.clone(),
move |mut req: Request<()>| { move |mut req: Request<()>| {
match access_token.read().unwrap().get() { match MetadataValue::from_str(&access_token.get()) {
Ok(access_token) => match MetadataValue::from_str(&access_token) {
Ok(authorization_header) => { Ok(authorization_header) => {
req.metadata_mut() req.metadata_mut()
.insert("authorization", authorization_header); .insert("authorization", authorization_header);
} }
Err(err) => warn!("Failed to set authorization header: {}", err), Err(err) => {
}, warn!("Failed to set authorization header: {}", err);
Err(err) => warn!("{}", err), }
} }
Ok(req) Ok(req)
}, },
) )
} else { } else {
bigtable_client::BigtableClient::new(self.channel.clone()) bigtable_client::BigtableClient::new(self.channel.clone())
}
}; };
BigTable { BigTable {
access_token: self.access_token.clone(), access_token: self.access_token.clone(),
@ -202,7 +198,7 @@ impl BigTableConnection {
} }
pub struct BigTable { pub struct BigTable {
access_token: Option<Arc<RwLock<AccessToken>>>, access_token: Option<AccessToken>,
client: bigtable_client::BigtableClient<tonic::transport::Channel>, client: bigtable_client::BigtableClient<tonic::transport::Channel>,
table_prefix: String, table_prefix: String,
} }
@ -283,7 +279,7 @@ impl BigTable {
async fn refresh_access_token(&self) { async fn refresh_access_token(&self) {
if let Some(ref access_token) = self.access_token { if let Some(ref access_token) = self.access_token {
access_token.write().unwrap().refresh().await; access_token.refresh().await;
} }
} }
@ -298,7 +294,6 @@ impl BigTable {
rows_limit: i64, rows_limit: i64,
) -> Result<Vec<RowKey>> { ) -> Result<Vec<RowKey>> {
self.refresh_access_token().await; self.refresh_access_token().await;
let response = self let response = self
.client .client
.read_rows(ReadRowsRequest { .read_rows(ReadRowsRequest {