From f7aaa4187aec3adcac750707bd11ca1f73087865 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Mon, 15 Jul 2024 16:11:02 -0700 Subject: [PATCH] Enhance query rewriter rule to support partial covering index (#409) * Refactor rewrite rule to support partial indexing Signed-off-by: Chen Dai * Add more UT Signed-off-by: Chen Dai * Fix bound comparison bug and refactor UT Signed-off-by: Chen Dai * Add IT Signed-off-by: Chen Dai * Fix select all bug Signed-off-by: Chen Dai * Refactor bound Signed-off-by: Chen Dai * Add more UT Signed-off-by: Chen Dai --------- Signed-off-by: Chen Dai --- .../spark/FlintSparkQueryRewriteHelper.scala | 177 ++++++++++++++++++ .../ApplyFlintSparkCoveringIndex.scala | 112 +++++++---- .../ApplyFlintSparkCoveringIndexSuite.scala | 71 ++++++- .../FlintSparkCoveringIndexSqlITSuite.scala | 26 +++ 4 files changed, 342 insertions(+), 44 deletions(-) create mode 100644 flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkQueryRewriteHelper.scala diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkQueryRewriteHelper.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkQueryRewriteHelper.scala new file mode 100644 index 000000000..fede39a8c --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkQueryRewriteHelper.scala @@ -0,0 +1,177 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import org.apache.spark.sql.catalyst.expressions._ + +/** + * Query rewrite helper that provides common utilities for query rewrite rule of various Flint + * indexes. + */ +trait FlintSparkQueryRewriteHelper { + + /** + * Determines if the given filter expression consists solely of AND operations and no OR + * operations, implying that it's a conjunction of conditions. + * + * @param filter + * The filter expression to check. + * @return + * True if the filter contains only AND operations, False if any OR operations are found. + */ + def isConjunction(filter: Expression): Boolean = { + filter.collectFirst { case Or(_, _) => + true + }.isEmpty + } + + /** + * Determines if the conditions in an index filter can subsume those in a query filter. This is + * essential to verify if all outputs that satisfy the index filter also satisfy the query + * filter, indicating that the index can potentially optimize the query. + * + * @param indexFilter + * The filter expression defined from the index, required to be a conjunction. + * @param queryFilter + * The filter expression present in the user query, required to be a conjunction. + * @return + * True if the index filter can subsume the query filter, otherwise False. + */ + def subsume(indexFilter: Expression, queryFilter: Expression): Boolean = { + if (!isConjunction(indexFilter) || !isConjunction(queryFilter)) { + return false + } + + // Flatten a potentially nested conjunction into a sequence of individual conditions + def flattenConditions(filter: Expression): Seq[Expression] = filter match { + case And(left, right) => flattenConditions(left) ++ flattenConditions(right) + case other => Seq(other) + } + val indexConditions = flattenConditions(indexFilter) + val queryConditions = flattenConditions(queryFilter) + + // Ensures that every condition in the index filter is subsumed by at least one condition + // on the same column in the query filter + indexConditions.forall { indexCond => + queryConditions.exists { queryCond => + (indexCond, queryCond) match { + case ( + indexComp @ BinaryComparison(indexCol: Attribute, _), + queryComp @ BinaryComparison(queryCol: Attribute, _)) + if indexCol.name == queryCol.name => + Range(indexComp).subsume(Range(queryComp)) + case _ => false // consider as not subsumed for unsupported expression + } + } + } + } + + /** + * Represents a range with optional lower and upper bounds. + * + * @param lower + * The optional lower bound + * @param upper + * The optional upper bound + */ + private case class Range(lower: Option[Bound], upper: Option[Bound]) { + + /** + * Determines if this range subsumes (completely covers) another range. + * + * @param other + * The other range to compare against. + * @return + * True if this range subsumes the other, otherwise false. + */ + def subsume(other: Range): Boolean = { + // Unknown range cannot subsume or be subsumed by any + if (this == Range.UNKNOWN || other == Range.UNKNOWN) { + return false + } + + // Subsumption check helper for lower and upper bound + def subsumeHelper( + thisBound: Option[Bound], + otherBound: Option[Bound], + comp: (Bound, Bound) => Boolean): Boolean = + (thisBound, otherBound) match { + case (Some(a), Some(b)) => comp(a, b) + case (None, _) => true // this is unbounded and thus can subsume any other bound + case (_, None) => false // other is unbounded and thus cannot be subsumed by any + } + subsumeHelper(lower, other.lower, _.lessThanOrEqualTo(_)) && + subsumeHelper(upper, other.upper, _.greaterThanOrEqualTo(_)) + } + } + + private object Range { + + /** Unknown range for unsupported binary comparison expression */ + private val UNKNOWN: Range = Range(None, None) + + /** + * Constructs a Range object from a binary comparison expression, translating comparison + * operators into bounds with appropriate inclusiveness. + * + * @param condition + * The binary comparison + */ + def apply(condition: BinaryComparison): Range = condition match { + case GreaterThan(_, Literal(value: Comparable[Any], _)) => + Range(Some(Bound(value, inclusive = false)), None) + case GreaterThanOrEqual(_, Literal(value: Comparable[Any], _)) => + Range(Some(Bound(value, inclusive = true)), None) + case LessThan(_, Literal(value: Comparable[Any], _)) => + Range(None, Some(Bound(value, inclusive = false))) + case LessThanOrEqual(_, Literal(value: Comparable[Any], _)) => + Range(None, Some(Bound(value, inclusive = true))) + case EqualTo(_, Literal(value: Comparable[Any], _)) => + Range(Some(Bound(value, inclusive = true)), Some(Bound(value, inclusive = true))) + case _ => UNKNOWN + } + } + + /** + * Represents a bound (lower or upper) in a range, defined by a literal value and its + * inclusiveness. + * + * @param value + * The literal value defining the bound. + * @param inclusive + * Indicates whether the bound is inclusive. + */ + private case class Bound(value: Comparable[Any], inclusive: Boolean) { + + /** + * Checks if this bound is less than or equal to another bound, considering inclusiveness. + * + * @param other + * The bound to compare against. + * @return + * True if this bound is less than or equal to the other bound, either because value is + * smaller, this bound is inclusive, or both bound are exclusive. + */ + def lessThanOrEqualTo(other: Bound): Boolean = { + val cmp = value.compareTo(other.value) + cmp < 0 || (cmp == 0 && (inclusive || !other.inclusive)) + } + + /** + * Checks if this bound is greater than or equal to another bound, considering inclusiveness. + * + * @param other + * The bound to compare against. + * @return + * True if this bound is greater than or equal to the other bound, either because value is + * greater, this bound is inclusive, or both bound are exclusive. + */ + def greaterThanOrEqualTo(other: Bound): Boolean = { + val cmp = value.compareTo(other.value) + cmp > 0 || (cmp == 0 && (inclusive || !other.inclusive)) + } + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndex.scala index 9217495e6..90cbdd0da 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndex.scala @@ -8,13 +8,14 @@ package org.opensearch.flint.spark.covering import java.util import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry.IndexState.DELETED -import org.opensearch.flint.spark.FlintSpark +import org.opensearch.flint.spark.{FlintSpark, FlintSparkQueryRewriteHelper} import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.getFlintIndexName import org.opensearch.flint.spark.source.{FlintSparkSourceRelation, FlintSparkSourceRelationProvider} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V2WriteCommand} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parseExpression +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, V2WriteCommand} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.flint.{qualifyTableName, FlintDataSourceV2} @@ -27,7 +28,9 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap * @param flint * Flint Spark API */ -class ApplyFlintSparkCoveringIndex(flint: FlintSpark) extends Rule[LogicalPlan] { +class ApplyFlintSparkCoveringIndex(flint: FlintSpark) + extends Rule[LogicalPlan] + with FlintSparkQueryRewriteHelper { /** All supported source relation providers */ private val supportedProviders = FlintSparkSourceRelationProvider.getAllProviders(flint.spark) @@ -37,22 +40,38 @@ class ApplyFlintSparkCoveringIndex(flint: FlintSpark) extends Rule[LogicalPlan] plan } else { // Iterate each sub plan tree in the given plan - plan transform { case subPlan => - supportedProviders - .collectFirst { - case provider if provider.isSupported(subPlan) => - logInfo(s"Provider [${provider.name()}] can match sub plan ${subPlan.nodeName}") - val relation = provider.getRelation(subPlan) - val relationCols = collectRelationColumnsInQueryPlan(plan, relation) - - // Choose the first covering index that meets all criteria above - findAllCoveringIndexesOnTable(relation.tableName) - .sortBy(_.name()) - .find(index => isCoveringIndexApplicable(index, relationCols)) - .map(index => replaceTableRelationWithIndexRelation(index, relation)) - .getOrElse(subPlan) // If no index found, return the original node - } - .getOrElse(subPlan) // If not supported by any provider, return the original node + plan transform { + case filter @ Filter(condition, Relation(sourceRelation)) => + doApply(plan, sourceRelation, Some(condition)) + .map(newRelation => filter.copy(child = newRelation)) + .getOrElse(filter) + case relation @ Relation(sourceRelation) => + doApply(plan, sourceRelation, None) + .getOrElse(relation) + } + } + } + + private def doApply( + plan: LogicalPlan, + relation: FlintSparkSourceRelation, + queryFilter: Option[Expression]): Option[LogicalPlan] = { + val relationCols = collectRelationColumnsInQueryPlan(plan, relation) + + // Choose the first covering index that meets all criteria above + findAllCoveringIndexesOnTable(relation.tableName) + .sortBy(_.name()) + .find(index => isCoveringIndexApplicable(index, queryFilter, relationCols)) + .map(index => replaceTableRelationWithIndexRelation(index, relation)) + } + + private object Relation { + def unapply(subPlan: LogicalPlan): Option[FlintSparkSourceRelation] = { + // Check if any source relation can support the plan node + supportedProviders.collectFirst { + case provider if provider.isSupported(subPlan) => + logInfo(s"Provider [${provider.name()}] can match plan ${subPlan.nodeName}") + provider.getRelation(subPlan) } } } @@ -65,20 +84,28 @@ class ApplyFlintSparkCoveringIndex(flint: FlintSpark) extends Rule[LogicalPlan] * Because this rule executes before push down optimization, relation includes all columns. */ val relationColsById = relation.output.map(attr => (attr.exprId, attr)).toMap - plan - .collect { - // Relation interface matches both file and Iceberg relation - case r: MultiInstanceRelation if r.eq(relation.plan) => Set.empty - case other => - other.expressions - .flatMap(_.references) - .flatMap(ref => { - relationColsById.get(ref.exprId) - }) // Ignore attribute not belong to current relation being rewritten - .map(attr => attr.name) - } - .flatten - .toSet + val relationCols = + plan + .collect { + // Relation interface matches both file and Iceberg relation + case r: MultiInstanceRelation if r.eq(relation.plan) => Set.empty + case other => + other.expressions + .flatMap(_.references) + .flatMap(ref => { + relationColsById.get(ref.exprId) + }) // Ignore attribute not belong to current relation being rewritten + .map(attr => attr.name) + } + .flatten + .toSet + + if (relationCols.isEmpty) { + // Return all if plan only has relation operator, e.g. SELECT * or all columns + relationColsById.values.map(_.name).toSet + } else { + relationCols + } } private def findAllCoveringIndexesOnTable(tableName: String): Seq[FlintSparkCoveringIndex] = { @@ -87,8 +114,8 @@ class ApplyFlintSparkCoveringIndex(flint: FlintSpark) extends Rule[LogicalPlan] val indexes = flint .describeIndexes(indexPattern) - .collect { // cast to covering index - case index: FlintSparkCoveringIndex => index + .collect { // cast to covering index and double check table name + case index: FlintSparkCoveringIndex if index.tableName == qualifiedTableName => index } val indexNames = indexes.map(_.name()).mkString(",") @@ -98,17 +125,26 @@ class ApplyFlintSparkCoveringIndex(flint: FlintSpark) extends Rule[LogicalPlan] private def isCoveringIndexApplicable( index: FlintSparkCoveringIndex, + queryFilter: Option[Expression], relationCols: Set[String]): Boolean = { val indexedCols = index.indexedColumns.keySet + val subsumption = (index.filterCondition, queryFilter) match { + case (None, _) => true // full index can cover any query filter + case (Some(_), None) => false // partial index cannot cover query without filter + case (Some(indexFilter), Some(_)) => + subsume(parseExpression(indexFilter), queryFilter.get) + } val isApplicable = index.latestLogEntry.exists(_.state != DELETED) && - index.filterCondition.isEmpty && // TODO: support partial covering index later + subsumption && relationCols.subsetOf(indexedCols) logInfo(s""" | Is covering index ${index.name()} applicable: $isApplicable | Index state: ${index.latestLogEntry.map(_.state)} - | Index filter condition: ${index.filterCondition} + | Query filter: $queryFilter + | Index filter: ${index.filterCondition} + | Subsumption test: $subsumption | Columns required: $relationCols | Columns indexed: $indexedCols |""".stripMargin) diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala index 917d5aee7..8ede40f86 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala @@ -43,6 +43,12 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers { super.beforeAll() sql(s"CREATE TABLE $testTable (name STRING, age INT) USING JSON") sql(s"CREATE TABLE $testTable2 (name STRING) USING JSON") + sql(s""" + | INSERT INTO $testTable + | VALUES + | ('A', 10), ('B', 15), ('C', 20), ('D', 25), ('E', 30), + | ('F', 35), ('G', 40), ('H', 45), ('I', 50), ('J', 55) + | """.stripMargin) // Mock static create method in FlintClientBuilder used by Flint data source clientBuilder @@ -63,15 +69,66 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers { .assertIndexNotUsed(testTable) } - test("should not apply if covering index is partial") { + // Comprehensive test by cartesian product of the following condition + private val conditions = Seq( + null, + "age = 20", + "age > 20", + "age >= 20", + "age < 20", + "age <= 20", + "age = 50", + "age > 50", + "age >= 50", + "age < 50", + "age <= 50", + "age > 20 AND age < 50", + "age >= 20 AND age < 50", + "age > 20 AND age <= 50", + "age >=20 AND age <= 50") + (for { + indexFilter <- conditions + queryFilter <- conditions + } yield (indexFilter, queryFilter)).distinct + .foreach { case (indexFilter, queryFilter) => + test(s"apply partial covering index with [$indexFilter] to query filter [$queryFilter]") { + def queryWithFilter(condition: String): String = + Option(condition) match { + case None => s"SELECT name FROM $testTable" + case Some(cond) => s"SELECT name FROM $testTable WHERE $cond" + } + + // Expect index applied if query result is subset of index data (index filter result) + val queryData = sql(queryWithFilter(queryFilter)).collect().toSet + val indexData = sql(queryWithFilter(indexFilter)).collect().toSet + val expectedResult = queryData.subsetOf(indexData) + + val assertion = assertFlintQueryRewriter + .withQuery(queryWithFilter(queryFilter)) + .withIndex( + new FlintSparkCoveringIndex( + indexName = "partial", + tableName = testTable, + indexedColumns = Map("name" -> "string", "age" -> "int"), + filterCondition = Option(indexFilter))) + + if (expectedResult) { + assertion.assertIndexUsed(getFlintIndexName("partial", testTable)) + } else { + assertion.assertIndexNotUsed(testTable) + } + } + } + + test("should not apply if covering index with disjunction filtering condition") { assertFlintQueryRewriter - .withQuery(s"SELECT name FROM $testTable") + .withQuery(s"SELECT name FROM $testTable WHERE name = 'A' AND age > 30") .withIndex( new FlintSparkCoveringIndex( - indexName = "name", + indexName = "partial", tableName = testTable, - indexedColumns = Map("name" -> "string"), - filterCondition = Some("age > 30"))) + indexedColumns = Map("name" -> "string", "age" -> "int"), + filterCondition = Some("name = 'A' OR age > 30"))) .assertIndexNotUsed(testTable) } @@ -111,6 +168,8 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers { s"SELECT name, age FROM $testTable", s"SELECT age, name FROM $testTable", s"SELECT name FROM $testTable WHERE age = 30", + s"SELECT name FROM $testTable WHERE name = 'A' AND age = 30", + s"SELECT name FROM $testTable WHERE name = 'A' OR age = 30", s"SELECT SUBSTR(name, 1) FROM $testTable WHERE ABS(age) = 30", s"SELECT COUNT(*) FROM $testTable GROUP BY age", s"SELECT name, COUNT(*) FROM $testTable WHERE age > 30 GROUP BY name", @@ -167,7 +226,7 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers { private var indexes: Seq[FlintSparkCoveringIndex] = Seq() def withQuery(query: String): AssertionHelper = { - this.plan = sql(query).queryExecution.analyzed + this.plan = sql(query).queryExecution.optimizedPlan this } diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala index db14e395b..ffd956b1c 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala @@ -287,6 +287,19 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite { checkAnswer(sql(query), Seq(Row("Hello", 30), Row("World", 25))) } + test("rewrite applicable simple query with partial covering index") { + awaitRefreshComplete(s""" + | CREATE INDEX $testIndex ON $testTable + | (name, age) + | WHERE age > 25 + | WITH (auto_refresh = true) + | """.stripMargin) + + val query = s"SELECT name, age FROM $testTable WHERE age >= 30" + checkKeywordsExist(sql(s"EXPLAIN $query"), "FlintScan") + checkAnswer(sql(query), Seq(Row("Hello", 30))) + } + test("rewrite applicable aggregate query with covering index") { awaitRefreshComplete(s""" | CREATE INDEX $testIndex ON $testTable @@ -320,6 +333,19 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite { } } + test("should not rewrite with partial covering index if not applicable") { + awaitRefreshComplete(s""" + | CREATE INDEX $testIndex ON $testTable + | (name, age) + | WHERE age > 25 + | WITH (auto_refresh = true) + | """.stripMargin) + + val query = s"SELECT name, age FROM $testTable WHERE age > 20" + checkKeywordsNotExist(sql(s"EXPLAIN $query"), "FlintScan") + checkAnswer(sql(query), Seq(Row("Hello", 30), Row("World", 25))) + } + test("rewrite applicable query with covering index before skipping index") { try { sql(s"""