Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kafka sink to open fewer threads #431

Merged
merged 1 commit into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,24 @@ package sinks

import cats.implicits._
import cats.effect._

import org.slf4j.LoggerFactory

import fs2.kafka._
import org.typelevel.log4cats.Logger
import org.typelevel.log4cats.slf4j.Slf4jLogger
import org.apache.kafka.clients.producer.{Callback, KafkaProducer, ProducerRecord, RecordMetadata}

import com.snowplowanalytics.snowplow.collector.core.{Config, Sink}

import scala.jdk.CollectionConverters._

/**
* Kafka Sink for the Scala Stream Collector
*/
class KafkaSink[F[_]: Async](
class KafkaSink[F[_]: Async: Logger](
val maxBytes: Int,
isHealthyState: Ref[F, Boolean],
kafkaProducer: KafkaProducer[F, String, Array[Byte]],
kafkaProducer: KafkaProducer[String, Array[Byte]],
topicName: String
) extends Sink[F] {

private lazy val log = LoggerFactory.getLogger(getClass())

override def isHealthy: F[Boolean] = isHealthyState.get

/**
Expand All @@ -40,26 +39,53 @@ class KafkaSink[F[_]: Async](
* @param events The list of events to send
* @param key The partition key to use
*/
override def storeRawEvents(events: List[Array[Byte]], key: String): F[Unit] = {
log.debug(s"Writing ${events.size} Thrift records to Kafka topic $topicName at key $key")
val records = ProducerRecords(events.map(e => (ProducerRecord(topicName, key, e))))
kafkaProducer.produce(records).onError { case _: Throwable => isHealthyState.set(false) } *> isHealthyState.set(
true
)
}
override def storeRawEvents(events: List[Array[Byte]], key: String): F[Unit] =
Logger[F].debug(s"Writing ${events.size} Thrift records to Kafka topic $topicName at key $key") *>
events.traverse_ { e =>
def go: F[Unit] =
Async[F]
.async_[Unit] { cb =>
val record = new ProducerRecord(topicName, key, e)
kafkaProducer.send(record, callback(cb))
()
}
.handleErrorWith { e =>
handlePublishError(e) >> go
}
go
} *> isHealthyState.set(true)

private def callback(asyncCallback: Either[Throwable, Unit] => Unit): Callback =
new Callback {
def onCompletion(metadata: RecordMetadata, exception: Exception): Unit =
Option(exception) match {
case Some(e) => asyncCallback(Left(e))
case None => asyncCallback(Right(()))
}
}

private def handlePublishError(error: Throwable): F[Unit] =
isHealthyState.set(false) *> Logger[F].error(s"Publishing to Kafka failed with message ${error.getMessage}")
}

object KafkaSink {

implicit private def unsafeLogger[F[_]: Sync]: Logger[F] =
Slf4jLogger.getLogger[F]

def create[F[_]: Async](
sinkConfig: Config.Sink[KafkaSinkConfig],
authCallbackClass: String
): Resource[F, KafkaSink[F]] =
for {
isHealthyState <- Resource.eval(Ref.of[F, Boolean](false))
kafkaProducer <- createProducer(sinkConfig.config, sinkConfig.buffer, authCallbackClass)
kafkaSink = new KafkaSink(sinkConfig.config.maxBytes, isHealthyState, kafkaProducer, sinkConfig.name)
} yield kafkaSink
} yield new KafkaSink(
sinkConfig.config.maxBytes,
isHealthyState,
kafkaProducer,
sinkConfig.name
)

/**
* Creates a new Kafka Producer with the given
Expand All @@ -71,20 +97,20 @@ object KafkaSink {
kafkaConfig: KafkaSinkConfig,
bufferConfig: Config.Buffer,
authCallbackClass: String
): Resource[F, KafkaProducer[F, String, Array[Byte]]] = {
): Resource[F, KafkaProducer[String, Array[Byte]]] = {
val props = Map(
"bootstrap.servers" -> kafkaConfig.brokers,
"acks" -> "all",
"retries" -> kafkaConfig.retries.toString,
"buffer.memory" -> bufferConfig.byteLimit.toString,
"linger.ms" -> bufferConfig.timeLimit.toString,
"key.serializer" -> "org.apache.kafka.common.serialization.StringSerializer",
"value.serializer" -> "org.apache.kafka.common.serialization.ByteArraySerializer",
"sasl.login.callback.handler.class" -> authCallbackClass
) ++ kafkaConfig.producerConf.getOrElse(Map.empty)

val producerSettings =
ProducerSettings[F, String, Array[Byte]].withBootstrapServers(kafkaConfig.brokers).withProperties(props)
) ++ kafkaConfig.producerConf.getOrElse(Map.empty) + ("buffer.memory" -> Long.MaxValue.toString)

KafkaProducer.resource(producerSettings)
val make = Sync[F].delay {
new KafkaProducer[String, Array[Byte]]((props: Map[String, AnyRef]).asJava)
}
Resource.make(make)(p => Sync[F].blocking(p.close))
}
}
2 changes: 1 addition & 1 deletion project/BuildSettings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ object BuildSettings {
moduleName := "snowplow-stream-collector-kafka",
Docker / packageName := "scala-stream-collector-kafka",
libraryDependencies ++= Seq(
Dependencies.Libraries.fs2Kafka,
Dependencies.Libraries.kafka,
Dependencies.Libraries.mskAuth,
Dependencies.Libraries.azureIdentity,

Expand Down
4 changes: 2 additions & 2 deletions project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ object Dependencies {
val fs2PubSub = "0.22.0"
val http4s = "0.23.23"
val jackson = "2.12.7" // force this version to mitigate security vulnerabilities
val fs2Kafka = "2.6.1"
val kafka = "3.8.1"
val log4cats = "2.6.0"
val log4j = "2.17.2" // CVE-2021-44228
val mskAuth = "1.1.1"
Expand Down Expand Up @@ -68,7 +68,7 @@ object Dependencies {
//sinks
val fs2PubSub = "com.permutive" %% "fs2-google-pubsub-grpc" % V.fs2PubSub
val jackson = "com.fasterxml.jackson.core" % "jackson-databind" % V.jackson
val fs2Kafka = "com.github.fd4s" %% "fs2-kafka" % V.fs2Kafka
val kafka = "org.apache.kafka" % "kafka-clients" % V.kafka
val kinesis = "com.amazonaws" % "aws-java-sdk-kinesis" % V.awsSdk
val log4j = "org.apache.logging.log4j" % "log4j-core" % V.log4j
val mskAuth = "software.amazon.msk" % "aws-msk-iam-auth" % V.mskAuth % Runtime // Enables AWS MSK IAM authentication https://github.com/snowplow/stream-collector/pull/214
Expand Down
Loading