Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HUDI-7830] Add predicate filter pruning for snapshot queries in hudi related sources #11396

Merged
merged 1 commit into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public UTF8String getPartitionPath(InternalRow row, StructType schema) {
return UTF8String.fromString(getPartitionPath(Option.empty(), Option.empty(), Option.of(Pair.of(row, schema))));
}

private String getPartitionPath(Option<GenericRecord> record, Option<Row> row, Option<Pair<InternalRow, StructType>> internalRowStructTypePair) {
public String getPartitionPath(Option<GenericRecord> record, Option<Row> row, Option<Pair<InternalRow, StructType>> internalRowStructTypePair) {
if (getPartitionPathFields() == null) {
throw new HoodieKeyException("Unable to find field names for partition path in cfg");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ public Pair<Option<Dataset<Row>>, String> fetchNextBatch(Option<String> lastCkpt
queryInfo.getStartInstant()))
.filter(String.format("%s <= '%s'", HoodieRecord.COMMIT_TIME_METADATA_FIELD,
queryInfo.getEndInstant()));
source = queryInfo.getPredicateFilter().map(source::filter).orElse(source);
}

HoodieRecord.HoodieRecordType recordType = createRecordMerger(props).getRecordType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,27 @@ public static class Config {
public static final String SNAPSHOT_LOAD_QUERY_SPLITTER_CLASS_NAME = "hoodie.deltastreamer.snapshotload.query.splitter.class.name";
}

/**
* Checkpoint returned for the SnapshotLoadQuerySplitter.
*/
public static class CheckpointWithPredicates {
String endInstant;
String predicateFilter;

public CheckpointWithPredicates(String endInstant, String predicateFilter) {
this.endInstant = endInstant;
this.predicateFilter = predicateFilter;
}

public String getEndInstant() {
return endInstant;
}

public String getPredicateFilter() {
return predicateFilter;
}
}

/**
* Constructor initializing the properties.
*
Expand All @@ -62,6 +83,15 @@ public SnapshotLoadQuerySplitter(TypedProperties properties) {
this.properties = properties;
}

/**
* Abstract method to retrieve the next checkpoint with predicates.
*
* @param df The dataset to process.
* @param beginCheckpointStr The starting checkpoint string.
* @return The next checkpoint with predicates for partitionPath etc. to optimise snapshot query.
*/
public abstract Option<CheckpointWithPredicates> getNextCheckpointWithPredicates(Dataset<Row> df, String beginCheckpointStr);

/**
* Abstract method to retrieve the next checkpoint.
*
Expand All @@ -83,8 +113,8 @@ public SnapshotLoadQuerySplitter(TypedProperties properties) {
* returning endPoint same as queryInfo.getEndInstant().
*/
public QueryInfo getNextCheckpoint(Dataset<Row> df, QueryInfo queryInfo, Option<SourceProfileSupplier> sourceProfileSupplier) {
return getNextCheckpoint(df, queryInfo.getStartInstant(), sourceProfileSupplier)
.map(checkpoint -> queryInfo.withUpdatedEndInstant(checkpoint))
return getNextCheckpointWithPredicates(df, queryInfo.getStartInstant())
.map(queryInfo::withUpdatedCheckpoint)
.orElse(queryInfo);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@

package org.apache.hudi.utilities.sources.helpers;

import org.apache.hudi.common.util.Option;
import org.apache.hudi.common.util.StringUtils;
import org.apache.hudi.utilities.sources.SnapshotLoadQuerySplitter;

import java.util.Arrays;
import java.util.List;

Expand All @@ -27,12 +31,24 @@
/**
* This class is used to prepare query information for s3 and gcs incr source.
* Some of the information in this class is used for batching based on sourceLimit.
* <p>
* queryType: Incremental or Snapshot query on the hudi table
* previousInstant: instant before startInstant.
* startInstant: start instant for range query
* endInstant: end instant for range query
* predicateFilter: predicate filters on columns to prune partitions and files.
* orderColumn: colum used for ordering results eg: _hoodie_record_key can be used.
* keyColumn: column used for performing range query eg: _hoodie_commit_time > startInstant and _hoodie_commit_time <= endInstant
* limitColumn: limits the numbers of rows returned by query
* orderByColumns: (orderColumn, keyColumn)
* </p>
*/
public class QueryInfo {
private final String queryType;
private final String previousInstant;
private final String startInstant;
private final String endInstant;
private final String predicateFilter;
vinishjail97 marked this conversation as resolved.
Show resolved Hide resolved
private final String orderColumn;
private final String keyColumn;
private final String limitColumn;
Expand All @@ -43,10 +59,32 @@ public QueryInfo(
String startInstant, String endInstant,
String orderColumn, String keyColumn,
String limitColumn) {
this(
queryType,
previousInstant,
startInstant,
endInstant,
StringUtils.EMPTY_STRING,
orderColumn,
keyColumn,
limitColumn
);
}

public QueryInfo(
String queryType,
String previousInstant,
String startInstant,
String endInstant,
String predicateFilter,
vinishjail97 marked this conversation as resolved.
Show resolved Hide resolved
String orderColumn,
String keyColumn,
String limitColumn) {
this.queryType = queryType;
this.previousInstant = previousInstant;
this.startInstant = startInstant;
this.endInstant = endInstant;
this.predicateFilter = predicateFilter;
this.orderColumn = orderColumn;
this.keyColumn = keyColumn;
this.limitColumn = limitColumn;
Expand Down Expand Up @@ -97,6 +135,13 @@ public List<String> getOrderByColumns() {
return orderByColumns;
}

public Option<String> getPredicateFilter() {
if (!StringUtils.isNullOrEmpty(predicateFilter)) {
return Option.of(predicateFilter);
}
return Option.empty();
}

public QueryInfo withUpdatedEndInstant(String newEndInstant) {
return new QueryInfo(
this.queryType,
Expand All @@ -109,6 +154,19 @@ public QueryInfo withUpdatedEndInstant(String newEndInstant) {
);
}

public QueryInfo withUpdatedCheckpoint(SnapshotLoadQuerySplitter.CheckpointWithPredicates checkpointWithPredicates) {
return new QueryInfo(
this.queryType,
this.previousInstant,
this.startInstant,
checkpointWithPredicates.getEndInstant(),
checkpointWithPredicates.getPredicateFilter(),
this.orderColumn,
this.keyColumn,
this.limitColumn
);
}

@Override
public String toString() {
return ("Query information for Incremental Source "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,12 @@ public Pair<QueryInfo, Dataset<Row>> runSnapshotQuery(QueryInfo queryInfo, Optio
}

public Dataset<Row> applySnapshotQueryFilters(Dataset<Row> snapshot, QueryInfo snapshotQueryInfo) {
return snapshot
Dataset<Row> df = snapshot
// add filtering so that only interested records are returned.
.filter(String.format("%s >= '%s'", HoodieRecord.COMMIT_TIME_METADATA_FIELD,
snapshotQueryInfo.getStartInstant()))
.filter(String.format("%s <= '%s'", HoodieRecord.COMMIT_TIME_METADATA_FIELD,
snapshotQueryInfo.getEndInstant()));
return snapshotQueryInfo.getPredicateFilter().map(df::filter).orElse(df);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,23 @@
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.EnumSource;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import java.util.stream.Stream;

import static org.apache.hudi.common.model.HoodieTableType.COPY_ON_WRITE;
import static org.apache.hudi.common.model.HoodieTableType.MERGE_ON_READ;
import static org.apache.hudi.common.model.WriteOperationType.BULK_INSERT;
import static org.apache.hudi.common.model.WriteOperationType.INSERT;
import static org.apache.hudi.common.model.WriteOperationType.UPSERT;
import static org.apache.hudi.common.testutils.HoodieTestUtils.DEFAULT_PARTITION_PATHS;
import static org.apache.hudi.common.testutils.HoodieTestUtils.RAW_TRIPS_TEST_NAME;
import static org.apache.hudi.testutils.Assertions.assertNoWriteErrors;
import static org.junit.jupiter.api.Assertions.assertEquals;
Expand Down Expand Up @@ -367,8 +371,51 @@ public void testHoodieIncrSourceWithDataSourceOptions(HoodieTableType tableType)
}
}

@Test
public void testPartitionPruningInHoodieIncrSource()
throws IOException {
this.tableType = MERGE_ON_READ;
metaClient = getHoodieMetaClient(storageConf(), basePath());
HoodieWriteConfig writeConfig = getConfigBuilder(basePath(), metaClient)
.withArchivalConfig(HoodieArchivalConfig.newBuilder().archiveCommitsWith(10, 12).build())
.withCleanConfig(HoodieCleanConfig.newBuilder().retainCommits(9).build())
.withCompactionConfig(
HoodieCompactionConfig.newBuilder()
.withScheduleInlineCompaction(true)
.withMaxNumDeltaCommitsBeforeCompaction(1)
.build())
.withMetadataConfig(HoodieMetadataConfig.newBuilder().enable(true).build())
.build();
List<Pair<String, List<HoodieRecord>>> inserts = new ArrayList<>();
try (SparkRDDWriteClient writeClient = getHoodieWriteClient(writeConfig)) {
inserts.add(writeRecordsForPartition(writeClient, BULK_INSERT, "100", DEFAULT_PARTITION_PATHS[0]));
inserts.add(writeRecordsForPartition(writeClient, BULK_INSERT, "200", DEFAULT_PARTITION_PATHS[1]));
inserts.add(writeRecordsForPartition(writeClient, BULK_INSERT, "300", DEFAULT_PARTITION_PATHS[2]));
// Go over all possible test cases to assert behaviour.
getArgsForPartitionPruningInHoodieIncrSource().forEach(argumentsStream -> {
Object[] arguments = argumentsStream.get();
String checkpointToPullFromHoodieInstant = (String) arguments[0];
int maxRowsPerSnapshotBatch = (int) arguments[1];
String expectedCheckpointHoodieInstant = (String) arguments[2];
int expectedCount = (int) arguments[3];
int expectedRDDPartitions = (int) arguments[4];

TypedProperties extraProps = new TypedProperties();
extraProps.setProperty(TestSnapshotQuerySplitterImpl.MAX_ROWS_PER_BATCH, String.valueOf(maxRowsPerSnapshotBatch));
readAndAssert(IncrSourceHelper.MissingCheckpointStrategy.READ_UPTO_LATEST_COMMIT,
Option.ofNullable(checkpointToPullFromHoodieInstant),
expectedCount,
expectedCheckpointHoodieInstant,
Option.of(TestSnapshotQuerySplitterImpl.class.getName()),
extraProps,
Option.ofNullable(expectedRDDPartitions)
);
});
}
}

private void readAndAssert(IncrSourceHelper.MissingCheckpointStrategy missingCheckpointStrategy, Option<String> checkpointToPull, int expectedCount,
String expectedCheckpoint, Option<String> snapshotCheckPointImplClassOpt, TypedProperties extraProps) {
String expectedCheckpoint, Option<String> snapshotCheckPointImplClassOpt, TypedProperties extraProps, Option<Integer> expectedRDDPartitions) {

Properties properties = new Properties();
properties.setProperty("hoodie.streamer.source.hoodieincr.path", basePath());
Expand All @@ -388,10 +435,16 @@ private void readAndAssert(IncrSourceHelper.MissingCheckpointStrategy missingChe
assertFalse(batchCheckPoint.getKey().isPresent());
} else {
assertEquals(expectedCount, batchCheckPoint.getKey().get().count());
expectedRDDPartitions.ifPresent(rddPartitions -> assertEquals(rddPartitions, batchCheckPoint.getKey().get().rdd().getNumPartitions()));
}
assertEquals(expectedCheckpoint, batchCheckPoint.getRight());
}

private void readAndAssert(IncrSourceHelper.MissingCheckpointStrategy missingCheckpointStrategy, Option<String> checkpointToPull, int expectedCount,
String expectedCheckpoint, Option<String> snapshotCheckPointImplClassOpt, TypedProperties extraProps) {
readAndAssert(missingCheckpointStrategy, checkpointToPull, expectedCount, expectedCheckpoint, snapshotCheckPointImplClassOpt, extraProps, Option.empty());
}

private void readAndAssert(IncrSourceHelper.MissingCheckpointStrategy missingCheckpointStrategy, Option<String> checkpointToPull,
int expectedCount, String expectedCheckpoint) {
readAndAssert(missingCheckpointStrategy, checkpointToPull, expectedCount, expectedCheckpoint, Option.empty(), new TypedProperties());
Expand All @@ -413,13 +466,39 @@ private Pair<String, List<HoodieRecord>> writeRecords(SparkRDDWriteClient writeC
return Pair.of(commit, records);
}

private Pair<String, List<HoodieRecord>> writeRecordsForPartition(SparkRDDWriteClient writeClient,
WriteOperationType writeOperationType,
String commit,
String partitionPath) {
writeClient.startCommitWithTime(commit);
List<HoodieRecord> records = dataGen.generateInsertsForPartition(commit, 100, partitionPath);
JavaRDD<WriteStatus> result = writeOperationType == WriteOperationType.BULK_INSERT
? writeClient.bulkInsert(jsc().parallelize(records, 1), commit)
: writeClient.upsert(jsc().parallelize(records, 1), commit);
List<WriteStatus> statuses = result.collect();
assertNoWriteErrors(statuses);
return Pair.of(commit, records);
}

private HoodieWriteConfig.Builder getConfigBuilder(String basePath, HoodieTableMetaClient metaClient) {
return HoodieWriteConfig.newBuilder().withPath(basePath).withSchema(HoodieTestDataGenerator.TRIP_EXAMPLE_SCHEMA)
.withParallelism(2, 2).withBulkInsertParallelism(2).withFinalizeWriteParallelism(2).withDeleteParallelism(2)
.withTimelineLayoutVersion(TimelineLayoutVersion.CURR_VERSION)
.forTable(metaClient.getTableConfig().getTableName());
}

private static Stream<Arguments> getArgsForPartitionPruningInHoodieIncrSource() {
// Arguments are in order -> checkpointToPullFromHoodieInstant, maxRowsPerSnapshotBatch, expectedCheckpointHoodieInstant, expectedCount, expectedFileParallelism.
return Stream.of(
Arguments.of(null, 1, "100", 100, 1),
Arguments.of(null, 101, "200", 200, 3),
Arguments.of(null, 10001, "300", 300, 3),
Arguments.of("100", 101, "300", 200, 2),
Arguments.of("200", 101, "300", 100, 1),
Arguments.of("300", 101, "300", 0, 0)
);
}

private static class DummySchemaProvider extends SchemaProvider {

private final Schema schema;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
package org.apache.hudi.utilities.sources.helpers;

import org.apache.hudi.common.config.TypedProperties;
import org.apache.hudi.common.model.HoodieRecord;
import org.apache.hudi.common.util.Option;
import org.apache.hudi.utilities.sources.SnapshotLoadQuerySplitter;
import org.apache.hudi.utilities.streamer.SourceProfileSupplier;
Expand All @@ -29,12 +28,16 @@

import java.util.List;

import static org.apache.hudi.common.model.HoodieRecord.COMMIT_TIME_METADATA_FIELD;
import static org.apache.hudi.common.model.HoodieRecord.PARTITION_PATH_METADATA_FIELD;
import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.lit;
import static org.apache.spark.sql.functions.max;
import static org.apache.spark.sql.functions.min;

public class TestSnapshotQuerySplitterImpl extends SnapshotLoadQuerySplitter {

private static final String COMMIT_TIME_METADATA_FIELD = HoodieRecord.COMMIT_TIME_METADATA_FIELD;
public static final String MAX_ROWS_PER_BATCH = "test.snapshot.load.max.row.count";

/**
* Constructor initializing the properties.
Expand All @@ -51,4 +54,23 @@ public Option<String> getNextCheckpoint(Dataset<Row> df, String beginCheckpointS
.orderBy(col(COMMIT_TIME_METADATA_FIELD)).limit(1).collectAsList();
return Option.ofNullable(row.size() > 0 ? row.get(0).getAs(COMMIT_TIME_METADATA_FIELD) : null);
}

@Override
public Option<CheckpointWithPredicates> getNextCheckpointWithPredicates(Dataset<Row> df, String beginCheckpointStr) {
int maxRowsPerBatch = properties.getInteger(MAX_ROWS_PER_BATCH, 1);
List<Row> row = df.select(col(COMMIT_TIME_METADATA_FIELD)).filter(col(COMMIT_TIME_METADATA_FIELD).gt(lit(beginCheckpointStr)))
.orderBy(col(COMMIT_TIME_METADATA_FIELD)).limit(maxRowsPerBatch).collectAsList();
if (!row.isEmpty()) {
String endInstant = row.get(row.size() - 1).getAs(COMMIT_TIME_METADATA_FIELD);
List<Row> minMax =
df.filter(col(COMMIT_TIME_METADATA_FIELD).gt(lit(beginCheckpointStr)))
.filter(col(COMMIT_TIME_METADATA_FIELD).leq(endInstant))
.select(PARTITION_PATH_METADATA_FIELD).agg(min(PARTITION_PATH_METADATA_FIELD).alias("min_partition_path"), max(PARTITION_PATH_METADATA_FIELD).alias("max_partition_path"))
.collectAsList();
String partitionFilter = String.format("partition_path >= '%s' and partition_path <= '%s'", minMax.get(0).getAs("min_partition_path"), minMax.get(0).getAs("max_partition_path"));
return Option.of(new CheckpointWithPredicates(endInstant, partitionFilter));
} else {
return Option.empty();
}
}
}
Loading