diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkQueryRewriteHelper.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkQueryRewriteHelper.scala index 65f1311be..d04072fd3 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkQueryRewriteHelper.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkQueryRewriteHelper.scala @@ -41,8 +41,9 @@ trait FlintSparkQueryRewriteHelper { * True if the index filter can subsume the query filter, otherwise False. */ def subsume(indexFilter: Expression, queryFilter: Expression): Boolean = { - require(isConjunction(indexFilter), "Index filter is not a conjunction") - require(isConjunction(queryFilter), "Query filter is not a conjunction") + if (!isConjunction(indexFilter) || !isConjunction(queryFilter)) { + return false + } // Flatten a potentially nested conjunction into a sequence of individual conditions def flattenConditions(filter: Expression): Seq[Expression] = filter match { diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndex.scala index 56d8404a6..636092c7c 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndex.scala @@ -41,13 +41,13 @@ class ApplyFlintSparkCoveringIndex(flint: FlintSpark) } else { // Iterate each sub plan tree in the given plan plan transform { - case subPlan @ Filter(condition, ExtractRelation(relation)) if isConjunction(condition) => - doApply(plan, relation, Some(condition)) - .map(newRelation => subPlan.copy(child = newRelation)) - .getOrElse(subPlan) - case subPlan @ ExtractRelation(relation) => - doApply(plan, relation, None) - .getOrElse(subPlan) + case filter @ Filter(condition, Relation(sourceRelation)) if isConjunction(condition) => + doApply(plan, sourceRelation, Some(condition)) + .map(newRelation => filter.copy(child = newRelation)) + .getOrElse(filter) + case relation @ Relation(sourceRelation) => + doApply(plan, sourceRelation, None) + .getOrElse(relation) } } } @@ -65,8 +65,9 @@ class ApplyFlintSparkCoveringIndex(flint: FlintSpark) .map(index => replaceTableRelationWithIndexRelation(index, relation)) } - private object ExtractRelation { + private object Relation { def unapply(subPlan: LogicalPlan): Option[FlintSparkSourceRelation] = { + // Check if any source relation can support the plan node supportedProviders.collectFirst { case provider if provider.isSupported(subPlan) => logInfo(s"Provider [${provider.name()}] can match plan ${subPlan.nodeName}") @@ -83,20 +84,28 @@ class ApplyFlintSparkCoveringIndex(flint: FlintSpark) * Because this rule executes before push down optimization, relation includes all columns. */ val relationColsById = relation.output.map(attr => (attr.exprId, attr)).toMap - plan - .collect { - // Relation interface matches both file and Iceberg relation - case r: MultiInstanceRelation if r.eq(relation.plan) => Set.empty - case other => - other.expressions - .flatMap(_.references) - .flatMap(ref => { - relationColsById.get(ref.exprId) - }) // Ignore attribute not belong to current relation being rewritten - .map(attr => attr.name) - } - .flatten - .toSet + val relationCols = + plan + .collect { + // Relation interface matches both file and Iceberg relation + case r: MultiInstanceRelation if r.eq(relation.plan) => Set.empty + case other => + other.expressions + .flatMap(_.references) + .flatMap(ref => { + relationColsById.get(ref.exprId) + }) // Ignore attribute not belong to current relation being rewritten + .map(attr => attr.name) + } + .flatten + .toSet + + if (relationCols.isEmpty) { + // Return all if plan only has relation operator, e.g. SELECT * or all columns + relationColsById.values.map(_.name).toSet + } else { + relationCols + } } private def findAllCoveringIndexesOnTable(tableName: String): Seq[FlintSparkCoveringIndex] = { diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala index 217c1461a..0d0bc805e 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala @@ -84,7 +84,7 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers { "age <= 50", "age > 20 AND age < 50", "age >= 20 AND age < 50", - "age > 20 AND age < 50", + "age > 20 AND age <= 50", "age >=20 AND age <= 50") (for { indexFilter <- conditions @@ -134,8 +134,8 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers { // Covering index doesn't cover column age Seq( - // s"SELECT * FROM $testTable", // FIXME: only relation operator - // s"SELECT name, age FROM $testTable", // FIXME: only relation operator + s"SELECT * FROM $testTable", + s"SELECT name, age FROM $testTable", s"SELECT name FROM $testTable WHERE age = 30", s"SELECT COUNT(*) FROM $testTable GROUP BY age").foreach { query => test(s"should not apply if column is not covered in $query") {