Skip to content

Commit

Permalink
Explicitly shut down thread pools for pubsub (close #42)
Browse files Browse the repository at this point in the history
  • Loading branch information
istreeter committed Dec 4, 2023
1 parent ec72d56 commit a1706f7
Showing 1 changed file with 61 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
*/
package com.snowplowanalytics.snowplow.sources.pubsub

import cats.effect.{Async, Sync, Resource}
import cats.effect.{Async, Resource, Sync}
import cats.effect.implicits._
import cats.implicits._
import fs2.{Chunk, Stream}
Expand All @@ -20,9 +20,9 @@ import java.time.Instant
// pubsub
import com.google.api.core.ApiService
import com.google.api.gax.batching.FlowControlSettings
import com.google.api.gax.core.ExecutorProvider
import com.google.api.gax.core.FixedExecutorProvider
import com.google.cloud.pubsub.v1.{AckReplyConsumer, MessageReceiver, Subscriber}
import com.google.common.util.concurrent.{ForwardingListeningExecutorService, MoreExecutors}
import com.google.common.util.concurrent.{ForwardingExecutorService, ListeningExecutorService, MoreExecutors}
import com.google.pubsub.v1.{ProjectSubscriptionName, PubsubMessage}
import org.threeten.bp.{Duration => ThreetenDuration}

Expand All @@ -33,7 +33,7 @@ import com.snowplowanalytics.snowplow.sources.internal.{Checkpointer, LowLevelEv

import scala.concurrent.duration.FiniteDuration

import java.util.concurrent.{Callable, Phaser, ScheduledExecutorService, ScheduledFuture, ScheduledThreadPoolExecutor, Semaphore, TimeUnit}
import java.util.concurrent.{Callable, ExecutorService, Executors, Phaser, ScheduledExecutorService, ScheduledFuture, Semaphore, TimeUnit}
import java.util.concurrent.atomic.AtomicReference

object PubsubSource {
Expand Down Expand Up @@ -176,7 +176,8 @@ object PubsubSource {

private def runSubscriber[F[_]: Async](config: PubsubSourceConfig, control: Control): Resource[F, Unit] =
for {
executor <- Resource.make(Sync[F].delay(scheduledExecutorService))(s => Sync[F].delay(s.shutdown()))
executor <- executorResource(Sync[F].delay(Executors.newScheduledThreadPool(2 * config.parallelPullCount)))
direct <- executorResource(Sync[F].delay(MoreExecutors.newDirectExecutorService()))
receiver = messageReceiver(config, control)
name = ProjectSubscriptionName.of(config.subscription.projectId, config.subscription.subscriptionId)
subscriber <- Resource.eval(Sync[F].delay {
Expand All @@ -186,12 +187,8 @@ object PubsubSource {
.setMaxDurationPerAckExtension(convertDuration(config.maxDurationPerAckExtension))
.setMinDurationPerAckExtension(convertDuration(config.minDurationPerAckExtension))
.setParallelPullCount(config.parallelPullCount)
.setExecutorProvider {
new ExecutorProvider {
def shouldAutoClose: Boolean = true
def getExecutor: ScheduledExecutorService = executor
}
}
.setExecutorProvider(FixedExecutorProvider.create(executorForEventCallbacks(direct, executor)))
.setSystemExecutorProvider(FixedExecutorProvider.create(executor))
.setFlowControlSettings {
// Switch off any flow control, because we handle it ourselves with the semaphore
FlowControlSettings.getDefaultInstance
Expand Down Expand Up @@ -247,40 +244,60 @@ object PubsubSource {
}
}

private def scheduledExecutorService: ScheduledExecutorService = new ForwardingListeningExecutorService with ScheduledExecutorService {
val delegate = MoreExecutors.newDirectExecutorService
lazy val scheduler = new ScheduledThreadPoolExecutor(1) // I think this scheduler is never used, but I implement it here for safety
override def schedule[V](
callable: Callable[V],
delay: Long,
unit: TimeUnit
): ScheduledFuture[V] =
scheduler.schedule(callable, delay, unit)
override def schedule(
runnable: Runnable,
delay: Long,
unit: TimeUnit
): ScheduledFuture[_] =
scheduler.schedule(runnable, delay, unit)
override def scheduleAtFixedRate(
runnable: Runnable,
initialDelay: Long,
period: Long,
unit: TimeUnit
): ScheduledFuture[_] =
scheduler.scheduleAtFixedRate(runnable, initialDelay, period, unit)
override def scheduleWithFixedDelay(
runnable: Runnable,
initialDelay: Long,
delay: Long,
unit: TimeUnit
): ScheduledFuture[_] =
scheduler.scheduleWithFixedDelay(runnable, initialDelay, delay, unit)
override def shutdown(): Unit = {
delegate.shutdown()
scheduler.shutdown()
private def executorResource[F[_]: Sync, E <: ExecutorService](make: F[E]): Resource[F, E] =
Resource.make(make)(es => Sync[F].blocking(es.shutdown()))

/**
* The ScheduledExecutorService to be used for processing events.
*
* We execute the callback on a `DirectExecutor`, which means the underlying Subscriber runs it
* directly on its system executor. When the queue is full, this means we deliberately block the
* system exeuctor. We need to do this trick because we have disabled FlowControl. This trick is
* our own version of flow control.
*/
private def executorForEventCallbacks(
directExecutor: ListeningExecutorService,
systemExecutor: ScheduledExecutorService
): ScheduledExecutorService =
new ForwardingExecutorService with ScheduledExecutorService {

/**
* Non-scheduled tasks (e.g. when a message is received), are run directly, without jumping to
* another thread pool
*/
override val delegate = directExecutor

/**
* Scheduled tasks (if they exist) are scheduled on the same thread pool shared by the system
* executor. As far as I know, these schedule methods never get called.
*/
override def schedule[V](
callable: Callable[V],
delay: Long,
unit: TimeUnit
): ScheduledFuture[V] =
systemExecutor.schedule(callable, delay, unit)
override def schedule(
runnable: Runnable,
delay: Long,
unit: TimeUnit
): ScheduledFuture[_] =
systemExecutor.schedule(runnable, delay, unit)
override def scheduleAtFixedRate(
runnable: Runnable,
initialDelay: Long,
period: Long,
unit: TimeUnit
): ScheduledFuture[_] =
systemExecutor.scheduleAtFixedRate(runnable, initialDelay, period, unit)
override def scheduleWithFixedDelay(
runnable: Runnable,
initialDelay: Long,
delay: Long,
unit: TimeUnit
): ScheduledFuture[_] =
systemExecutor.scheduleWithFixedDelay(runnable, initialDelay, delay, unit)
}
}

private def convertDuration(d: FiniteDuration): ThreetenDuration =
ThreetenDuration.ofMillis(d.toMillis)
Expand Down

0 comments on commit a1706f7

Please sign in to comment.