diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkQueryRewriteHelper.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkQueryRewriteHelper.scala index 8fe8bcc57..5af6a373a 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkQueryRewriteHelper.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkQueryRewriteHelper.scala @@ -5,7 +5,7 @@ package org.opensearch.flint.spark -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, BinaryComparison, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal} +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, BinaryComparison, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Or} /** * Query rewrite helper that provides common utilities for query rewrite rule of various Flint @@ -13,79 +13,156 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, BinaryComparis */ trait FlintSparkQueryRewriteHelper { + /** + * Determines if the given filter expression consists solely of AND operations and no OR + * operations, implying that it's a conjunction of conditions. + * + * @param filter + * The filter expression to check. + * @return + * True if the filter contains only AND operations, False if any OR operations are found. + */ + def isConjunction(filter: Expression): Boolean = { + filter.collectFirst { case Or(_, _) => + true + }.isEmpty + } + /** * Determines if the conditions in an index filter can subsume those in a query filter. This is * essential to verify if all outputs that satisfy the index filter also satisfy the query * filter, indicating that the index can potentially optimize the query. * * @param indexFilter - * The filter expression defined from the index + * The filter expression defined from the index, required to be a conjunction. * @param queryFilter - * The filter expression present in the user query + * The filter expression present in the user query, required to be a conjunction. * @return - * True if the index filter can subsume the query filter, otherwise False + * True if the index filter can subsume the query filter, otherwise False. */ def subsume(indexFilter: Expression, queryFilter: Expression): Boolean = { + require(isConjunction(indexFilter), "Index filter is not a conjunction") + require(isConjunction(queryFilter), "Query filter is not a conjunction") + + // Flatten a potentially nested conjunction into a sequence of individual conditions def flattenConditions(filter: Expression): Seq[Expression] = filter match { case And(left, right) => flattenConditions(left) ++ flattenConditions(right) case other => Seq(other) } - val indexConditions = flattenConditions(indexFilter) val queryConditions = flattenConditions(queryFilter) - // Each index condition must subsume in a query condition - // otherwise it means index data cannot "cover" query condition + // Ensures that every condition in the index filter is subsumed by at least one condition + // in the query filter indexConditions.forall { indexCondition => queryConditions.exists { queryCondition => (indexCondition, queryCondition) match { case ( - i @ BinaryComparison(indexCol: Attribute, _), - q @ BinaryComparison(queryCol: Attribute, _)) if indexCol.name == queryCol.name => - Range(i).subsume(Range(q)) + indexComparison @ BinaryComparison(indexCol: Attribute, _), + queryComparison @ BinaryComparison(queryCol: Attribute, _)) + if indexCol.name == queryCol.name => + Range(indexComparison).subsume(Range(queryComparison)) case _ => false } } } } - case class Bound(value: Literal, inclusive: Boolean) { - - def lessThanOrEqualTo(other: Bound): Boolean = { - val cmp = value.value.asInstanceOf[Comparable[Any]].compareTo(other.value.value) - cmp < 0 || (cmp == 0 && inclusive && other.inclusive) - } - } - + /** + * Represents a range with optional lower and upper bounds. + * + * @param lower + * The optional lower bound + * @param upper + * The optional upper bound + */ case class Range(lower: Option[Bound], upper: Option[Bound]) { + /** + * Determines if this range subsumes (completely covers) another range. A range is considered + * to subsume another if its lower bound is less restrictive and its upper bound is more + * restrictive than those of the other range. + * + * @param other + * The other range to compare against. + * @return + * True if this range subsumes the other, otherwise false. + */ def subsume(other: Range): Boolean = { - val isLowerSubsumed = (lower, other.lower) match { - case (Some(a), Some(b)) => a.lessThanOrEqualTo(b) - case (None, _) => true // `bound1` is unbounded and thus can subsume anything - case (_, None) => false // `bound2` is unbounded and thus cannot be subsumed - case (None, None) => true - } - val isUpperSubsumed = (upper, other.upper) match { - case (Some(a), Some(b)) => b.lessThanOrEqualTo(a) - case (None, _) => true // `bound1` is unbounded and thus can subsume anything - case (_, None) => false // `bound2` is unbounded and thus cannot be subsumed - case (None, None) => true - } - isLowerSubsumed && isUpperSubsumed + // Subsumption check helper for lower and upper bound + def subsume( + thisBound: Option[Bound], + otherBound: Option[Bound], + comp: (Bound, Bound) => Boolean): Boolean = + (thisBound, otherBound) match { + case (Some(a), Some(b)) => comp(a, b) + case (None, _) => true // this is unbounded and thus can subsume any other bound + case (_, None) => false // other is unbounded and thus cannot be subsumed by any + } + subsume(lower, other.lower, _.lessThanOrEqualTo(_)) && + subsume(upper, other.upper, _.greaterThanOrEqualTo(_)) } } object Range { + + /** + * Constructs a Range object from a binary comparison expression, translating comparison + * operators into bounds with appropriate inclusivity. + * + * @param condition + * The binary comparison + */ def apply(condition: BinaryComparison): Range = condition match { - case GreaterThan(_, value: Literal) => Range(Some(Bound(value, inclusive = false)), None) + case GreaterThan(_, value: Literal) => + Range(Some(Bound(value, inclusive = false)), None) case GreaterThanOrEqual(_, value: Literal) => Range(Some(Bound(value, inclusive = true)), None) - case LessThan(_, value: Literal) => Range(None, Some(Bound(value, inclusive = false))) - case LessThanOrEqual(_, value: Literal) => Range(None, Some(Bound(value, inclusive = true))) + case LessThan(_, value: Literal) => + Range(None, Some(Bound(value, inclusive = false))) + case LessThanOrEqual(_, value: Literal) => + Range(None, Some(Bound(value, inclusive = true))) case EqualTo(_, value: Literal) => Range(Some(Bound(value, inclusive = true)), Some(Bound(value, inclusive = true))) case _ => Range(None, None) // For unsupported or complex conditions } } + + /** + * Represents a bound (lower or upper) in a range, defined by a literal value and its + * inclusiveness. + * + * @param value + * The literal value defining the bound. + * @param inclusive + * Indicates whether the bound is inclusive. + */ + case class Bound(value: Literal, inclusive: Boolean) { + + /** + * Checks if this bound is less than or equal to another bound, considering inclusiveness. + * + * @param other + * The bound to compare against. + * @return + * True if this bound is less than or equal to the other bound. + */ + def lessThanOrEqualTo(other: Bound): Boolean = { + val cmp = value.value.asInstanceOf[Comparable[Any]].compareTo(other.value.value) + cmp < 0 || (cmp == 0 && (inclusive || !other.inclusive)) + } + + /** + * Checks if this bound is greater than or equal to another bound, considering inclusiveness. + * + * @param other + * The bound to compare against. + * @return + * True if this bound is greater than or equal to the other bound. + */ + def greaterThanOrEqualTo(other: Bound): Boolean = { + val cmp = value.value.asInstanceOf[Comparable[Any]].compareTo(other.value.value) + cmp > 0 || (cmp == 0 && (inclusive || !other.inclusive)) + } + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndex.scala index 35090a3a9..33c7637e2 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndex.scala @@ -41,7 +41,7 @@ class ApplyFlintSparkCoveringIndex(flint: FlintSpark) } else { // Iterate each sub plan tree in the given plan plan transform { - case subPlan @ Filter(condition, ExtractRelation(relation)) => + case subPlan @ Filter(condition, ExtractRelation(relation)) if isConjunction(condition) => doApply(plan, relation, Some(condition)) .map(newRelation => subPlan.copy(child = newRelation)) .getOrElse(subPlan) diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala index af19b3801..217c1461a 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala @@ -43,6 +43,12 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers { super.beforeAll() sql(s"CREATE TABLE $testTable (name STRING, age INT) USING JSON") sql(s"CREATE TABLE $testTable2 (name STRING) USING JSON") + sql(s""" + | INSERT INTO $testTable + | VALUES + | ('A', 10), ('B', 15), ('C', 20), ('D', 25), ('E', 30), + | ('F', 35), ('G', 40), ('H', 45), ('I', 50), ('J', 55) + | """.stripMargin) // Mock static create method in FlintClientBuilder used by Flint data source clientBuilder @@ -63,50 +69,56 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers { .assertIndexNotUsed(testTable) } - Seq( - ("age = 30", "age = 20", false), - ("age = 30", "age < 20", false), - ("age = 30", "age > 50", false), - ("age > 30 AND age < 60", "age > 20 AND age < 50", false), - ("age > 30", "age >= 30", false), - ("age <= 30", "age <= 20", false), - ("age < 50", "age = 49", false), - ("age <= 50", "age = 50", false), - ("age > 30 AND age < 60", "age > 40 AND age < 50", false), - (null, "age > 30", false), // no query filter - ("age = 30", "age = 30", true), - ("age = 30", "age <= 30", true), - ("age = 30", "age >= 30", true), - ("age = 30", "age > 20 AND age < 50", true), - ("age > 30 AND age < 40", "age > 20 AND age < 50", true), - ("age >= 30", "age > 29", true), - ("age <= 30", "age < 31", true), - ("age > 30", null, true) // no index filter - ).foreach { case (queryFilter, indexFilter, expectedResult) => - test( - s"apply partial covering index with [$indexFilter] to query filter [$queryFilter]: $expectedResult") { - val query = if (queryFilter == null) { - s"SELECT name FROM $testTable" - } else { - s"SELECT name FROM $testTable WHERE $queryFilter" - } - - val assertion = assertFlintQueryRewriter - .withQuery(query) - .withIndex( - new FlintSparkCoveringIndex( - indexName = "partial", - tableName = testTable, - indexedColumns = Map("name" -> "string", "age" -> "int"), - filterCondition = Option(indexFilter))) - - if (expectedResult) { - assertion.assertIndexUsed(getFlintIndexName("partial", testTable)) - } else { - assertion.assertIndexNotUsed(testTable) + // Comprehensive test by cartesian product of the following condition + private val conditions = Seq( + null, + "age = 20", + "age > 20", + "age >= 20", + "age < 20", + "age <= 20", + "age = 50", + "age > 50", + "age >= 50", + "age < 50", + "age <= 50", + "age > 20 AND age < 50", + "age >= 20 AND age < 50", + "age > 20 AND age < 50", + "age >=20 AND age <= 50") + (for { + indexFilter <- conditions + queryFilter <- conditions + } yield (indexFilter, queryFilter)).distinct + .foreach { case (indexFilter, queryFilter) => + test(s"apply partial covering index with [$indexFilter] to query filter [$queryFilter]") { + def queryWithFilter(condition: String): String = + Option(condition) match { + case None => s"SELECT name FROM $testTable" + case Some(cond) => s"SELECT name FROM $testTable WHERE $cond" + } + + // Expect index applied if query result is subset of index data (index filter result) + val queryData = sql(queryWithFilter(queryFilter)).collect().toSet + val indexData = sql(queryWithFilter(indexFilter)).collect().toSet + val expectedResult = queryData.subsetOf(indexData) + + val assertion = assertFlintQueryRewriter + .withQuery(queryWithFilter(queryFilter)) + .withIndex( + new FlintSparkCoveringIndex( + indexName = "partial", + tableName = testTable, + indexedColumns = Map("name" -> "string", "age" -> "int"), + filterCondition = Option(indexFilter))) + + if (expectedResult) { + assertion.assertIndexUsed(getFlintIndexName("partial", testTable)) + } else { + assertion.assertIndexNotUsed(testTable) + } } } - } test("should not apply if covering index is logically deleted") { assertFlintQueryRewriter