Skip to content

Commit

Permalink
Merge pull request #10 from pinterest/update_consumer
Browse files Browse the repository at this point in the history
Make TieredStorageConsumer implement Kafka's Consumer<K, V> interface
  • Loading branch information
jeffxiang authored Sep 11, 2024
2 parents 6935a54 + b793cd4 commit 46482ca
Show file tree
Hide file tree
Showing 14 changed files with 536 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public class AssignmentAwareConsumerRebalanceListener implements ConsumerRebalan
private final Map<TopicPartition, Long> committed;
private final TieredStorageConsumer.OffsetReset offsetReset;
private final AtomicBoolean isPartitionAssignmentComplete = new AtomicBoolean(true);
private ConsumerRebalanceListener customListener = null;

public AssignmentAwareConsumerRebalanceListener(
KafkaConsumer kafkaConsumer, String consumerGroup, Properties properties,
Expand All @@ -45,13 +46,19 @@ public AssignmentAwareConsumerRebalanceListener(
this.offsetReset = offsetReset;
}

protected void setCustomRebalanceListener(ConsumerRebalanceListener customListener) {
this.customListener = customListener;
}

@Override
public void onPartitionsRevoked(Collection<TopicPartition> collection) {
isPartitionAssignmentComplete.set(false);
LOG.info(String.format("Partitions revoked: " + collection));
this.assignment.removeAll(collection);
collection.forEach(position::remove);
isPartitionAssignmentComplete.set(true);
if (customListener != null)
customListener.onPartitionsRevoked(collection);
}

@Override
Expand Down Expand Up @@ -88,6 +95,8 @@ public void onPartitionsAssigned(Collection<TopicPartition> collection) {
adminClient.close();
}
isPartitionAssignmentComplete.set(true);
if (customListener != null)
customListener.onPartitionsAssigned(collection);
LOG.info("Completed onPartitionsAssigned.");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import org.apache.kafka.clients.consumer.KafkaConsumer;
import org.apache.kafka.clients.consumer.OffsetAndMetadata;
import org.apache.kafka.clients.consumer.OffsetCommitCallback;
import org.apache.kafka.common.TopicPartition;
import org.apache.log4j.LogManager;
import org.apache.log4j.Logger;

import java.time.Duration;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
Expand All @@ -18,12 +20,22 @@
public class KafkaConsumerUtils {
private static final Logger LOG = LogManager.getLogger(KafkaConsumerUtils.class.getName());

public static void commitSync(@SuppressWarnings("rawtypes") KafkaConsumer kafkaConsumer, Map<TopicPartition, Long> offsets) {
public static Map<TopicPartition, OffsetAndMetadata> getOffsetsAndMetadata(Map<TopicPartition, Long> offsets) {
Map<TopicPartition, OffsetAndMetadata> offsetsToCommit = new HashMap<>();
offsets.forEach((key, value) -> offsetsToCommit.put(key, new OffsetAndMetadata(value)));
kafkaConsumer.commitSync(offsetsToCommit);
offsets.forEach((key, value) -> kafkaConsumer.seek(key, value + 1));
LOG.info("Committed offsets: " + offsetsToCommit);
return offsetsToCommit;
}

public static void commitSync(@SuppressWarnings("rawtypes") KafkaConsumer kafkaConsumer, Map<TopicPartition, OffsetAndMetadata> offsets, Duration timeout) {
kafkaConsumer.commitSync(offsets, timeout);
offsets.forEach((key, value) -> kafkaConsumer.seek(key, value.offset() + 1));
LOG.info("Committed offsets: " + offsets);
}

public static void commitAsync(@SuppressWarnings("rawtypes") KafkaConsumer kafkaConsumer, Map<TopicPartition, OffsetAndMetadata> offsets, OffsetCommitCallback callback) {
kafkaConsumer.commitAsync(offsets, callback);
offsets.forEach((key, value) -> kafkaConsumer.seek(key, value.offset() + 1));
LOG.info("Committed offsets: " + offsets);
}

public static void resetOffsetToLatest(@SuppressWarnings("rawtypes") KafkaConsumer kafkaConsumer, TopicPartition topicPartition) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.pinterest.kafka.tieredstorage.common.metrics.MetricsConfiguration;
import org.apache.kafka.clients.consumer.ConsumerRecords;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.serialization.Deserializer;
import org.apache.log4j.LogManager;
import org.apache.log4j.Logger;

Expand All @@ -18,15 +19,15 @@
/**
* S3Consumer is a consumer that reads data from S3.
*/
public class S3Consumer {
public class S3Consumer<K, V> {
private static final Logger LOG = LogManager.getLogger(S3Consumer.class.getName());
private final Map<String, S3StorageServiceEndpoint.Builder> s3Location = new HashMap<>();
private final Set<TopicPartition> assignment = new HashSet<>();
private final S3PartitionsConsumer s3PartitionsConsumer;
private final S3PartitionsConsumer<K, V> s3PartitionsConsumer;
private Map<TopicPartition, Long> positions;

public S3Consumer(String consumerGroup, Properties properties, MetricsConfiguration metricsConfiguration) {
this.s3PartitionsConsumer = new S3PartitionsConsumer(consumerGroup, properties, metricsConfiguration);
public S3Consumer(String consumerGroup, Properties properties, MetricsConfiguration metricsConfiguration, Deserializer<K> keyDeserializer, Deserializer<V> valueDeserializer) {
this.s3PartitionsConsumer = new S3PartitionsConsumer<>(consumerGroup, properties, metricsConfiguration, keyDeserializer, valueDeserializer);
}

/**
Expand Down Expand Up @@ -102,8 +103,8 @@ public Set<TopicPartition> resetOffsets(Set<TopicPartition> partitions) {
* @param maxRecordsToConsume the maximum number of records to consume
* @return the {@link ConsumerRecords} consumed from S3
*/
public ConsumerRecords<byte[], byte[]> poll(int maxRecordsToConsume) {
ConsumerRecords<byte[], byte[]> records = s3PartitionsConsumer.poll(maxRecordsToConsume);
public ConsumerRecords<K, V> poll(int maxRecordsToConsume) {
ConsumerRecords<K, V> records = s3PartitionsConsumer.poll(maxRecordsToConsume);
setPositions(s3PartitionsConsumer.getPositions());
return records;
}
Expand All @@ -112,6 +113,14 @@ public Map<TopicPartition, Long> getPositions() {
return this.positions;
}

public void pause(Collection<TopicPartition> partitions) {
s3PartitionsConsumer.pause(partitions);
}

public void resume(Collection<TopicPartition> partitions) {
s3PartitionsConsumer.resume(partitions);
}

/**
* Unsubscribes the S3 consumer
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,23 @@ public class S3OffsetIndexHandler {
private ByteBuffer indexFileByteBuffer = null;
private static S3Client s3Client = S3Client.builder().region(S3Utils.REGION).build();

/**
* Returns the Kafka offset at the given byte position in the log segment.
* @param s3Path S3 path of the log segment
* @param position byte position in the log segment
* @return Kafka offset at the given byte position in the log segment
*/
public static long getOffsetAtPosition(Triple<String, String, Long> s3Path, int position) {
String logFileKey = s3Path.getMiddle();
String noExtensionFileKey = logFileKey.substring(0, logFileKey.lastIndexOf("."));
long baseOffset = Long.parseLong(noExtensionFileKey.substring(noExtensionFileKey.lastIndexOf("/") + 1));
String indexFileKey = noExtensionFileKey + ".index";
GetObjectRequest getObjectRequest = GetObjectRequest.builder().bucket(s3Path.getLeft()).key(indexFileKey).build();
byte[] indexFileBytes = s3Client.getObject(getObjectRequest, ResponseTransformer.toBytes()).asByteArray();
ByteBuffer indexFileByteBuffer = ByteBuffer.wrap(indexFileBytes);
return baseOffset + indexFileByteBuffer.getInt(position);
}

/**
* Returns the minimum byte position in the log segment given the Kafka offset of interest.
* @param s3Path S3 path of the log segment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
import org.apache.commons.lang3.tuple.Triple;
import org.apache.kafka.clients.consumer.ConsumerConfig;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.clients.consumer.OffsetAndTimestamp;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.annotation.InterfaceStability;
import org.apache.kafka.common.header.Headers;
import org.apache.kafka.common.header.internals.RecordHeaders;
import org.apache.kafka.common.record.Record;
import org.apache.kafka.common.record.S3ChannelRecordBatch;
import org.apache.kafka.common.record.S3Records;
import org.apache.kafka.common.record.TimestampType;
import org.apache.kafka.common.serialization.Deserializer;
import org.apache.kafka.common.utils.Utils;
import org.apache.log4j.LogManager;
import org.apache.log4j.Logger;
Expand All @@ -29,10 +32,9 @@
/**
* Consumes Kafka records in a given Kafka partition from S3
*/
public class S3PartitionConsumer {
public class S3PartitionConsumer<K, V> {
private static final Logger LOG = LogManager.getLogger(S3PartitionConsumer.class.getName());
private static final long DEFAULT_S3_METADATA_RELOAD_INTERVAL_MS = 3600000; // 1 hour
private static final long DEFAULT_MAX_PARTITION_FETCH_SIZE_BYTES = 1048576;
private String location;
private final TopicPartition topicPartition;
private long position;
Expand All @@ -49,17 +51,44 @@ public class S3PartitionConsumer {
private String latestS3Object = null;
private final S3OffsetIndexHandler s3OffsetIndexHandler = new S3OffsetIndexHandler();
private final MetricsConfiguration metricsConfiguration;
private Deserializer<K> keyDeserializer;
private Deserializer<V> valueDeserializer;

public S3PartitionConsumer(String location, TopicPartition topicPartition, String consumerGroup, Properties properties, MetricsConfiguration metricsConfiguration) {
this(location, topicPartition, consumerGroup, properties, metricsConfiguration, null, null);
}

public S3PartitionConsumer(String location, TopicPartition topicPartition, String consumerGroup, Properties properties, MetricsConfiguration metricsConfiguration, Deserializer<K> keyDeserializer, Deserializer<V> valueDeserializer) {
this.location = location;
this.topicPartition = topicPartition;
this.consumerGroup = consumerGroup;
maxPartitionFetchSizeBytes = Integer.parseInt(properties.getProperty(ConsumerConfig.MAX_PARTITION_FETCH_BYTES_CONFIG, Long.toString(DEFAULT_MAX_PARTITION_FETCH_SIZE_BYTES)));
s3MetadataReloadIntervalMs = Long.parseLong(properties.getProperty(TieredStorageConsumerConfig.STORAGE_SERVICE_ENDPOINT_S3_METADATA_RELOAD_INTERVAL_MS_CONFIG, Long.toString(DEFAULT_S3_METADATA_RELOAD_INTERVAL_MS)));
ConsumerConfig consumerConfig = new ConsumerConfig(properties);
this.maxPartitionFetchSizeBytes = consumerConfig.getInt(ConsumerConfig.MAX_PARTITION_FETCH_BYTES_CONFIG);
this.s3MetadataReloadIntervalMs = Long.parseLong(properties.getProperty(TieredStorageConsumerConfig.STORAGE_SERVICE_ENDPOINT_S3_METADATA_RELOAD_INTERVAL_MS_CONFIG, Long.toString(DEFAULT_S3_METADATA_RELOAD_INTERVAL_MS)));
this.metricsConfiguration = metricsConfiguration;
this.keyDeserializer = keyDeserializer;
this.valueDeserializer = valueDeserializer;
initializeDeserializers(consumerConfig);
LOG.info(String.format("Created S3PartitionConsumer for %s with maxPartitionFetchSizeBytes=%s and s3MetadataReloadIntervalMs=%s", topicPartition, maxPartitionFetchSizeBytes, s3MetadataReloadIntervalMs));
}

private void initializeDeserializers(ConsumerConfig consumerConfig) {
// borrowed from KafkaConsumer
if (keyDeserializer == null) {
this.keyDeserializer = consumerConfig.getConfiguredInstance(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, Deserializer.class);
this.keyDeserializer.configure(consumerConfig.originals(), true);
} else {
consumerConfig.ignore(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG);
}
if (valueDeserializer == null) {
this.valueDeserializer = consumerConfig.getConfiguredInstance(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, Deserializer.class);
this.valueDeserializer.configure(consumerConfig.originals(), false);
} else {
consumerConfig.ignore(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG);
}
// /borrowed from KafkaConsumer
}

/**
* Returns the first offset of the first S3 object
* @return the first offset of the first S3 object
Expand Down Expand Up @@ -177,7 +206,7 @@ private void maybeSetS3Records(long position) throws NoS3ObjectException, IOExce
* @param maxRecords
* @return list of {@link org.apache.kafka.clients.consumer.ConsumerRecords}
*/
public List<ConsumerRecord<byte[], byte[]>> poll(int maxRecords) {
public List<ConsumerRecord<K, V>> poll(int maxRecords) {
return poll(maxRecords, false);
}

Expand All @@ -187,7 +216,7 @@ public List<ConsumerRecord<byte[], byte[]>> poll(int maxRecords) {
* @param shouldReadWholeObject
* @return list of {@link org.apache.kafka.clients.consumer.ConsumerRecords}
*/
public List<ConsumerRecord<byte[], byte[]>> poll(int maxRecords, boolean shouldReadWholeObject) {
public List<ConsumerRecord<K, V>> poll(int maxRecords, boolean shouldReadWholeObject) {
if (shouldReadWholeObject)
LOG.debug(String.format("Trying to consume all records from each S3 object for %s", topicPartition));
else
Expand Down Expand Up @@ -215,7 +244,7 @@ public List<ConsumerRecord<byte[], byte[]>> poll(int maxRecords, boolean shouldR
}

long lastSeenOffset = -1;
List<ConsumerRecord<byte[], byte[]>> records = new ArrayList<>();
List<ConsumerRecord<K, V>> records = new ArrayList<>();
LOG.debug(String.format("For consuming from offset %s will be processing S3 object %s", position, s3Path));
while (batches.hasNext() && (shouldReadWholeObject || (recordCount < maxRecords && fetchSizeBytes < maxPartitionFetchSizeBytes))) {
S3ChannelRecordBatch batch = batches.next();
Expand Down Expand Up @@ -253,8 +282,8 @@ public List<ConsumerRecord<byte[], byte[]>> poll(int maxRecords, boolean shouldR
record.checksumOrNull() == null ? -1 : record.checksumOrNull(),
record.keySize(),
record.valueSize(),
record.key() == null ? null : Utils.toArray(record.key()),
record.value() == null ? null : Utils.toArray(record.value()),
record.key() == null ? null : keyDeserializer.deserialize(topicPartition.topic(), Utils.toArray(record.key())),
record.value() == null ? null : valueDeserializer.deserialize(topicPartition.topic(), Utils.toArray(record.value())),
headers
));
++recordCount;
Expand Down Expand Up @@ -366,6 +395,14 @@ private Triple<String, String, Long> getS3PathForPosition(long position) {
return s3Object;
}

@InterfaceStability.Evolving
public Map<TopicPartition, OffsetAndTimestamp> offsetForTime(Long timestamp, Long beginningOffset, Long endOffset) {
//TODO: This needs a delicate implementation to avoid listing the whole prefix, which is an expensive operation,
// if possible. A naive approach is to do a binary search since we have the first and last offset (log segment)
// on S3
throw new UnsupportedOperationException("offsetForTime is not implemented yet");
}

/**
* Closes the S3Records object
* @throws IOException
Expand Down
Loading

0 comments on commit 46482ca

Please sign in to comment.