Skip to content

Commit

Permalink
Add cache to AWS token provider
Browse files Browse the repository at this point in the history
  • Loading branch information
mwylde committed Jan 7, 2025
1 parent 30bdba9 commit 79ac889
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 16 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ datafusion-functions-window = {git = 'https://github.com/ArroyoSystems/arrow-dat

datafusion-functions-json = {git = 'https://github.com/ArroyoSystems/datafusion-functions-json', branch = 'datafusion_43'}

object_store = { git = 'http://github.com/ArroyoSystems/arrow-rs', branch = 'object_store_0.11.1/arroyo' }
# object_store = { git = 'http://github.com/ArroyoSystems/arrow-rs', branch = 'object_store_0.11.1/arroyo' }
object_store = { git = 'http://github.com/ArroyoSystems/arrow-rs', branch = 'public_token_cache' }

cornucopia_async = { git = "https://github.com/ArroyoSystems/cornucopia", branch = "sqlite" }
cornucopia = { git = "https://github.com/ArroyoSystems/cornucopia", branch = "sqlite" }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use bincode::config;
use prost::Message;
use std::fmt::Debug;
use std::{collections::HashMap, time::SystemTime};
use tracing::info;
use tracing::debug;

pub struct TwoPhaseCommitterOperator<TPC: TwoPhaseCommitter> {
committer: TPC,
Expand Down Expand Up @@ -86,7 +86,7 @@ impl<TPC: TwoPhaseCommitter> TwoPhaseCommitterOperator<TPC> {
mut commit_data: HashMap<String, HashMap<u32, Vec<u8>>>,
ctx: &mut OperatorContext,
) {
info!("received commit message");
debug!("received commit message");
let pre_commits = match self.committer.commit_strategy() {
CommitStrategy::PerSubtask => std::mem::take(&mut self.pre_commits),
CommitStrategy::PerOperator => {
Expand Down
49 changes: 37 additions & 12 deletions crates/arroyo-storage/src/aws.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use crate::StorageError;
use aws_config::BehaviorVersion;
use aws_credential_types::provider::ProvideCredentials;
use object_store::{aws::AwsCredential, CredentialProvider};
use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider};
use object_store::{aws::AwsCredential, CredentialProvider, TemporaryToken, TokenCache};
use std::error::Error;
use std::sync::Arc;
use std::time::{Duration, Instant};

pub struct ArroyoCredentialProvider {
cache: TokenCache<Arc<AwsCredential>>,
provider: aws_credential_types::provider::SharedCredentialsProvider,
}

Expand All @@ -28,6 +31,7 @@ impl ArroyoCredentialProvider {
.clone();

Ok(Self {
cache: TokenCache::default().with_min_ttl(Duration::from_secs(60)),
provider: credentials,
})
}
Expand All @@ -41,21 +45,42 @@ impl ArroyoCredentialProvider {
}
}

async fn get_token(
provider: &SharedCredentialsProvider,
) -> Result<TemporaryToken<Arc<AwsCredential>>, Box<dyn Error + Send + Sync>> {
let creds = provider
.provide_credentials()
.await
.map_err(|e| object_store::Error::Generic {
store: "S3",
source: Box::new(e),
})?;

let expiry = creds
.expiry()
.map(|exp| Instant::now() + exp.elapsed().unwrap_or_default());

Ok(TemporaryToken {
token: Arc::new(AwsCredential {
key_id: creds.access_key_id().to_string(),
secret_key: creds.secret_access_key().to_string(),
token: creds.session_token().map(ToString::to_string),
}),
expiry,
})
}

#[async_trait::async_trait]
impl CredentialProvider for ArroyoCredentialProvider {
type Credential = AwsCredential;

async fn get_credential(&self) -> object_store::Result<Arc<Self::Credential>> {
let creds = self.provider.provide_credentials().await.map_err(|e| {
object_store::Error::Generic {
self.cache
.get_or_insert_with(|| get_token(&self.provider))
.await
.map_err(|e| object_store::Error::Generic {
store: "S3",
source: Box::new(e),
}
})?;
Ok(Arc::new(AwsCredential {
key_id: creds.access_key_id().to_string(),
secret_key: creds.secret_access_key().to_string(),
token: creds.session_token().map(ToString::to_string),
}))
source: e,
})
}
}

0 comments on commit 79ac889

Please sign in to comment.