Skip to content

Commit

Permalink
Row-level TTL PR 5: automatically extract StateSerdes (#375)
Browse files Browse the repository at this point in the history
Avoid forcing the user to pass in their serdes again by extracting the StateSerdes from the MeteredKeyValueStore layer of the StateStore hierarchy.

Also addresses feedback from PR 4 to increase the ttl times used in the CassandraKV/FactTableIntegrationTest tests
  • Loading branch information
ableegoldman authored Oct 29, 2024
1 parent 09c25d7 commit ee59069
Show file tree
Hide file tree
Showing 12 changed files with 246 additions and 254 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.function.Function;
import org.apache.kafka.common.serialization.Serde;

public class TtlProvider<K, V> {

/**
* Creates a new TtlProvider with the given default duration to retain records for unless
* overridden. To allow ttl overrides for individual records, you can use one of the
* {@link #fromKey(Function, Serde)}, {@link #fromValue(Function, Serde)},
* or {@link #fromKeyAndValue(BiFunction, Serde, Serde)} methods to define the row-level
* {@link #fromKey(Function)}, {@link #fromValue(Function)},
* or {@link #fromKeyAndValue(BiFunction)} methods to define the row-level
* override function.
*
* @return a new TtlProvider that will retain records for the specified default duration
Expand All @@ -38,17 +37,15 @@ public static <K, V> TtlProvider<K, V> withDefault(final Duration defaultTtl) {
return new TtlProvider<>(
TtlType.DEFAULT_ONLY,
TtlDuration.of(defaultTtl),
(ignoredK, ignoredV) -> Optional.empty(),
null,
null
(ignoredK, ignoredV) -> Optional.empty()
);
}

/**
* Creates a new TtlProvider that has no default (equivalent to infinite retention for
* all records unless an override is specified). Must be used in combination with
* exactly one of the {@link #fromKey(Function, Serde)}, {@link #fromValue(Function, Serde)},
* and {@link #fromKeyAndValue(BiFunction, Serde, Serde)} methods to define the row-level
* exactly one of the {@link #fromKey(Function)}, {@link #fromValue(Function)},
* and {@link #fromKeyAndValue(BiFunction)} methods to define the row-level
* override function.
*
* @return a new TtlProvider that will retain records indefinitely by default
Expand All @@ -57,83 +54,65 @@ public static <K, V> TtlProvider<K, V> withNoDefault() {
return new TtlProvider<>(
TtlType.DEFAULT_ONLY,
TtlDuration.noTtl(),
(ignoredK, ignoredV) -> Optional.empty(),
null,
null
(ignoredK, ignoredV) -> Optional.empty()
);
}

/**
* @param computeTtlFromKey function that returns the ttl override for this specific key,
* or {@link Optional#empty()} to use the default ttl
*
* @return the same TtlProvider with a key-based override function
*/
public TtlProvider<K, V> fromKey(
final Function<K, Optional<TtlDuration>> computeTtlFromKey,
final Serde<K> keySerde
final Function<K, Optional<TtlDuration>> computeTtlFromKey
) {
if (ttlType.equals(TtlType.VALUE) || ttlType.equals(TtlType.KEY_AND_VALUE)) {
throw new IllegalArgumentException("Must choose only key, value, or key-and-value ttl");
}

if (keySerde == null || keySerde.deserializer() == null) {
throw new IllegalArgumentException("The key Serde and Deserializer must not be null");
}

return new TtlProvider<>(
TtlType.KEY,
defaultTtl,
(k, ignored) -> computeTtlFromKey.apply(k),
keySerde,
null
(k, ignored) -> computeTtlFromKey.apply(k)
);
}

/**
* @param computeTtlFromValue function that returns the ttl override for this specific value,
* or {@link Optional#empty()} to use the default ttl
* @return the same TtlProvider with a value-based override function
*/
public TtlProvider<K, V> fromValue(
final Function<V, Optional<TtlDuration>> computeTtlFromValue,
final Serde<V> valueSerde
final Function<V, Optional<TtlDuration>> computeTtlFromValue
) {
if (ttlType.equals(TtlType.KEY) || ttlType.equals(TtlType.KEY_AND_VALUE)) {
throw new IllegalArgumentException("Must choose only key, value, or key-and-value ttl");
}

if (valueSerde == null || valueSerde.deserializer() == null) {
throw new IllegalArgumentException("The value Serde and Deserializer must not be null");
}

return new TtlProvider<>(
TtlType.VALUE,
defaultTtl,
(ignored, v) -> computeTtlFromValue.apply(v),
null,
valueSerde);
(ignored, v) -> computeTtlFromValue.apply(v)
);
}

/**
* @param computeTtlFromKeyAndValue function that returns the ttl override for this specific key
* and value, or {@link Optional#empty()} to use the default ttl
* @return the same TtlProvider with a key-and-value-based override function
*/
public TtlProvider<K, V> fromKeyAndValue(
final BiFunction<K, V, Optional<TtlDuration>> computeTtlFromKeyAndValue,
final Serde<K> keySerde,
final Serde<V> valueSerde
final BiFunction<K, V, Optional<TtlDuration>> computeTtlFromKeyAndValue
) {
if (ttlType.equals(TtlType.KEY) || ttlType.equals(TtlType.VALUE)) {
throw new IllegalArgumentException("Must choose only key, value, or key-and-value ttl");
}

if (keySerde == null || keySerde.deserializer() == null) {
throw new IllegalArgumentException("The key Serde and Deserializer must not be null");
} else if (valueSerde == null || valueSerde.deserializer() == null) {
throw new IllegalArgumentException("The value Serde and Deserializer must not be null");
}

return new TtlProvider<>(
TtlType.KEY_AND_VALUE,
defaultTtl,
computeTtlFromKeyAndValue,
keySerde,
valueSerde
computeTtlFromKeyAndValue
);
}

Expand Down Expand Up @@ -218,34 +197,18 @@ private enum TtlType {
private final TtlType ttlType;
private final TtlDuration defaultTtl;

// Only non-null for key/value-based ttl providers
private final Serde<K> keySerde;
private final Serde<V> valueSerde;

private final BiFunction<K, V, Optional<TtlDuration>> computeTtl;

private TtlProvider(
final TtlType ttlType,
final TtlDuration defaultTtl,
final BiFunction<K, V, Optional<TtlDuration>> computeTtl,
final Serde<K> keySerde,
final Serde<V> valueSerde
final BiFunction<K, V, Optional<TtlDuration>> computeTtl
) {
this.ttlType = ttlType;
this.defaultTtl = defaultTtl;
this.keySerde = keySerde;
this.valueSerde = valueSerde;
this.computeTtl = computeTtl;
}

public Serde<K> keySerde() {
return keySerde;
}

public Serde<V> valueSerde() {
return valueSerde;
}

public TtlDuration defaultTtl() {
return defaultTtl;
}
Expand All @@ -265,6 +228,7 @@ public Optional<TtlDuration> computeTtl(
) {
final K key;
final V value;

switch (ttlType) {
case DEFAULT_ONLY:
key = null; //ignored
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import dev.responsive.kafka.internal.metrics.ResponsiveRestoreListener;
import dev.responsive.kafka.internal.utils.SessionClients;
import java.util.Collection;
import java.util.Optional;
import java.util.concurrent.TimeoutException;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.common.TopicPartition;
Expand All @@ -49,10 +50,11 @@ public class GlobalOperations implements KeyValueOperations {

public static GlobalOperations create(
final StateStoreContext storeContext,
final ResponsiveKeyValueParams params
final ResponsiveKeyValueParams params,
final Optional<TtlResolver<?, ?>> ttlResolver
) throws InterruptedException, TimeoutException {

if (params.ttlProvider().isPresent()) {
if (ttlResolver.isPresent()) {
throw new UnsupportedOperationException("Global stores are not yet compatible with ttl");
}

Expand All @@ -67,7 +69,7 @@ public static GlobalOperations create(
final var spec = RemoteTableSpecFactory.fromKVParams(
params,
defaultPartitioner(),
TtlResolver.fromTtlProvider(false, changelogTopic.topic(), params.ttlProvider())
ttlResolver
);

final var table = client.globalFactory().create(spec);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package dev.responsive.kafka.internal.stores;

import dev.responsive.kafka.api.stores.ResponsiveKeyValueParams;
import java.util.Optional;
import java.util.concurrent.TimeoutException;
import org.apache.kafka.streams.processor.StateStoreContext;
import org.apache.kafka.streams.processor.internals.Task;
Expand All @@ -26,7 +27,7 @@ public interface KVOperationsProvider {

KeyValueOperations provide(
final ResponsiveKeyValueParams params,
final boolean isTimestamped,
final Optional<TtlResolver<?, ?>> ttlResolver,
final StateStoreContext context,
final Task.TaskType type
) throws InterruptedException, TimeoutException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public class PartitionedOperations implements KeyValueOperations {

public static PartitionedOperations create(
final TableName name,
final boolean isTimestamped,
final Optional<TtlResolver<?, ?>> ttlResolver,
final StateStoreContext storeContext,
final ResponsiveKeyValueParams params
) throws InterruptedException, TimeoutException {
Expand All @@ -98,12 +98,6 @@ public static PartitionedOperations create(
context.taskId().partition()
);

final Optional<TtlResolver<?, ?>> ttlResolver = TtlResolver.fromTtlProvider(
isTimestamped,
changelog.topic(),
params.ttlProvider()
);

final RemoteKVTable<?> table;
switch (sessionClients.storageBackend()) {
case CASSANDRA:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import dev.responsive.kafka.api.stores.ResponsiveKeyValueParams;
import dev.responsive.kafka.internal.utils.TableName;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.TimeoutException;
import org.apache.kafka.common.utils.Bytes;
import org.apache.kafka.common.utils.LogContext;
Expand All @@ -33,6 +34,8 @@
import org.apache.kafka.streams.query.Position;
import org.apache.kafka.streams.state.KeyValueIterator;
import org.apache.kafka.streams.state.KeyValueStore;
import org.apache.kafka.streams.state.StateSerdes;
import org.apache.kafka.streams.state.internals.StoreAccessorUtil;
import org.apache.kafka.streams.state.internals.StoreQueryUtils;
import org.slf4j.Logger;

Expand Down Expand Up @@ -115,7 +118,13 @@ public void init(final StateStoreContext storeContext, final StateStore root) {
log.warn("Unexpected standby task created, should transition to active shortly");
}

operations = opsProvider.provide(params, isTimestamped, storeContext, taskType);
final StateSerdes<?, ?> stateSerdes = StoreAccessorUtil.extractKeyValueStoreSerdes(root);
final Optional<TtlResolver<?, ?>> ttlResolver = TtlResolver.fromTtlProviderAndStateSerdes(
stateSerdes,
params.ttlProvider()
);

operations = opsProvider.provide(params, ttlResolver, storeContext, taskType);
log.info("Completed initializing state store");

open = true;
Expand All @@ -127,13 +136,13 @@ public void init(final StateStoreContext storeContext, final StateStore root) {

private static KeyValueOperations provideOperations(
final ResponsiveKeyValueParams params,
final boolean isTimestamped,
final Optional<TtlResolver<?, ?>> ttlResolver,
final StateStoreContext context,
final TaskType taskType
) throws InterruptedException, TimeoutException {
return (taskType == TaskType.GLOBAL)
? GlobalOperations.create(context, params)
: PartitionedOperations.create(params.name(), isTimestamped, context, params);
? GlobalOperations.create(context, params, ttlResolver)
: PartitionedOperations.create(params.name(), ttlResolver, context, params);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import dev.responsive.kafka.internal.utils.StateDeserializer;
import java.util.Optional;
import org.apache.kafka.common.utils.Bytes;
import org.apache.kafka.streams.state.StateSerdes;

public class TtlResolver<K, V> {

Expand All @@ -29,27 +30,28 @@ public class TtlResolver<K, V> {
private final StateDeserializer<K, V> stateDeserializer;
private final TtlProvider<K, V> ttlProvider;

public static Optional<TtlResolver<?, ?>> fromTtlProvider(
final boolean isTimestamped,
final String changelogTopic,
@SuppressWarnings("unchecked")
public static <K, V> Optional<TtlResolver<?, ?>> fromTtlProviderAndStateSerdes(
final StateSerdes<?, ?> stateSerdes,
final Optional<TtlProvider<?, ?>> ttlProvider
) {
return ttlProvider.isPresent()
? Optional.of(new TtlResolver<>(isTimestamped, changelogTopic, ttlProvider.get()))
? Optional.of(
new TtlResolver<>(
(StateDeserializer<K, V>) new StateDeserializer<>(
stateSerdes.topic(),
stateSerdes.keyDeserializer(),
stateSerdes.valueDeserializer()),
(TtlProvider<K, V>) ttlProvider.get()
))
: Optional.empty();
}

public TtlResolver(
final boolean isTimestamped,
final String changelogTopic,
final StateDeserializer<K, V> stateDeserializer,
final TtlProvider<K, V> ttlProvider
) {
this.stateDeserializer = new StateDeserializer<>(
isTimestamped,
changelogTopic,
ttlProvider.keySerde(),
ttlProvider.valueSerde()
);
this.stateDeserializer = stateDeserializer;
this.ttlProvider = ttlProvider;
}

Expand Down
Loading

0 comments on commit ee59069

Please sign in to comment.