From 0c2afb84b3a1a228b008d23e1186073a572d6c21 Mon Sep 17 00:00:00 2001 From: Simon Fell Date: Mon, 8 Jan 2024 09:24:57 -0800 Subject: [PATCH] perform background refresh of tokens --- src/authentication_manager.rs | 114 +++++++++++++++++++++++++++++---- src/custom_service_account.rs | 6 +- src/default_authorized_user.rs | 6 +- src/default_service_account.rs | 6 +- src/gcloud_authorized_user.rs | 6 +- 5 files changed, 120 insertions(+), 18 deletions(-) diff --git a/src/authentication_manager.rs b/src/authentication_manager.rs index 108ff98..3071498 100644 --- a/src/authentication_manager.rs +++ b/src/authentication_manager.rs @@ -1,5 +1,11 @@ +use std::collections::hash_map::Entry::{Occupied, Vacant}; +use std::collections::HashMap; +use std::sync::Arc; + use async_trait::async_trait; -use tokio::sync::Mutex; +use chrono::{Duration, Utc}; +use tokio::sync::{Mutex, OwnedMutexGuard}; +use tracing::{debug, info, warn}; use crate::custom_service_account::CustomServiceAccount; use crate::default_authorized_user::ConfigDefaultCredentials; @@ -13,6 +19,12 @@ pub(crate) trait ServiceAccount: Send + Sync { async fn project_id(&self, client: &HyperClient) -> Result; fn get_token(&self, scopes: &[&str]) -> Option; async fn refresh_token(&self, client: &HyperClient, scopes: &[&str]) -> Result; + fn get_style(&self) -> TokenStyle; +} + +pub(crate) enum TokenStyle { + Account, + AccountAndScopes, } /// Authentication manager is responsible for caching and obtaining credentials for the required @@ -21,10 +33,13 @@ pub(crate) trait ServiceAccount: Send + Sync { /// Construct the authentication manager with [`AuthenticationManager::new()`] or by creating /// a [`CustomServiceAccount`], then converting it into an `AuthenticationManager` using the `From` /// impl. -pub struct AuthenticationManager { - pub(crate) client: HyperClient, - pub(crate) service_account: Box, - refresh_mutex: Mutex<()>, +#[derive(Clone)] +pub struct AuthenticationManager(Arc); + +struct AuthManagerInner { + client: HyperClient, + service_account: Box, + refresh_lock: RefreshLock, } impl AuthenticationManager { @@ -80,40 +95,79 @@ impl AuthenticationManager { } fn build(client: HyperClient, service_account: impl ServiceAccount + 'static) -> Self { - Self { + let refresh_lock = RefreshLock::new(service_account.get_style()); + Self(Arc::new(AuthManagerInner { client, service_account: Box::new(service_account), - refresh_mutex: Mutex::new(()), - } + refresh_lock, + })) } /// Requests Bearer token for the provided scope /// /// Token can be used in the request authorization header in format "Bearer {token}" pub async fn get_token(&self, scopes: &[&str]) -> Result { - let token = self.service_account.get_token(scopes); + let token = self.0.service_account.get_token(scopes); + if let Some(token) = token.filter(|token| !token.has_expired()) { + let valid_for = token.expires_at().signed_duration_since(Utc::now()); + if valid_for < Duration::seconds(60) { + debug!(?valid_for, "gcp_auth token expires soon!"); + + let lock = self.0.refresh_lock.lock_for_scopes(scopes).await; + match lock.try_lock_owned() { + Err(_) => { + // already being refreshed. + } + Ok(guard) => { + let inner = self.clone(); + let scopes: Vec = scopes.iter().map(|s| s.to_string()).collect(); + tokio::spawn(async move { + inner.background_refresh(scopes, guard).await; + }); + } + } + } return Ok(token); } - let _guard = self.refresh_mutex.lock().await; + warn!("starting inline refresh of gcp auth token"); + let lock = self.0.refresh_lock.lock_for_scopes(scopes).await; + let _guard = lock.lock().await; // Check if refresh happened while we were waiting. - let token = self.service_account.get_token(scopes); + let token = self.0.service_account.get_token(scopes); if let Some(token) = token.filter(|token| !token.has_expired()) { return Ok(token); } - self.service_account - .refresh_token(&self.client, scopes) + self.0 + .service_account + .refresh_token(&self.0.client, scopes) .await } + async fn background_refresh(&self, scopes: Vec, _lock: OwnedMutexGuard<()>) { + info!("gcp_auth starting background refresh of auth token"); + let scope_refs: Vec<&str> = scopes.iter().map(|s| s.as_str()).collect(); + match self + .0 + .service_account + .refresh_token(&self.0.client, &scope_refs) + .await + { + Ok(t) => { + info!(valid_for=?t.expires_at().signed_duration_since(Utc::now()), "gcp auth completed background token refresh") + } + Err(err) => warn!(?err, "gcp_auth background token refresh failed"), + } + } + /// Request the project ID for the authenticating account /// /// This is only available for service account-based authentication methods. pub async fn project_id(&self) -> Result { - self.service_account.project_id(&self.client).await + self.0.service_account.project_id(&self.0.client).await } } @@ -122,3 +176,35 @@ impl From for AuthenticationManager { Self::build(types::client(), service_account) } } + +enum RefreshLock { + One(Arc>), + ByScopes(Mutex, Arc>>>), +} + +impl RefreshLock { + fn new(style: TokenStyle) -> Self { + match style { + TokenStyle::Account => RefreshLock::One(Arc::new(Mutex::new(()))), + TokenStyle::AccountAndScopes => RefreshLock::ByScopes(Mutex::new(HashMap::new())), + } + } + + async fn lock_for_scopes(&self, scopes: &[&str]) -> Arc> { + match self { + RefreshLock::One(mutex) => mutex.clone(), + RefreshLock::ByScopes(mutexes) => { + let scopes_key: Vec<_> = scopes.iter().map(|s| s.to_string()).collect(); + let mut scope_locks = mutexes.lock().await; + match scope_locks.entry(scopes_key) { + Occupied(e) => e.get().clone(), + Vacant(v) => { + let lock = Arc::new(Mutex::new(())); + v.insert(lock.clone()); + lock + } + } + } + } + } +} diff --git a/src/custom_service_account.rs b/src/custom_service_account.rs index 43edf44..9dbf1b9 100644 --- a/src/custom_service_account.rs +++ b/src/custom_service_account.rs @@ -6,7 +6,7 @@ use std::sync::RwLock; use async_trait::async_trait; use serde::{Deserialize, Serialize}; -use crate::authentication_manager::ServiceAccount; +use crate::authentication_manager::{ServiceAccount, TokenStyle}; use crate::error::Error; use crate::types::{HyperClient, Signer, Token}; use crate::util::HyperExt; @@ -80,6 +80,10 @@ impl CustomServiceAccount { #[async_trait] impl ServiceAccount for CustomServiceAccount { + fn get_style(&self) -> TokenStyle { + TokenStyle::AccountAndScopes + } + async fn project_id(&self, _: &HyperClient) -> Result { match &self.credentials.project_id { Some(pid) => Ok(pid.clone()), diff --git a/src/default_authorized_user.rs b/src/default_authorized_user.rs index 81e050c..984d9a2 100644 --- a/src/default_authorized_user.rs +++ b/src/default_authorized_user.rs @@ -6,7 +6,7 @@ use hyper::body::Body; use hyper::{Method, Request}; use serde::{Deserialize, Serialize}; -use crate::authentication_manager::ServiceAccount; +use crate::authentication_manager::{ServiceAccount, TokenStyle}; use crate::error::Error; use crate::types::{HyperClient, Token}; use crate::util::HyperExt; @@ -78,6 +78,10 @@ impl ConfigDefaultCredentials { #[async_trait] impl ServiceAccount for ConfigDefaultCredentials { + fn get_style(&self) -> TokenStyle { + TokenStyle::Account + } + async fn project_id(&self, _: &HyperClient) -> Result { self.credentials .quota_project_id diff --git a/src/default_service_account.rs b/src/default_service_account.rs index c1cf32a..33beac3 100644 --- a/src/default_service_account.rs +++ b/src/default_service_account.rs @@ -5,7 +5,7 @@ use async_trait::async_trait; use hyper::body::Body; use hyper::{Method, Request}; -use crate::authentication_manager::ServiceAccount; +use crate::authentication_manager::{ServiceAccount, TokenStyle}; use crate::error::Error; use crate::types::{HyperClient, Token}; use crate::util::HyperExt; @@ -62,6 +62,10 @@ impl MetadataServiceAccount { #[async_trait] impl ServiceAccount for MetadataServiceAccount { + fn get_style(&self) -> TokenStyle { + TokenStyle::Account + } + async fn project_id(&self, client: &HyperClient) -> Result { tracing::debug!("Getting project ID from GCP instance metadata server"); let req = Self::build_token_request(Self::DEFAULT_PROJECT_ID_GCP_URI); diff --git a/src/gcloud_authorized_user.rs b/src/gcloud_authorized_user.rs index e2c58cb..e84f670 100644 --- a/src/gcloud_authorized_user.rs +++ b/src/gcloud_authorized_user.rs @@ -6,7 +6,7 @@ use std::time::Duration; use async_trait::async_trait; use which::which; -use crate::authentication_manager::ServiceAccount; +use crate::authentication_manager::{ServiceAccount, TokenStyle}; use crate::error::Error; use crate::error::Error::{GCloudError, GCloudNotFound, GCloudParseError}; use crate::types::HyperClient; @@ -46,6 +46,10 @@ impl GCloudAuthorizedUser { #[async_trait] impl ServiceAccount for GCloudAuthorizedUser { + fn get_style(&self) -> TokenStyle { + TokenStyle::Account + } + async fn project_id(&self, _: &HyperClient) -> Result { self.project_id.clone().ok_or(Error::NoProjectId) }