Skip to content

Commit

Permalink
Fix collecting required column logic
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Apr 18, 2024
1 parent 513e490 commit d2a4a4f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
package org.opensearch.flint.spark.covering

import java.util

import org.opensearch.flint.spark.FlintSpark
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.getFlintIndexName

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.NamedExpression.newExprId
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V2WriteCommand}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.LogicalRelation
Expand Down Expand Up @@ -46,8 +44,16 @@ class ApplyFlintSparkCoveringIndex(flint: FlintSpark) extends Rule[LogicalPlan]
val indexPattern = getFlintIndexName("*", qualifiedTableName)
val indexes = flint.describeIndexes(indexPattern)

// Collect all columns needed by the query from the relation
val requiredCols = plan.output.map(_.name).toSet
// Collect all columns needed by the query except those in relation. This is because this rule
// executes before push down optimization, relation includes all columns in the table.
val requiredCols =
plan
.collect {
case _: LogicalRelation => Set.empty[String]
case other => other.expressions.flatMap(_.references).map(_.name).toSet
}
.flatten
.toSet

// Choose the first covering index that meets all criteria
indexes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ import org.scalatestplus.mockito.MockitoSugar.mock
import org.apache.spark.FlintSuite
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parseExpression
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.sources.BaseRelation
Expand Down Expand Up @@ -49,30 +50,32 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers {
super.afterAll()
}

test("should not apply when no index present") {
test("should not apply if no index present") {
assertFlintQueryRewriter
.withTable(StructType.fromDDL("name STRING, age INT"))
.assertIndexNotUsed()
}

test("should not apply when index on other table") {
test("should not apply if some columns in project are not covered") {
assertFlintQueryRewriter
.withTable(StructType.fromDDL("name STRING, age INT"))
.withTable(StructType.fromDDL("name STRING, age INT, city STRING"))
.withProject("name", "age", "city") // city is not covered
.withIndex(
new FlintSparkCoveringIndex(
indexName = "all",
tableName = s"other_s$testTable",
indexedColumns = Map("city" -> "string")))
indexName = "partial",
tableName = testTable,
indexedColumns = Map("name" -> "string", "age" -> "int")))
.assertIndexNotUsed()
}

test("should not apply when some columns are not covered") {
test("should not apply if some columns in filter are not covered") {
assertFlintQueryRewriter
.withTable(StructType.fromDDL("name STRING, age INT, city STRING"))
.withProject("name", "age", "city") // city is not covered
.withFilter("city = 'Seattle'")
.withProject("name", "age")
.withIndex(
new FlintSparkCoveringIndex(
indexName = "partial",
indexName = "all",
tableName = testTable,
indexedColumns = Map("name" -> "string", "age" -> "int")))
.assertIndexNotUsed()
Expand Down Expand Up @@ -114,6 +117,11 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers {
this
}

def withFilter(predicate: String): AssertionHelper = {
this.plan = Filter(parseExpression(predicate), plan)
this
}

def withIndex(index: FlintSparkCoveringIndex): AssertionHelper = {
this.indexes = indexes :+ index
this
Expand Down

0 comments on commit d2a4a4f

Please sign in to comment.