diff --git a/kafka/src/main/scala/ox/kafka/KafkaDrain.scala b/kafka/src/main/scala/ox/kafka/KafkaDrain.scala new file mode 100644 index 00000000..f9aee0f0 --- /dev/null +++ b/kafka/src/main/scala/ox/kafka/KafkaDrain.scala @@ -0,0 +1,41 @@ +package ox.kafka + +import org.apache.kafka.clients.consumer.{ConsumerRecord, KafkaConsumer, OffsetAndMetadata} +import org.apache.kafka.clients.producer.{Callback, KafkaProducer, ProducerRecord, RecordMetadata} +import org.apache.kafka.common.TopicPartition +import org.slf4j.LoggerFactory +import ox.* +import ox.channels.* + +import java.util.concurrent.atomic.AtomicInteger +import scala.annotation.tailrec +import scala.collection.mutable +import scala.concurrent.duration.* +import scala.jdk.CollectionConverters.* + +object KafkaDrain: + def publish[K, V](settings: ProducerSettings[K, V]): Source[ProducerRecord[K, V]] => Unit = source => + publish(settings.toProducer, closeWhenComplete = true)(source) + + def publish[K, V](producer: KafkaProducer[K, V], closeWhenComplete: Boolean): Source[ProducerRecord[K, V]] => Unit = source => + // if sending multiple records ends in an exception, we'll receive at most one anyway; we don't want to block the + // producers, hence creating an unbounded channel + val producerExceptions = Channel[Exception](Int.MaxValue) + + try + repeatWhile { + select(producerExceptions.receiveClause, source.receiveOrDoneClause) match // bias on exceptions + case e: ChannelClosed.Error => throw e.toThrowable + case ChannelClosed.Done => false // source must be done, as producerExceptions is never done + case producerExceptions.Received(e) => throw e + case source.Received(record) => + producer.send( + record, + (_: RecordMetadata, exception: Exception) => { + if exception != null then producerExceptions.send(exception) + } + ) + true + } + finally + if closeWhenComplete then uninterruptible(producer.close()) diff --git a/kafka/src/test/scala/ox/kafka/KafkaTest.scala b/kafka/src/test/scala/ox/kafka/KafkaTest.scala index 7cf54202..82e67be3 100644 --- a/kafka/src/test/scala/ox/kafka/KafkaTest.scala +++ b/kafka/src/test/scala/ox/kafka/KafkaTest.scala @@ -21,7 +21,7 @@ class KafkaTest extends AnyFlatSpec with Matchers with EmbeddedKafka with Before override def afterAll(): Unit = EmbeddedKafka.stop() - it should "receive messages from a topic" in { + "source" should "receive messages from a topic" in { // given val topic = "t1" val group = "g1" @@ -49,7 +49,7 @@ class KafkaTest extends AnyFlatSpec with Matchers with EmbeddedKafka with Before } } - it should "send messages to topics" in { + "sink" should "send messages to topics" in { // given val topic = "t2" @@ -116,4 +116,22 @@ class KafkaTest extends AnyFlatSpec with Matchers with EmbeddedKafka with Before inSource2.receive().orThrow.value shouldBe "10" } } + + "drain" should "send messages to topics" in { + // given + val topic = "t4" + + // when + scoped { + val settings = ProducerSettings.default.bootstrapServers(bootstrapServer) + Source + .fromIterable(List("a", "b", "c")) + .mapAsView(msg => ProducerRecord[String, String](topic, msg)) + .applied(KafkaDrain.publish(settings)) + } + + // then + given Deserializer[String] = new StringDeserializer() + consumeNumberMessagesFrom[String](topic, 3, timeout = 30.seconds) shouldBe List("a", "b", "c") + } }