Skip to content

Commit

Permalink
Custom AWS credentials provider should fall back to default credentia…
Browse files Browse the repository at this point in the history
…ls chain (close #42)
  • Loading branch information
istreeter committed Dec 17, 2023
1 parent fef63bd commit 4d1b431
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 27 deletions.
5 changes: 4 additions & 1 deletion config/config.aws.reference.hocon
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"good": {
# -- URI of the bucket where the data lake will be written (required)
# -- For a GCS bucket, the uri should start with `gs://`
"location": "gs://my-bucket/events
"location": "s3a://my-bucket/events
# -- Atomic columns which should be brought to the "left-hand-side" of the events table, to
# -- enable Delta's Data Skipping feature.
Expand Down Expand Up @@ -92,6 +92,9 @@
"conf": {
# -- E.g. to enable the spark ui for debugging:
"spark.ui.enabled": true

# -- E.g. to change credentials provider
"fs.s3a.aws.credentials.provider": "com.amazonaws.auth.InstanceProfileCredentialsProvider"
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,34 @@ import software.amazon.awssdk.awscore.defaultsmode.DefaultsMode
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest
import software.amazon.awssdk.services.sts.StsClient
import software.amazon.awssdk.auth.credentials.{AwsCredentials, AwsCredentialsProvider}
import software.amazon.awssdk.auth.credentials.{AwsCredentials, AwsCredentialsProvider, DefaultCredentialsProvider}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.s3a.{Constants => S3aConstants}

import java.net.URI
import java.util.concurrent.TimeUnit

/**
* A credentials provider that uses STS to assume a role
* A credentials provider that can use STS to assume a role
*
* Similar to hadoop's in-built `AssumedRoleCredentialsProvider` but with support for an external id
* Similar to hadoop's in-built `AssumedRoleCredentialsProvider` but with support for an external
* id.
*
* To enable STS, the spark configuration must contain four parameters:
* - "fs.s3a.assumed.role.session.name": The AWS iam session name. A default is provided in
* application.conf.
* - "fs.s3a.assumed.role.session.duration": The AWS iam session duration. A default is provided
* in application.conf.
* - "fs.s3a.assumed.role.arn": ARN of the AWS role to assume.
* - "fs.s3a.assumed.role.session.external.id": External ID to provide when assuming the role.
*
* If any required parameter is missing, we fall back to using the default AWS credentials chain,
* e.g. environment variables, instance profile, or whatever else.
*
* @param delegate
* The configured credentials provider to which we delegate requests for credentials
*/
class AssumedRoleCredentialsProvider(delegate: StsAssumeRoleCredentialsProvider) extends AwsCredentialsProvider {
class AssumedRoleCredentialsProvider(delegate: AwsCredentialsProvider) extends AwsCredentialsProvider {

/**
* Standard constructor invoked by hadoop
Expand All @@ -37,23 +49,35 @@ class AssumedRoleCredentialsProvider(delegate: StsAssumeRoleCredentialsProvider)
* The hadoop configuration, provided via spark configuration
*/
def this(fsUri: URI, conf: Configuration) =
this(
StsAssumeRoleCredentialsProvider.builder
.stsClient {
StsClient.builder.defaultsMode(DefaultsMode.AUTO).build
}
.refreshRequest { (req: AssumeRoleRequest.Builder) =>
req
.roleArn(conf.getTrimmed(S3aConstants.ASSUMED_ROLE_ARN))
.roleSessionName(conf.getTrimmed(S3aConstants.ASSUMED_ROLE_SESSION_NAME))
.durationSeconds(conf.getTimeDuration(S3aConstants.ASSUMED_ROLE_SESSION_DURATION, 0L, TimeUnit.SECONDS).toInt)
.externalId(conf.getTrimmed("fs.s3a.assumed.role.session.external.id"))
()
}
.build
)
this(AssumedRoleCredentialsProvider.getDelegate(conf))

override def resolveCredentials(): AwsCredentials =
delegate.resolveCredentials()

}

object AssumedRoleCredentialsProvider {

private def getDelegate(conf: Configuration): AwsCredentialsProvider = {
val stsOpt = for {
roleArn <- Option(conf.getTrimmed(S3aConstants.ASSUMED_ROLE_ARN))
roleSessionName <- Option(conf.getTrimmed(S3aConstants.ASSUMED_ROLE_SESSION_NAME))
durationSeconds <- Option(conf.getTimeDuration(S3aConstants.ASSUMED_ROLE_SESSION_DURATION, 0L, TimeUnit.SECONDS).toInt)
externalId <- Option(conf.getTrimmed("fs.s3a.assumed.role.session.external.id"))
} yield StsAssumeRoleCredentialsProvider.builder
.stsClient {
StsClient.builder.defaultsMode(DefaultsMode.AUTO).build
}
.refreshRequest { (req: AssumeRoleRequest.Builder) =>
req
.roleArn(roleArn)
.roleSessionName(roleSessionName)
.durationSeconds(durationSeconds)
.externalId(externalId)
()
}
.build
stsOpt.getOrElse(DefaultCredentialsProvider.create)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,20 @@ class AssumedRoleCredentialsProviderV1(delegate: AssumedRoleCredentialsProvider)
def this(fsUri: URI, conf: Configuration) =
this(new AssumedRoleCredentialsProvider(fsUri, conf))

override def getCredentials(): AWSCredentials = {
val v2 = delegate.resolveCredentials().asInstanceOf[AwsSessionCredentialsV2]
new AWSSessionCredentials {
def getSessionToken() = v2.sessionToken()
def getAWSAccessKeyId() = v2.accessKeyId()
def getAWSSecretKey() = v2.secretAccessKey()
override def getCredentials(): AWSCredentials =
delegate.resolveCredentials() match {
case v2: AwsSessionCredentialsV2 =>
new AWSSessionCredentials {
def getAWSAccessKeyId() = v2.accessKeyId()
def getAWSSecretKey() = v2.secretAccessKey()
def getSessionToken() = v2.sessionToken()
}
case v2 =>
new AWSCredentials {
def getAWSAccessKeyId() = v2.accessKeyId()
def getAWSSecretKey() = v2.secretAccessKey()
}
}
}

override def refresh(): Unit = ()

Expand Down

0 comments on commit 4d1b431

Please sign in to comment.