From 4795923c03ed4586390f334e70cbd9f9d4d3b470 Mon Sep 17 00:00:00 2001 From: "A. Sophie Blee-Goldman" Date: Mon, 20 Nov 2023 21:31:06 -0800 Subject: [PATCH] create segments if they dont exist, else try to reserve epoch --- .../internal/db/CassandraKeyValueTable.java | 4 +- .../internal/db/CassandraWindowedTable.java | 132 +++++++++++------- .../kafka/internal/db/ColumnName.java | 2 +- .../kafka/internal/db/LwtWriterFactory.java | 7 +- .../internal/stores/CommitBufferTest.java | 3 +- 5 files changed, 93 insertions(+), 55 deletions(-) diff --git a/kafka-client/src/main/java/dev/responsive/kafka/internal/db/CassandraKeyValueTable.java b/kafka-client/src/main/java/dev/responsive/kafka/internal/db/CassandraKeyValueTable.java index 995531036..39c69106b 100644 --- a/kafka-client/src/main/java/dev/responsive/kafka/internal/db/CassandraKeyValueTable.java +++ b/kafka-client/src/main/java/dev/responsive/kafka/internal/db/CassandraKeyValueTable.java @@ -268,9 +268,9 @@ public String name() { public WriterFactory init( final int kafkaPartition ) { - partitioner.allTablePartitions(kafkaPartition).forEach(sub -> client.execute( + partitioner.allTablePartitions(kafkaPartition).forEach(tablePartition -> client.execute( QueryBuilder.insertInto(name) - .value(PARTITION_KEY.column(), PARTITION_KEY.literal(sub)) + .value(PARTITION_KEY.column(), PARTITION_KEY.literal(tablePartition)) .value(ROW_TYPE.column(), METADATA_ROW.literal()) .value(DATA_KEY.column(), DATA_KEY.literal(METADATA_KEY)) .value(TIMESTAMP.column(), TIMESTAMP.literal(METADATA_TS)) diff --git a/kafka-client/src/main/java/dev/responsive/kafka/internal/db/CassandraWindowedTable.java b/kafka-client/src/main/java/dev/responsive/kafka/internal/db/CassandraWindowedTable.java index ae1adfdea..933bbe13b 100644 --- a/kafka-client/src/main/java/dev/responsive/kafka/internal/db/CassandraWindowedTable.java +++ b/kafka-client/src/main/java/dev/responsive/kafka/internal/db/CassandraWindowedTable.java @@ -32,7 +32,6 @@ import static dev.responsive.kafka.internal.db.RowType.METADATA_ROW; import static dev.responsive.kafka.internal.db.partitioning.SegmentPartitioner.UNINITIALIZED_STREAM_TIME; import static dev.responsive.kafka.internal.stores.ResponsiveStoreRegistration.NO_COMMITTED_OFFSET; -import static java.util.Collections.singletonList; import com.datastax.oss.driver.api.core.cql.BoundStatement; import com.datastax.oss.driver.api.core.cql.PreparedStatement; @@ -53,8 +52,10 @@ import java.nio.ByteBuffer; import java.time.Duration; import java.time.Instant; +import java.util.HashMap; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.concurrent.TimeoutException; import javax.annotation.CheckReturnValue; @@ -79,7 +80,7 @@ public class CassandraWindowedTable implements private final CassandraClient client; private final SegmentPartitioner partitioner; - private final PreparedStatement initSegment; + private final PreparedStatement createSegment; private final PreparedStatement expireSegment; private final PreparedStatement insert; private final PreparedStatement delete; @@ -102,9 +103,27 @@ public class CassandraWindowedTable implements // as these entities have different views of the current time and should not be unified. // (Specifically, this table will always lag the view of stream-time that is shared by the // ResponsiveWindowStore and CommitBuffer due to buffering/batching of writes) - private long lastFlushStreamTime = UNINITIALIZED_STREAM_TIME; - private long pendingFlushStreamTime = UNINITIALIZED_STREAM_TIME; - private SegmentRoll activeRoll; + private final Map kafkaPartitionToPendingFlushInfo = new HashMap<>(); + + // TODO: move this into the LWTWriter/Factory to keep this class stateless + private static class PendingFlushInfo { + private long lastFlushStreamTime = UNINITIALIZED_STREAM_TIME; + private long pendingFlushStreamTime = UNINITIALIZED_STREAM_TIME; + private SegmentRoll pendingSegmentRoll; + + void maybeUpdatePendingStreamTime(final long recordTimestamp) { + this.pendingFlushStreamTime = Math.max(pendingFlushStreamTime, recordTimestamp); + } + + void initSegmentRoll(final SegmentRoll pendingSegmentRoll) { + this.pendingSegmentRoll = pendingSegmentRoll; + } + + void finalizeFlush() { + pendingSegmentRoll = null; + lastFlushStreamTime = pendingFlushStreamTime; + } + } public static CassandraWindowedTable create( final CassandraTableSpec spec, @@ -139,16 +158,16 @@ public static CassandraWindowedTable create( client.execute(createTable.build()); client.awaitTable(name).await(Duration.ofSeconds(60)); - final var initSegment = client.prepare( + final var createSegment = client.prepare( QueryBuilder - .update(name) - .setColumn(EPOCH.column(), bindMarker(EPOCH.bind())) - .where(PARTITION_KEY.relation().isEqualTo(bindMarker(PARTITION_KEY.bind()))) - .where(SEGMENT_ID.relation().isEqualTo(bindMarker(SEGMENT_ID.bind()))) - .where(ROW_TYPE.relation().isEqualTo(METADATA_ROW.literal())) - .where(DATA_KEY.relation().isEqualTo(DATA_KEY.literal(METADATA_KEY))) - .where(WINDOW_START.relation().isEqualTo(WINDOW_START.literal(METADATA_TS))) - .ifColumn(EPOCH.column()).isLessThan(bindMarker(EPOCH.bind())) + .insertInto(name) + .value(PARTITION_KEY.column(), bindMarker(PARTITION_KEY.bind())) + .value(SEGMENT_ID.column(), bindMarker(SEGMENT_ID.bind())) + .value(ROW_TYPE.column(), METADATA_ROW.literal()) + .value(DATA_KEY.column(), DATA_KEY.literal(METADATA_KEY)) + .value(WINDOW_START.column(), WINDOW_START.literal(METADATA_TS)) + .value(EPOCH.column(), bindMarker(EPOCH.bind())) + .ifNotExists() .build() ); @@ -371,7 +390,7 @@ public static CassandraWindowedTable create( name, client, partitioner, - initSegment, + createSegment, expireSegment, insert, delete, @@ -412,7 +431,7 @@ public CassandraWindowedTable( final String name, final CassandraClient client, final SegmentPartitioner partitioner, - final PreparedStatement initSegment, + final PreparedStatement createSegment, final PreparedStatement expireSegment, final PreparedStatement insert, final PreparedStatement delete, @@ -434,7 +453,7 @@ public CassandraWindowedTable( this.name = name; this.client = client; this.partitioner = partitioner; - this.initSegment = initSegment; + this.createSegment = createSegment; this.expireSegment = expireSegment; this.insert = insert; this.delete = delete; @@ -463,6 +482,7 @@ public String name() { public WriterFactory, SegmentPartition> init( final int kafkaPartition ) { + kafkaPartitionToPendingFlushInfo.put(kafkaPartition, new PendingFlushInfo()); final SegmentPartition metadataPartition = partitioner.metadataTablePartition(kafkaPartition); client.execute( @@ -538,14 +558,26 @@ public void preCommit( final int kafkaPartition, final long epoch ) { - if (pendingFlushStreamTime > lastFlushStreamTime) { - activeRoll = partitioner.rolledSegments(name, lastFlushStreamTime, pendingFlushStreamTime); - - for (final long segmentId : activeRoll.segmentsToCreate) { - initSegment(new SegmentPartition(kafkaPartition, segmentId), epoch); + final PendingFlushInfo pendingFlush = kafkaPartitionToPendingFlushInfo.get(kafkaPartition); + final SegmentRoll pendingRoll = partitioner.rolledSegments( + name, pendingFlush.lastFlushStreamTime, pendingFlush.pendingFlushStreamTime + ); + pendingFlush.initSegmentRoll(pendingRoll); + for (final long segmentId : pendingRoll.segmentsToCreate) { + final SegmentPartition segment = new SegmentPartition(kafkaPartition, segmentId); + final var createSegment = createSegment(segment, epoch); + + // If the segment creation failed because the table partition already exists, attempt to + // update the epoch in case we are fencing an older writer -- if that fails it means we're + // the ones being fenced + // TODO: what if the segment creation failed for a reason besides already existing? + if (!createSegment.wasApplied()) { + final var reserveEpoch = client.execute(reserveEpoch(segment, epoch)); + + if (!reserveEpoch.wasApplied()) { + handleEpochFencing(kafkaPartition, segment, epoch); + } } - - lastFlushStreamTime = pendingFlushStreamTime; } } @@ -554,17 +586,16 @@ public void postCommit( final int kafkaPartition, final long epoch ) { - if (activeRoll != null) { - for (final long segmentId : activeRoll.segmentsToExpire) { - expireSegment(new SegmentPartition(kafkaPartition, segmentId)); - } - activeRoll = null; + final PendingFlushInfo pendingFlush = kafkaPartitionToPendingFlushInfo.get(kafkaPartition); + for (final long segmentId : pendingFlush.pendingSegmentRoll.segmentsToExpire) { + expireSegment(new SegmentPartition(kafkaPartition, segmentId)); } + pendingFlush.finalizeFlush(); } - private void initSegment(final SegmentPartition segmentToCreate, final long epoch) { - client.execute( - initSegment + private ResultSet createSegment(final SegmentPartition segmentToCreate, final long epoch) { + return client.execute( + createSegment .bind() .setInt(PARTITION_KEY.bind(), segmentToCreate.tablePartition) .setLong(SEGMENT_ID.bind(), segmentToCreate.segmentId) @@ -613,26 +644,12 @@ public BoundStatement setOffset( .setLong(OFFSET.bind(), offset); } - // TODO: combine with reserve epoch? - public BoundStatement setStreamTime( - final int kafkaPartition, - final long streamTime - ) { - final SegmentPartition metadataPartition = partitioner.metadataTablePartition(kafkaPartition); - return setStreamTime - .bind() - .setInt(PARTITION_KEY.bind(), metadataPartition.partitionKey) - .setLong(SEGMENT_ID.bind(), metadataPartition.segmentId) - .setLong(STREAM_TIME.bind(), streamTime); - } - - // TODO: combine epoch and streamTime into single row in the metadata segment? public long fetchStreamTime(final int kafkaPartition) { final SegmentPartition metadataPartition = partitioner.metadataTablePartition(kafkaPartition); final List result = client.execute( fetchStreamTime .bind() - .setInt(PARTITION_KEY.bind(), metadataPartition.partitionKey) + .setInt(PARTITION_KEY.bind(), metadataPartition.tablePartition) .setLong(SEGMENT_ID.bind(), metadataPartition.segmentId)) .all(); @@ -645,6 +662,21 @@ public long fetchStreamTime(final int kafkaPartition) { } } + public BoundStatement setStreamTime( + final int kafkaPartition, + final long epoch + ) { + final PendingFlushInfo pendingFlush = kafkaPartitionToPendingFlushInfo.get(kafkaPartition); + + final SegmentPartition metadataPartition = partitioner.metadataTablePartition(kafkaPartition); + return setStreamTime + .bind() + .setInt(PARTITION_KEY.bind(), metadataPartition.tablePartition) + .setLong(SEGMENT_ID.bind(), metadataPartition.segmentId) + .setLong(STREAM_TIME.bind(), pendingFlush.pendingFlushStreamTime) + .setLong(EPOCH.bind(), epoch); + } + @Override public long fetchEpoch(final SegmentPartition segmentPartition) { final List result = client.execute( @@ -703,7 +735,8 @@ public BoundStatement insert( final byte[] value, final long epochMillis ) { - pendingFlushStreamTime = Math.max(pendingFlushStreamTime, key.stamp); + kafkaPartitionToPendingFlushInfo.get(kafkaPartition).maybeUpdatePendingStreamTime(key.stamp); + final SegmentPartition remotePartition = partitioner.tablePartition(kafkaPartition, key); return insert .bind() @@ -730,7 +763,8 @@ public BoundStatement delete( final int kafkaPartition, final Stamped key ) { - pendingFlushStreamTime = Math.max(pendingFlushStreamTime, key.stamp); + kafkaPartitionToPendingFlushInfo.get(kafkaPartition).maybeUpdatePendingStreamTime(key.stamp); + final SegmentPartition segmentPartition = partitioner.tablePartition(kafkaPartition, key); return delete .bind() diff --git a/kafka-client/src/main/java/dev/responsive/kafka/internal/db/ColumnName.java b/kafka-client/src/main/java/dev/responsive/kafka/internal/db/ColumnName.java index ac0d70d91..1c1c835eb 100644 --- a/kafka-client/src/main/java/dev/responsive/kafka/internal/db/ColumnName.java +++ b/kafka-client/src/main/java/dev/responsive/kafka/internal/db/ColumnName.java @@ -38,7 +38,7 @@ public enum ColumnName { DATA_VALUE("value", "value", b -> bytes((byte[]) b)), OFFSET("offset", "offset"), EPOCH("epoch", "epoch"), - STREAM_TIME("streamTime", "streamtime", ts -> timestamp((long) ts)), + STREAM_TIME("streamTime", "streamtime"), WINDOW_START("windowStart", "windowstart", ts -> timestamp((long) ts)), TIMESTAMP("ts", "ts", ts -> timestamp((long) ts)); diff --git a/kafka-client/src/main/java/dev/responsive/kafka/internal/db/LwtWriterFactory.java b/kafka-client/src/main/java/dev/responsive/kafka/internal/db/LwtWriterFactory.java index 866ca8a70..781ea6b7c 100644 --- a/kafka-client/src/main/java/dev/responsive/kafka/internal/db/LwtWriterFactory.java +++ b/kafka-client/src/main/java/dev/responsive/kafka/internal/db/LwtWriterFactory.java @@ -85,6 +85,12 @@ public RemoteWriteResult

setOffset(final long offset) { builder.addStatement(fencingStatement(tablePartition)); builder.addStatement(table.setOffset(kafkaPartition, offset)); + // TODO(sophie): clean up this hack, perhaps by combining the offset and stream-time into + // a single metadata row update + if (table instanceof CassandraWindowedTable) { + builder.addStatement(((CassandraWindowedTable) table).setStreamTime(kafkaPartition, epoch)); + } + final var result = client.execute(builder.build()); return result.wasApplied() ? RemoteWriteResult.success(tablePartition) @@ -106,7 +112,6 @@ public RemoteWriteResult

commitPendingFlush( final var flushResult = super.commitPendingFlush(pendingFlush, consumedOffset); tableMetadata.postCommit(kafkaPartition, epoch); - // TODO: should #advanceStreamTime return a RemoteWriteResult as well? return flushResult; } diff --git a/kafka-client/src/test/java/dev/responsive/kafka/internal/stores/CommitBufferTest.java b/kafka-client/src/test/java/dev/responsive/kafka/internal/stores/CommitBufferTest.java index f177e8879..99bce7d28 100644 --- a/kafka-client/src/test/java/dev/responsive/kafka/internal/stores/CommitBufferTest.java +++ b/kafka-client/src/test/java/dev/responsive/kafka/internal/stores/CommitBufferTest.java @@ -321,8 +321,7 @@ public void shouldFenceOffsetFlushBasedOnMetadataRowEpoch() { // reserve epoch for partition 8 to ensure it doesn't get flushed // if it did it would get fenced - LwtWriterFactory.initialize( - table, table, client, partitioner, changelog.partition(), List.of(8, 9)); + table.init(KAFKA_PARTITION); final Bytes k1 = Bytes.wrap(new byte[]{1}); final Bytes k2 = Bytes.wrap(new byte[]{2});