From 4d1b431884afc7729f2d251353714c08b4ba459f Mon Sep 17 00:00:00 2001 From: Ian Streeter Date: Fri, 15 Dec 2023 22:55:28 +0000 Subject: [PATCH] Custom AWS credentials provider should fall back to default credentials chain (close #42) --- config/config.aws.reference.hocon | 5 +- .../AssumedRoleCredentialsProvider.scala | 62 +++++++++++++------ .../AssumedRoleCredentialsProviderV1.scala | 20 +++--- 3 files changed, 60 insertions(+), 27 deletions(-) diff --git a/config/config.aws.reference.hocon b/config/config.aws.reference.hocon index f2e36ba6..91ff928b 100644 --- a/config/config.aws.reference.hocon +++ b/config/config.aws.reference.hocon @@ -37,7 +37,7 @@ "good": { # -- URI of the bucket where the data lake will be written (required) # -- For a GCS bucket, the uri should start with `gs://` - "location": "gs://my-bucket/events + "location": "s3a://my-bucket/events # -- Atomic columns which should be brought to the "left-hand-side" of the events table, to # -- enable Delta's Data Skipping feature. @@ -92,6 +92,9 @@ "conf": { # -- E.g. to enable the spark ui for debugging: "spark.ui.enabled": true + + # -- E.g. to change credentials provider + "fs.s3a.aws.credentials.provider": "com.amazonaws.auth.InstanceProfileCredentialsProvider" } } diff --git a/modules/aws/src/main/scala/com.snowplowanalytics.snowplow.lakes/AssumedRoleCredentialsProvider.scala b/modules/aws/src/main/scala/com.snowplowanalytics.snowplow.lakes/AssumedRoleCredentialsProvider.scala index a965c20e..64abe285 100644 --- a/modules/aws/src/main/scala/com.snowplowanalytics.snowplow.lakes/AssumedRoleCredentialsProvider.scala +++ b/modules/aws/src/main/scala/com.snowplowanalytics.snowplow.lakes/AssumedRoleCredentialsProvider.scala @@ -11,7 +11,7 @@ import software.amazon.awssdk.awscore.defaultsmode.DefaultsMode import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider import software.amazon.awssdk.services.sts.model.AssumeRoleRequest import software.amazon.awssdk.services.sts.StsClient -import software.amazon.awssdk.auth.credentials.{AwsCredentials, AwsCredentialsProvider} +import software.amazon.awssdk.auth.credentials.{AwsCredentials, AwsCredentialsProvider, DefaultCredentialsProvider} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.s3a.{Constants => S3aConstants} @@ -19,14 +19,26 @@ import java.net.URI import java.util.concurrent.TimeUnit /** - * A credentials provider that uses STS to assume a role + * A credentials provider that can use STS to assume a role * - * Similar to hadoop's in-built `AssumedRoleCredentialsProvider` but with support for an external id + * Similar to hadoop's in-built `AssumedRoleCredentialsProvider` but with support for an external + * id. + * + * To enable STS, the spark configuration must contain four parameters: + * - "fs.s3a.assumed.role.session.name": The AWS iam session name. A default is provided in + * application.conf. + * - "fs.s3a.assumed.role.session.duration": The AWS iam session duration. A default is provided + * in application.conf. + * - "fs.s3a.assumed.role.arn": ARN of the AWS role to assume. + * - "fs.s3a.assumed.role.session.external.id": External ID to provide when assuming the role. + * + * If any required parameter is missing, we fall back to using the default AWS credentials chain, + * e.g. environment variables, instance profile, or whatever else. * * @param delegate * The configured credentials provider to which we delegate requests for credentials */ -class AssumedRoleCredentialsProvider(delegate: StsAssumeRoleCredentialsProvider) extends AwsCredentialsProvider { +class AssumedRoleCredentialsProvider(delegate: AwsCredentialsProvider) extends AwsCredentialsProvider { /** * Standard constructor invoked by hadoop @@ -37,23 +49,35 @@ class AssumedRoleCredentialsProvider(delegate: StsAssumeRoleCredentialsProvider) * The hadoop configuration, provided via spark configuration */ def this(fsUri: URI, conf: Configuration) = - this( - StsAssumeRoleCredentialsProvider.builder - .stsClient { - StsClient.builder.defaultsMode(DefaultsMode.AUTO).build - } - .refreshRequest { (req: AssumeRoleRequest.Builder) => - req - .roleArn(conf.getTrimmed(S3aConstants.ASSUMED_ROLE_ARN)) - .roleSessionName(conf.getTrimmed(S3aConstants.ASSUMED_ROLE_SESSION_NAME)) - .durationSeconds(conf.getTimeDuration(S3aConstants.ASSUMED_ROLE_SESSION_DURATION, 0L, TimeUnit.SECONDS).toInt) - .externalId(conf.getTrimmed("fs.s3a.assumed.role.session.external.id")) - () - } - .build - ) + this(AssumedRoleCredentialsProvider.getDelegate(conf)) override def resolveCredentials(): AwsCredentials = delegate.resolveCredentials() } + +object AssumedRoleCredentialsProvider { + + private def getDelegate(conf: Configuration): AwsCredentialsProvider = { + val stsOpt = for { + roleArn <- Option(conf.getTrimmed(S3aConstants.ASSUMED_ROLE_ARN)) + roleSessionName <- Option(conf.getTrimmed(S3aConstants.ASSUMED_ROLE_SESSION_NAME)) + durationSeconds <- Option(conf.getTimeDuration(S3aConstants.ASSUMED_ROLE_SESSION_DURATION, 0L, TimeUnit.SECONDS).toInt) + externalId <- Option(conf.getTrimmed("fs.s3a.assumed.role.session.external.id")) + } yield StsAssumeRoleCredentialsProvider.builder + .stsClient { + StsClient.builder.defaultsMode(DefaultsMode.AUTO).build + } + .refreshRequest { (req: AssumeRoleRequest.Builder) => + req + .roleArn(roleArn) + .roleSessionName(roleSessionName) + .durationSeconds(durationSeconds) + .externalId(externalId) + () + } + .build + stsOpt.getOrElse(DefaultCredentialsProvider.create) + } + +} diff --git a/modules/aws/src/main/scala/com.snowplowanalytics.snowplow.lakes/AssumedRoleCredentialsProviderV1.scala b/modules/aws/src/main/scala/com.snowplowanalytics.snowplow.lakes/AssumedRoleCredentialsProviderV1.scala index cbb6982f..bb9496f2 100644 --- a/modules/aws/src/main/scala/com.snowplowanalytics.snowplow.lakes/AssumedRoleCredentialsProviderV1.scala +++ b/modules/aws/src/main/scala/com.snowplowanalytics.snowplow.lakes/AssumedRoleCredentialsProviderV1.scala @@ -24,14 +24,20 @@ class AssumedRoleCredentialsProviderV1(delegate: AssumedRoleCredentialsProvider) def this(fsUri: URI, conf: Configuration) = this(new AssumedRoleCredentialsProvider(fsUri, conf)) - override def getCredentials(): AWSCredentials = { - val v2 = delegate.resolveCredentials().asInstanceOf[AwsSessionCredentialsV2] - new AWSSessionCredentials { - def getSessionToken() = v2.sessionToken() - def getAWSAccessKeyId() = v2.accessKeyId() - def getAWSSecretKey() = v2.secretAccessKey() + override def getCredentials(): AWSCredentials = + delegate.resolveCredentials() match { + case v2: AwsSessionCredentialsV2 => + new AWSSessionCredentials { + def getAWSAccessKeyId() = v2.accessKeyId() + def getAWSSecretKey() = v2.secretAccessKey() + def getSessionToken() = v2.sessionToken() + } + case v2 => + new AWSCredentials { + def getAWSAccessKeyId() = v2.accessKeyId() + def getAWSSecretKey() = v2.secretAccessKey() + } } - } override def refresh(): Unit = ()