diff --git a/README.md b/README.md index fd73728..ba27808 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,23 @@ shouldSortFiles|true|whether to sort files based on timestamp while listing them useInstanceProfileCredentials|false|Whether to use EC2 instance profile credentials for connecting to Amazon SQS maxFilesPerTrigger|no default value|maximum number of files to process in a microbatch maxFileAge|7d|Maximum age of a file that can be found in this directory +basePath|no default value|Base path in case of partitioned S3 data. Eg. `s3://bucket/basedDir/part1=10/part2=20/file.json` will have basePath as `s3://bucket/basedDir/` + +## Using Parrtitioned S3 Bucket + +In case your S3 bucket is partitioned, your schema must contain both data columns as well as partition +columns. Moreover, partition columns need to have `isPartitioned` set to `true` in their metadata. + +Example: +``` +val metaData = (new MetadataBuilder).putString("isPartitioned", "true").build() + +val partitionedSchema = new StructType().add(StructField( + "col1", IntegerType, true, metaData)) +``` + +Also, `basePath` needs to be specified in the options in case of partitioned S3 bucket. +Specifying partitioned columns without specifying the `basePath` will throw an error. ## Example diff --git a/src/main/scala/org/apache/spark/sql/streaming/sqs/SqsSource.scala b/src/main/scala/org/apache/spark/sql/streaming/sqs/SqsSource.scala index 4f301cc..140111d 100644 --- a/src/main/scala/org/apache/spark/sql/streaming/sqs/SqsSource.scala +++ b/src/main/scala/org/apache/spark/sql/streaming/sqs/SqsSource.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.streaming.sqs import java.net.URI +import scala.util.Try + import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging @@ -34,6 +36,8 @@ class SqsSource(sparkSession: SparkSession, options: Map[String, String], override val schema: StructType) extends Source with Logging { + import SqsSource._ + private val sourceOptions = new SqsSourceOptions(options) private val hadoopConf = sparkSession.sessionState.newHadoopConf() @@ -50,6 +54,22 @@ class SqsSource(sparkSession: SparkSession, private val shouldSortFiles = sourceOptions.shouldSortFiles + private val partitionColumnNames = { + schema.fields.filter(field => field.metadata.contains(IS_PARTITIONED) && + Try(field.metadata.getBoolean(IS_PARTITIONED)).toOption.getOrElse(throw new + IllegalArgumentException(s"$IS_PARTITIONED for column ${field.name} must be true or " + + s"false")) + ).map(_.name) + } + + private val optionsWithBasePath = if (!partitionColumnNames.isEmpty) { + val basePartitionsPath = sourceOptions.basePath.getOrElse(throw new IllegalArgumentException( + s"$BASE_PATH is mandatory if schema contains partitionColumns")) + options + (BASE_PATH -> basePartitionsPath) + } else { + options + } + private val sqsClient = new SqsClient(sourceOptions, hadoopConf) metadataLog.allFiles().foreach { entry => @@ -75,8 +95,9 @@ class SqsSource(sparkSession: SparkSession, sparkSession, paths = files.map(f => new Path(new URI(f.path)).toString), userSpecifiedSchema = Some(schema), + partitionColumns = partitionColumnNames, className = fileFormatClassName, - options = options) + options = optionsWithBasePath) Dataset.ofRows(sparkSession, LogicalRelation(newDataSource.resolveRelation( checkFilesExist = false), isStreaming = true)) } @@ -138,3 +159,9 @@ class SqsSource(sparkSession: SparkSession, } +object SqsSource { + + val IS_PARTITIONED = "isPartitioned" + val BASE_PATH = "basePath" +} + diff --git a/src/main/scala/org/apache/spark/sql/streaming/sqs/SqsSourceOptions.scala b/src/main/scala/org/apache/spark/sql/streaming/sqs/SqsSourceOptions.scala index a4c0cc1..4704819 100644 --- a/src/main/scala/org/apache/spark/sql/streaming/sqs/SqsSourceOptions.scala +++ b/src/main/scala/org/apache/spark/sql/streaming/sqs/SqsSourceOptions.scala @@ -37,6 +37,8 @@ class SqsSourceOptions(parameters: CaseInsensitiveMap[String]) extends Logging { } } + val basePath: Option[String] = parameters.get("basePath") + /** * Maximum age of a file that can be found in this directory, before it is ignored. For the * first batch all files will be considered valid. diff --git a/src/test/scala/org/apache/spark/sql/streaming/sqs/SqsSourceSuite.scala b/src/test/scala/org/apache/spark/sql/streaming/sqs/SqsSourceSuite.scala new file mode 100644 index 0000000..159a75c --- /dev/null +++ b/src/test/scala/org/apache/spark/sql/streaming/sqs/SqsSourceSuite.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.sqs + +import java.util.Locale + +import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest} +import org.apache.spark.sql.types._ + +class SqsSourceSuite extends StreamTest { + + import org.apache.spark.sql.streaming.sqs.SqsSource._ + + test("partitioned data source - base path not specified") { + + var query: StreamingQuery = null + + val expectedMsg = s"$BASE_PATH is mandatory if schema contains partitionColumns" + + try { + val errorMessage = intercept[StreamingQueryException] { + + val metaData = (new MetadataBuilder).putBoolean(IS_PARTITIONED, true).build() + + val partitionedSchema = new StructType().add(StructField( + "col1", IntegerType, true, metaData)) + + val reader = spark + .readStream + .format("s3-sqs") + .option("sqsUrl", "https://DUMMY_URL") + .option("fileFormat", "json") + .option("region", "us-east-1") + .schema(partitionedSchema) + .load() + + query = reader.writeStream + .queryName("testQuery") + .format("memory") + .start() + + query.processAllAvailable() + }.getMessage + assert(errorMessage.toLowerCase(Locale.ROOT).contains(expectedMsg.toLowerCase(Locale.ROOT))) + } finally { + if (query != null) { + // terminating streaming query if necessary + query.stop() + } + } + + } + + test("isPartitioned doesn't contain true or false") { + + var query: StreamingQuery = null + + val columName = "col1" + + val expectedMsg = s"$IS_PARTITIONED for column $columName must be true or false" + + try { + val errorMessage = intercept[StreamingQueryException] { + + val metaData = (new MetadataBuilder).putString(IS_PARTITIONED, "x").build() + + val partitionedSchema = new StructType().add(StructField( + "col1", IntegerType, true, metaData)) + + val reader = spark + .readStream + .format("s3-sqs") + .option("sqsUrl", "https://DUMMY_URL") + .option("fileFormat", "json") + .option("region", "us-east-1") + .schema(partitionedSchema) + .load() + + query = reader.writeStream + .format("memory") + .queryName("testQuery") + .start() + + query.processAllAvailable() + }.getMessage + assert(errorMessage.toLowerCase(Locale.ROOT).contains(expectedMsg.toLowerCase(Locale.ROOT))) + } finally { + if (query != null) { + // terminating streaming query if necessary + query.stop() + } + } + + } + +} +