Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add ability to bundle all records from one micro-batch into PutRecords #86

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ Refering $SPARK_HOME to the Spark installation directory.
| kinesis.executor.recordMaxBufferedTime | 1000 (millis) | Specify the maximum buffered time of a record |
| kinesis.executor.maxConnections | 1 | Specify the maximum connections to Kinesis |
| kinesis.executor.aggregationEnabled | true | Specify if records should be aggregated before sending them to Kinesis |
| kniesis.executor.flushwaittimemillis | 100 | Wait time while flushing records to Kinesis on Task End |
| kinesis.executor.flushwaittimemillis | 100 | Wait time while flushing records to Kinesis on Task End |
| kinesis.executor.sink.bundle.records | false | Bundle all records from one micro-batch into PutRecords |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have added "kinesis.executor.recordTtl" - can we add details about this config here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1


## Roadmap
* We need to migrate to DataSource V2 APIs for MicroBatchExecution.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,15 @@ import java.util.concurrent.{ExecutionException, TimeUnit}
import scala.collection.JavaConverters._
import scala.util.control.NonFatal

import com.amazonaws.auth.{AWSStaticCredentialsProvider, BasicAWSCredentials}
import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging

import com.amazonaws.regions.RegionUtils
import com.amazonaws.services.kinesis.AmazonKinesis
import com.amazonaws.services.kinesis.producer.{KinesisProducer, KinesisProducerConfiguration}
import com.google.common.cache._
import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException}

import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging

private[kinesis] object CachedKinesisProducer extends Logging {

private type Producer = KinesisProducer
Expand Down Expand Up @@ -69,6 +68,11 @@ private[kinesis] object CachedKinesisProducer extends Logging {
.map { k => k.drop(8).toString -> producerConfiguration(k) }
.toMap

val recordTtl = kinesisParams.getOrElse(
KinesisSourceProvider.SINK_RECORD_TTL,
KinesisSourceProvider.DEFAULT_SINK_RECORD_TTL)
.toLong

val recordMaxBufferedTime = kinesisParams.getOrElse(
KinesisSourceProvider.SINK_RECORD_MAX_BUFFERED_TIME,
KinesisSourceProvider.DEFAULT_SINK_RECORD_MAX_BUFFERED_TIME)
Expand Down Expand Up @@ -123,6 +127,7 @@ private[kinesis] object CachedKinesisProducer extends Logging {
}

val kinesisProducer = new Producer(new KinesisProducerConfiguration()
.setRecordTtl(recordTtl)
.setRecordMaxBufferedTime(recordMaxBufferedTime)
.setMaxConnections(maxConnections)
.setAggregationEnabled(aggregation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,22 @@ private[kinesis] class KinesisSourceProvider extends DataSourceRegister
*/

override def sourceSchema(
sqlContext: SQLContext,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): (String, StructType) = {
sqlContext: SQLContext,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): (String, StructType) = {
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
validateStreamOptions(caseInsensitiveParams)
require(schema.isEmpty, "Kinesis source has a fixed schema and cannot be set with a custom one")
(shortName(), KinesisReader.kinesisSchema)
}

override def createSource(
sqlContext: SQLContext,
metadataPath: String,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): Source = {
sqlContext: SQLContext,
metadataPath: String,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): Source = {

val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }

Expand Down Expand Up @@ -138,7 +138,7 @@ private[kinesis] class KinesisSourceProvider extends DataSourceRegister
"Sink endpoint url is a required field")
}
if (caseInsensitiveParams.contains(SINK_AGGREGATION_ENABLED) && (
caseInsensitiveParams(SINK_AGGREGATION_ENABLED).trim != "true" &&
caseInsensitiveParams(SINK_AGGREGATION_ENABLED).trim != "true" &&
caseInsensitiveParams(SINK_AGGREGATION_ENABLED).trim != "false"
)) {
throw new IllegalArgumentException(
Expand Down Expand Up @@ -235,14 +235,15 @@ private[kinesis] object KinesisSourceProvider extends Logging {
// Sink Options
private[kinesis] val SINK_STREAM_NAME_KEY = "streamname"
private[kinesis] val SINK_ENDPOINT_URL = "endpointurl"
private[kinesis] val SINK_RECORD_TTL = "kinesis.executor.recordTtl"
private[kinesis] val SINK_RECORD_MAX_BUFFERED_TIME = "kinesis.executor.recordmaxbufferedtime"
private[kinesis] val SINK_MAX_CONNECTIONS = "kinesis.executor.maxconnections"
private[kinesis] val SINK_AGGREGATION_ENABLED = "kinesis.executor.aggregationenabled"
private[kinesis] val SINK_FLUSH_WAIT_TIME_MILLIS = "kniesis.executor.flushwaittimemillis"
private[kinesis] val SINK_FLUSH_WAIT_TIME_MILLIS = "kinesis.executor.flushwaittimemillis"
private[kinesis] val SINK_BUNDLE_RECORDS = "kinesis.executor.sink.bundle.records"


private[kinesis] def getKinesisPosition(
params: Map[String, String]): InitialKinesisPosition = {
private[kinesis] def getKinesisPosition(params: Map[String, String]): InitialKinesisPosition = {
val CURRENT_TIMESTAMP = System.currentTimeMillis
params.get(STARTING_POSITION_KEY).map(_.trim) match {
case Some(position) if position.toLowerCase(Locale.ROOT) == "latest" =>
Expand All @@ -262,14 +263,16 @@ private[kinesis] object KinesisSourceProvider extends Logging {

private[kinesis] val DEFAULT_KINESIS_REGION_NAME: String = "us-east-1"

private[kinesis] val DEFAULT_SINK_RECORD_TTL: String = "30000"

private[kinesis] val DEFAULT_SINK_RECORD_MAX_BUFFERED_TIME: String = "1000"

private[kinesis] val DEFAULT_SINK_MAX_CONNECTIONS: String = "1"

private[kinesis] val DEFAULT_SINK_AGGREGATION: String = "true"

private[kinesis] val DEFAULT_FLUSH_WAIT_TIME_MILLIS: String = "100"
}


private[kinesis] val DEFAULT_SINK_BUNDLE_RECORDS: String = "false"

}
76 changes: 70 additions & 6 deletions src/main/scala/org/apache/spark/sql/kinesis/KinesisWriteTask.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,73 @@ private[kinesis] class KinesisWriteTask(producerConfiguration: Map[String, Strin
s"${KinesisSourceProvider.SINK_FLUSH_WAIT_TIME_MILLIS} has to be a positive integer")
}

private val sinKBundleRecords = Try(producerConfiguration.getOrElse(
KinesisSourceProvider.SINK_BUNDLE_RECORDS,
KinesisSourceProvider.DEFAULT_SINK_BUNDLE_RECORDS).toBoolean).getOrElse {
throw new IllegalArgumentException(
s"${KinesisSourceProvider.SINK_BUNDLE_RECORDS} has to be a boolean value")
}

private var failedWrite: Throwable = _


def execute(iterator: Iterator[InternalRow]): Unit = {

if (sinKBundleRecords) {
bundleExecute(iterator)
} else {
singleExecute(iterator)
}

}

private def bundleExecute(iterator: Iterator[InternalRow]): Unit = {

val groupedIterator: iterator.GroupedIterator[InternalRow] = iterator.grouped(490)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is 490 here? Should it be configurable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1


while (groupedIterator.hasNext) {
val rowList = groupedIterator.next()
sendBundledData(rowList)
}

}

private def sendBundledData(rowList: List[InternalRow]): Unit = {
producer = CachedKinesisProducer.getOrCreate(producerConfiguration)

val kinesisCallBack = new FutureCallback[UserRecordResult]() {

override def onFailure(t: Throwable): Unit = {
if (failedWrite == null && t!= null) {
failedWrite = t
logError(s"Writing to $streamName failed due to ${t.getCause}")
}
}

override def onSuccess(result: UserRecordResult): Unit = {
logDebug(s"Successfully put records: \n " +
s"sequenceNumber=${result.getSequenceNumber}, \n" +
s"shardId=${result.getShardId}, \n" +
s"attempts=${result.getAttempts.size}")
}
}

for (r <- rowList) {

val projectedRow = projection(r)
val partitionKey = projectedRow.getString(0)
val data = projectedRow.getBinary(1)

val future = producer.addUserRecord(streamName, partitionKey, ByteBuffer.wrap(data))

Futures.addCallback(future, kinesisCallBack)

}
}

private def singleExecute(iterator: Iterator[InternalRow]): Unit = {
producer = CachedKinesisProducer.getOrCreate(producerConfiguration)

while (iterator.hasNext && failedWrite == null) {
val currentRow = iterator.next()
val projectedRow = projection(currentRow)
Expand All @@ -56,11 +118,10 @@ private[kinesis] class KinesisWriteTask(producerConfiguration: Map[String, Strin

sendData(partitionKey, data)
}
}

def sendData(partitionKey: String, data: Array[Byte]): String = {
var sentSeqNumbers = new String
}

private def sendData(partitionKey: String, data: Array[Byte]): Unit = {
val future = producer.addUserRecord(streamName, partitionKey, ByteBuffer.wrap(data))

val kinesisCallBack = new FutureCallback[UserRecordResult]() {
Expand All @@ -73,14 +134,17 @@ private[kinesis] class KinesisWriteTask(producerConfiguration: Map[String, Strin
}

override def onSuccess(result: UserRecordResult): Unit = {
val shardId = result.getShardId
sentSeqNumbers = result.getSequenceNumber
logDebug(s"Successfully put records: \n " +
s"sequenceNumber=${result.getSequenceNumber}, \n" +
s"shardId=${result.getShardId}, \n" +
s"attempts=${result.getAttempts.size}")
}

}

Futures.addCallback(future, kinesisCallBack)

producer.flushSync()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@leslieyanyan @itsvikramagr
The slowness is on account of this function call producer.flushSync(). Please refer my comment here: #81 (review)

The new code in this PR is showing improved performance because method sendBundledData() doesn't have this function call producer.flushSync()

We'll need to separately evaluate how much performance impact we're getting by using GroupedIterator instead of normal iterator.

sentSeqNumbers
}

private def flushRecordsIfNecessary(): Unit = {
Expand Down