Simplify access token refreshing
This commit is contained in:
parent
e56ea138c7
commit
1f7af14386
|
@ -2,11 +2,16 @@
|
||||||
use goauth::{
|
use goauth::{
|
||||||
auth::{JwtClaims, Token},
|
auth::{JwtClaims, Token},
|
||||||
credentials::Credentials,
|
credentials::Credentials,
|
||||||
get_token,
|
|
||||||
};
|
};
|
||||||
use log::*;
|
use log::*;
|
||||||
use smpl_jwt::Jwt;
|
use smpl_jwt::Jwt;
|
||||||
use std::time::Instant;
|
use std::{
|
||||||
|
sync::{
|
||||||
|
atomic::{AtomicBool, Ordering},
|
||||||
|
{Arc, RwLock},
|
||||||
|
},
|
||||||
|
time::Instant,
|
||||||
|
};
|
||||||
|
|
||||||
pub use goauth::scopes::Scope;
|
pub use goauth::scopes::Scope;
|
||||||
|
|
||||||
|
@ -23,36 +28,29 @@ fn load_credentials() -> Result<Credentials, String> {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct AccessToken {
|
pub struct AccessToken {
|
||||||
credentials: Credentials,
|
credentials: Credentials,
|
||||||
jwt: Jwt<JwtClaims>,
|
scope: Scope,
|
||||||
token: Option<(Token, Instant)>,
|
refresh_active: Arc<AtomicBool>,
|
||||||
|
token: Arc<RwLock<(Token, Instant)>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AccessToken {
|
impl AccessToken {
|
||||||
pub fn new(scope: &Scope) -> Result<Self, String> {
|
pub async fn new(scope: Scope) -> Result<Self, String> {
|
||||||
let credentials = load_credentials()?;
|
let credentials = load_credentials()?;
|
||||||
|
if let Err(err) = credentials.rsa_key() {
|
||||||
let claims = JwtClaims::new(
|
Err(format!("Invalid rsa key: {}", err))
|
||||||
credentials.iss(),
|
} else {
|
||||||
&scope,
|
let token = Arc::new(RwLock::new(Self::get_token(&credentials, &scope).await?));
|
||||||
credentials.token_uri(),
|
let access_token = Self {
|
||||||
None,
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
let jwt = Jwt::new(
|
|
||||||
claims,
|
|
||||||
credentials
|
|
||||||
.rsa_key()
|
|
||||||
.map_err(|err| format!("Invalid rsa key: {}", err))?,
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
credentials,
|
credentials,
|
||||||
jwt,
|
scope,
|
||||||
token: None,
|
token,
|
||||||
})
|
refresh_active: Arc::new(AtomicBool::new(false)),
|
||||||
|
};
|
||||||
|
Ok(access_token)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The project that this token grants access to
|
/// The project that this token grants access to
|
||||||
|
@ -60,32 +58,61 @@ impl AccessToken {
|
||||||
self.credentials.project()
|
self.credentials.project()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Call this function regularly, and before calling `access_token()`
|
async fn get_token(
|
||||||
pub async fn refresh(&mut self) {
|
credentials: &Credentials,
|
||||||
if let Some((token, last_refresh)) = self.token.as_ref() {
|
scope: &Scope,
|
||||||
if last_refresh.elapsed().as_secs() < token.expires_in() as u64 / 2 {
|
) -> Result<(Token, Instant), String> {
|
||||||
|
info!("Requesting token for {:?} scope", scope);
|
||||||
|
let claims = JwtClaims::new(
|
||||||
|
credentials.iss(),
|
||||||
|
scope,
|
||||||
|
credentials.token_uri(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
let jwt = Jwt::new(claims, credentials.rsa_key().unwrap(), None);
|
||||||
|
|
||||||
|
let token = goauth::get_token(&jwt, credentials)
|
||||||
|
.await
|
||||||
|
.map_err(|err| format!("Failed to refresh access token: {}", err))?;
|
||||||
|
|
||||||
|
info!("Token expires in {} seconds", token.expires_in());
|
||||||
|
Ok((token, Instant::now()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Call this function regularly to ensure the access token does not expire
|
||||||
|
pub async fn refresh(&self) {
|
||||||
|
// Check if it's time to try a token refresh
|
||||||
|
{
|
||||||
|
let token_r = self.token.read().unwrap();
|
||||||
|
if token_r.1.elapsed().as_secs() < token_r.0.expires_in() as u64 / 2 {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if self
|
||||||
|
.refresh_active
|
||||||
|
.compare_and_swap(false, true, Ordering::Relaxed)
|
||||||
|
{
|
||||||
|
// Refresh already pending
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Refreshing token");
|
info!("Refreshing token");
|
||||||
match get_token(&self.jwt, &self.credentials).await {
|
let new_token = Self::get_token(&self.credentials, &self.scope).await;
|
||||||
Ok(new_token) => {
|
{
|
||||||
info!("Token expires in {} seconds", new_token.expires_in());
|
let mut token_w = self.token.write().unwrap();
|
||||||
self.token = Some((new_token, Instant::now()));
|
match new_token {
|
||||||
}
|
Ok(new_token) => *token_w = new_token,
|
||||||
Err(err) => {
|
Err(err) => warn!("{}", err),
|
||||||
warn!("Failed to get new token: {}", err);
|
|
||||||
}
|
}
|
||||||
|
self.refresh_active.store(false, Ordering::Relaxed);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return an access token suitable for use in an HTTP authorization header
|
/// Return an access token suitable for use in an HTTP authorization header
|
||||||
pub fn get(&self) -> Result<String, String> {
|
pub fn get(&self) -> String {
|
||||||
if let Some((token, _)) = self.token.as_ref() {
|
let token_r = self.token.read().unwrap();
|
||||||
Ok(format!("{} {}", token.token_type(), token.access_token()))
|
format!("{} {}", token_r.0.token_type(), token_r.0.access_token())
|
||||||
} else {
|
|
||||||
Err("Access token not available".into())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,6 @@ use crate::access_token::{AccessToken, Scope};
|
||||||
use crate::compression::{compress_best, decompress};
|
use crate::compression::{compress_best, decompress};
|
||||||
use crate::root_ca_certificate;
|
use crate::root_ca_certificate;
|
||||||
use log::*;
|
use log::*;
|
||||||
use std::sync::{Arc, RwLock};
|
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tonic::{metadata::MetadataValue, transport::ClientTlsConfig, Request};
|
use tonic::{metadata::MetadataValue, transport::ClientTlsConfig, Request};
|
||||||
|
|
||||||
|
@ -86,7 +85,7 @@ pub type Result<T> = std::result::Result<T, Error>;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct BigTableConnection {
|
pub struct BigTableConnection {
|
||||||
access_token: Option<Arc<RwLock<AccessToken>>>,
|
access_token: Option<AccessToken>,
|
||||||
channel: tonic::transport::Channel,
|
channel: tonic::transport::Channel,
|
||||||
table_prefix: String,
|
table_prefix: String,
|
||||||
}
|
}
|
||||||
|
@ -115,14 +114,14 @@ impl BigTableConnection {
|
||||||
}
|
}
|
||||||
|
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
let mut access_token = AccessToken::new(if read_only {
|
let access_token = AccessToken::new(if read_only {
|
||||||
&Scope::BigTableDataReadOnly
|
Scope::BigTableDataReadOnly
|
||||||
} else {
|
} else {
|
||||||
&Scope::BigTableData
|
Scope::BigTableData
|
||||||
})
|
})
|
||||||
|
.await
|
||||||
.map_err(Error::AccessTokenError)?;
|
.map_err(Error::AccessTokenError)?;
|
||||||
|
|
||||||
access_token.refresh().await;
|
|
||||||
let table_prefix = format!(
|
let table_prefix = format!(
|
||||||
"projects/{}/instances/{}/tables/",
|
"projects/{}/instances/{}/tables/",
|
||||||
access_token.project(),
|
access_token.project(),
|
||||||
|
@ -130,7 +129,7 @@ impl BigTableConnection {
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
access_token: Some(Arc::new(RwLock::new(access_token))),
|
access_token: Some(access_token),
|
||||||
channel: tonic::transport::Channel::from_static(
|
channel: tonic::transport::Channel::from_static(
|
||||||
"https://bigtable.googleapis.com",
|
"https://bigtable.googleapis.com",
|
||||||
)
|
)
|
||||||
|
@ -153,28 +152,25 @@ impl BigTableConnection {
|
||||||
/// Clients require `&mut self`, due to `Tonic::transport::Channel` limitations, however
|
/// Clients require `&mut self`, due to `Tonic::transport::Channel` limitations, however
|
||||||
/// creating new clients is cheap and thus can be used as a work around for ease of use.
|
/// creating new clients is cheap and thus can be used as a work around for ease of use.
|
||||||
pub fn client(&self) -> BigTable {
|
pub fn client(&self) -> BigTable {
|
||||||
let client = {
|
let client = if let Some(access_token) = &self.access_token {
|
||||||
if let Some(ref access_token) = self.access_token {
|
|
||||||
let access_token = access_token.clone();
|
let access_token = access_token.clone();
|
||||||
bigtable_client::BigtableClient::with_interceptor(
|
bigtable_client::BigtableClient::with_interceptor(
|
||||||
self.channel.clone(),
|
self.channel.clone(),
|
||||||
move |mut req: Request<()>| {
|
move |mut req: Request<()>| {
|
||||||
match access_token.read().unwrap().get() {
|
match MetadataValue::from_str(&access_token.get()) {
|
||||||
Ok(access_token) => match MetadataValue::from_str(&access_token) {
|
|
||||||
Ok(authorization_header) => {
|
Ok(authorization_header) => {
|
||||||
req.metadata_mut()
|
req.metadata_mut()
|
||||||
.insert("authorization", authorization_header);
|
.insert("authorization", authorization_header);
|
||||||
}
|
}
|
||||||
Err(err) => warn!("Failed to set authorization header: {}", err),
|
Err(err) => {
|
||||||
},
|
warn!("Failed to set authorization header: {}", err);
|
||||||
Err(err) => warn!("{}", err),
|
}
|
||||||
}
|
}
|
||||||
Ok(req)
|
Ok(req)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
bigtable_client::BigtableClient::new(self.channel.clone())
|
bigtable_client::BigtableClient::new(self.channel.clone())
|
||||||
}
|
|
||||||
};
|
};
|
||||||
BigTable {
|
BigTable {
|
||||||
access_token: self.access_token.clone(),
|
access_token: self.access_token.clone(),
|
||||||
|
@ -202,7 +198,7 @@ impl BigTableConnection {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct BigTable {
|
pub struct BigTable {
|
||||||
access_token: Option<Arc<RwLock<AccessToken>>>,
|
access_token: Option<AccessToken>,
|
||||||
client: bigtable_client::BigtableClient<tonic::transport::Channel>,
|
client: bigtable_client::BigtableClient<tonic::transport::Channel>,
|
||||||
table_prefix: String,
|
table_prefix: String,
|
||||||
}
|
}
|
||||||
|
@ -283,7 +279,7 @@ impl BigTable {
|
||||||
|
|
||||||
async fn refresh_access_token(&self) {
|
async fn refresh_access_token(&self) {
|
||||||
if let Some(ref access_token) = self.access_token {
|
if let Some(ref access_token) = self.access_token {
|
||||||
access_token.write().unwrap().refresh().await;
|
access_token.refresh().await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -298,7 +294,6 @@ impl BigTable {
|
||||||
rows_limit: i64,
|
rows_limit: i64,
|
||||||
) -> Result<Vec<RowKey>> {
|
) -> Result<Vec<RowKey>> {
|
||||||
self.refresh_access_token().await;
|
self.refresh_access_token().await;
|
||||||
|
|
||||||
let response = self
|
let response = self
|
||||||
.client
|
.client
|
||||||
.read_rows(ReadRowsRequest {
|
.read_rows(ReadRowsRequest {
|
||||||
|
|
Loading…
Reference in New Issue