diff --git a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumer.java b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumer.java index 099daf754f..38a4bf125d 100644 --- a/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumer.java +++ b/data-prepper-plugins/kafka-plugins/src/main/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumer.java @@ -67,6 +67,7 @@ public class KafkaCustomConsumer implements Runnable, ConsumerRebalanceListener private static final Logger LOG = LoggerFactory.getLogger(KafkaCustomConsumer.class); private static final Long COMMIT_OFFSET_INTERVAL_MS = 300000L; private static final int DEFAULT_NUMBER_OF_RECORDS_TO_ACCUMULATE = 1; + private static final int RETRY_ON_EXCEPTION_SLEEP_MS = 1000; static final String DEFAULT_KEY = "message"; private volatile long lastCommitTime; @@ -75,6 +76,7 @@ public class KafkaCustomConsumer implements Runnable, ConsumerRebalanceListener private final String topicName; private final TopicConsumerConfig topicConfig; private MessageFormat schema; + private boolean paused; private final BufferAccumulator> bufferAccumulator; private final Buffer> buffer; private static final ObjectMapper objectMapper = new ObjectMapper(); @@ -94,6 +96,7 @@ public class KafkaCustomConsumer implements Runnable, ConsumerRebalanceListener private long numRecordsCommitted = 0; private final LogRateLimiter errLogRateLimiter; private final ByteDecoder byteDecoder; + private final long maxRetriesOnException; public KafkaCustomConsumer(final KafkaConsumer consumer, final AtomicBoolean shutdownInProgress, @@ -110,8 +113,10 @@ public KafkaCustomConsumer(final KafkaConsumer consumer, this.shutdownInProgress = shutdownInProgress; this.consumer = consumer; this.buffer = buffer; + this.paused = false; this.byteDecoder = byteDecoder; this.topicMetrics = topicMetrics; + this.maxRetriesOnException = topicConfig.getMaxPollInterval().toMillis() / (2 * RETRY_ON_EXCEPTION_SLEEP_MS); this.pauseConsumePredicate = pauseConsumePredicate; this.topicMetrics.register(consumer); this.offsetsToCommit = new HashMap<>(); @@ -170,10 +175,15 @@ private AcknowledgementSet createAcknowledgementSet(Map void consumeRecords() throws Exception { - try { + ConsumerRecords doPoll() throws Exception { ConsumerRecords records = consumer.poll(Duration.ofMillis(topicConfig.getThreadWaitingTime().toMillis()/2)); + return records; + } + + void consumeRecords() throws Exception { + try { + ConsumerRecords records = doPoll(); if (Objects.nonNull(records) && !records.isEmpty() && records.count() > 0) { Map offsets = new HashMap<>(); AcknowledgementSet acknowledgementSet = null; @@ -419,21 +429,45 @@ private void processRecord(final AcknowledgementSet acknowledgementSet, final Re if (acknowledgementSet != null) { acknowledgementSet.add(record.getData()); } + long numRetries = 0; while (true) { try { - bufferAccumulator.add(record); + if (numRetries == 0) { + bufferAccumulator.add(record); + } else { + bufferAccumulator.flush(); + } break; } catch (Exception e) { + if (!paused && numRetries++ > maxRetriesOnException) { + paused = true; + consumer.pause(consumer.assignment()); + } if (e instanceof SizeOverflowException) { topicMetrics.getNumberOfBufferSizeOverflows().increment(); } else { LOG.debug("Error while adding record to buffer, retrying ", e); } try { - Thread.sleep(100); + Thread.sleep(RETRY_ON_EXCEPTION_SLEEP_MS); + if (paused) { + ConsumerRecords records = doPoll(); + if (records.count() > 0) { + LOG.warn("Unexpected records received while the consumer is paused. Resetting the paritions to retry from last read pointer"); + synchronized(this) { + partitionsToReset.addAll(consumer.assignment()); + }; + break; + } + } } catch (Exception ex) {} // ignore the exception because it only means the thread slept for shorter time } } + + if (paused) { + consumer.resume(consumer.assignment()); + paused = false; + } } private void iterateRecordPartitions(ConsumerRecords records, final AcknowledgementSet acknowledgementSet, @@ -503,6 +537,9 @@ public void onPartitionsAssigned(Collection partitions) { LOG.info("Assigned partition {}", topicPartition); ownedPartitionsEpoch.put(topicPartition, epoch); } + if (paused) { + consumer.pause(consumer.assignment()); + } } dumpTopicPartitionOffsets(partitions); } @@ -520,6 +557,9 @@ public void onPartitionsRevoked(Collection partitions) { ownedPartitionsEpoch.remove(topicPartition); partitionCommitTrackerMap.remove(topicPartition.partition()); } + if (paused) { + consumer.pause(consumer.assignment()); + } } } diff --git a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumerTest.java b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumerTest.java index 7d3a0f3fb9..968639f674 100644 --- a/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumerTest.java +++ b/data-prepper-plugins/kafka-plugins/src/test/java/org/opensearch/dataprepper/plugins/kafka/consumer/KafkaCustomConsumerTest.java @@ -28,6 +28,7 @@ import org.opensearch.dataprepper.model.CheckpointState; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.buffer.SizeOverflowException; import org.opensearch.dataprepper.model.configuration.PluginSetting; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; @@ -54,6 +55,7 @@ import static org.hamcrest.Matchers.hasEntry; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -70,6 +72,9 @@ public class KafkaCustomConsumerTest { private Buffer> buffer; + @Mock + private Buffer> mockBuffer; + @Mock private KafkaConsumerConfig sourceConfig; @@ -106,21 +111,31 @@ public class KafkaCustomConsumerTest { private Counter posCounter; @Mock private Counter negCounter; + @Mock + private Counter overflowCounter; private Duration delayTime; private double posCount; private double negCount; + private double overflowCount; + private boolean paused; + private boolean resumed; @BeforeEach public void setUp() { delayTime = Duration.ofMillis(10); + paused = false; + resumed = false; kafkaConsumer = mock(KafkaConsumer.class); topicMetrics = mock(KafkaTopicConsumerMetrics.class); counter = mock(Counter.class); posCounter = mock(Counter.class); + mockBuffer = mock(Buffer.class); negCounter = mock(Counter.class); + overflowCounter = mock(Counter.class); topicConfig = mock(TopicConsumerConfig.class); when(topicMetrics.getNumberOfPositiveAcknowledgements()).thenReturn(posCounter); when(topicMetrics.getNumberOfNegativeAcknowledgements()).thenReturn(negCounter); + when(topicMetrics.getNumberOfBufferSizeOverflows()).thenReturn(overflowCounter); when(topicMetrics.getNumberOfRecordsCommitted()).thenReturn(counter); when(topicMetrics.getNumberOfDeserializationErrors()).thenReturn(counter); when(topicConfig.getThreadWaitingTime()).thenReturn(Duration.ofSeconds(1)); @@ -128,6 +143,16 @@ public void setUp() { when(topicConfig.getAutoCommit()).thenReturn(false); when(kafkaConsumer.committed(any(TopicPartition.class))).thenReturn(null); + doAnswer((i)-> { + paused = true; + return null; + }).when(kafkaConsumer).pause(any()); + + doAnswer((i)-> { + resumed = true; + return null; + }).when(kafkaConsumer).resume(any()); + doAnswer((i)-> { posCount += 1.0; return null; @@ -136,6 +161,10 @@ public void setUp() { negCount += 1.0; return null; }).when(negCounter).increment(); + doAnswer((i)-> { + overflowCount += 1.0; + return null; + }).when(overflowCounter).increment(); doAnswer((i)-> {return posCount;}).when(posCounter).count(); doAnswer((i)-> {return negCount;}).when(negCounter).count(); callbackExecutor = Executors.newScheduledThreadPool(2); @@ -147,6 +176,11 @@ public void setUp() { when(topicConfig.getName()).thenReturn(TOPIC_NAME); } + public KafkaCustomConsumer createObjectUnderTestWithMockBuffer(String schemaType) { + return new KafkaCustomConsumer(kafkaConsumer, shutdownInProgress, mockBuffer, sourceConfig, topicConfig, schemaType, + acknowledgementSetManager, null, topicMetrics, pauseConsumePredicate); + } + public KafkaCustomConsumer createObjectUnderTest(String schemaType, boolean acknowledgementsEnabled) { when(sourceConfig.getAcknowledgementsEnabled()).thenReturn(acknowledgementsEnabled); return new KafkaCustomConsumer(kafkaConsumer, shutdownInProgress, buffer, sourceConfig, topicConfig, schemaType, @@ -162,6 +196,56 @@ private BlockingBuffer> getBuffer() { return new BlockingBuffer<>(pluginSetting); } + @Test + public void testBufferOverflowPauseResume() throws InterruptedException, Exception { + when(topicConfig.getMaxPollInterval()).thenReturn(Duration.ofMillis(4000)); + String topic = topicConfig.getName(); + consumerRecords = createPlainTextRecords(topic, 0L); + doAnswer((i)-> { + if (!paused && !resumed) + throw new SizeOverflowException("size overflow"); + buffer.writeAll(i.getArgument(0), i.getArgument(1)); + return null; + }).when(mockBuffer).writeAll(any(), anyInt()); + + doAnswer((i) -> { + if (paused && !resumed) + return List.of(); + return consumerRecords; + }).when(kafkaConsumer).poll(any(Duration.class)); + consumer = createObjectUnderTestWithMockBuffer("plaintext"); + try { + consumer.onPartitionsAssigned(List.of(new TopicPartition(topic, testPartition))); + consumer.consumeRecords(); + } catch (Exception e){} + assertTrue(paused); + assertTrue(resumed); + + final Map.Entry>, CheckpointState> bufferRecords = buffer.read(1000); + ArrayList> bufferedRecords = new ArrayList<>(bufferRecords.getKey()); + Assertions.assertEquals(consumerRecords.count(), bufferedRecords.size()); + Map offsetsToCommit = consumer.getOffsetsToCommit(); + Assertions.assertEquals(offsetsToCommit.size(), 1); + offsetsToCommit.forEach((topicPartition, offsetAndMetadata) -> { + Assertions.assertEquals(topicPartition.partition(), testPartition); + Assertions.assertEquals(topicPartition.topic(), topic); + Assertions.assertEquals(offsetAndMetadata.offset(), 2L); + }); + Assertions.assertEquals(consumer.getNumRecordsCommitted(), 2L); + + for (Record record: bufferedRecords) { + Event event = record.getData(); + String value1 = event.get(testKey1, String.class); + String value2 = event.get(testKey2, String.class); + assertTrue(value1 != null || value2 != null); + if (value1 != null) { + Assertions.assertEquals(value1, testValue1); + } + if (value2 != null) { + Assertions.assertEquals(value2, testValue2); + } + } + } @Test public void testPlainTextConsumeRecords() throws InterruptedException { String topic = topicConfig.getName();