Skip to content

Commit

Permalink
Refactor rewrite rule
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed May 1, 2024
1 parent f4416b1 commit 3b6a4a8
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ package org.opensearch.flint.spark.covering
import java.util

import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState.DELETED
import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndex}
import org.opensearch.flint.spark.FlintSpark
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.getFlintIndexName
import org.opensearch.flint.spark.source.{FlintSparkSourceRelation, FlintSparkSourceRelationProvider}

Expand All @@ -30,30 +30,27 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
class ApplyFlintSparkCoveringIndex(flint: FlintSpark) extends Rule[LogicalPlan] {

/** All supported source relation providers */
private val relationProviders = FlintSparkSourceRelationProvider.getProviders(flint.spark)
private val supportedProviders = FlintSparkSourceRelationProvider.getAllProviders(flint.spark)

override def apply(plan: LogicalPlan): LogicalPlan = {
if (plan.isInstanceOf[V2WriteCommand]) {
plan
} else {
plan transform { case subPlan =>
relationProviders
supportedProviders
.collectFirst {
case relationProvider if relationProvider.isSupported(subPlan) =>
val relation = relationProvider.getRelation(subPlan)
case provider if provider.isSupported(subPlan) =>
val relation = provider.getRelation(subPlan)
val relationCols = collectRelationColumnsInQueryPlan(plan, relation)

// Choose the first covering index that meets all criteria above
findAllCoveringIndexesOnTable(relation.tableName)
.sortBy(_.name())
.collectFirst {
case index: FlintSparkCoveringIndex
if isCoveringIndexApplicable(index, relationCols) =>
replaceTableRelationWithIndexRelation(index, relation)
}
.getOrElse(subPlan) // If no index found, return the original relation
.find(index => isCoveringIndexApplicable(index, relationCols))
.map(index => replaceTableRelationWithIndexRelation(index, relation))
.getOrElse(subPlan) // If no index found, return the original node
}
.getOrElse(subPlan)
.getOrElse(subPlan) // If not supported by any provider, return the original node
}
}
}
Expand All @@ -68,6 +65,7 @@ class ApplyFlintSparkCoveringIndex(flint: FlintSpark) extends Rule[LogicalPlan]
val relationColsById = relation.output.map(attr => (attr.exprId, attr)).toMap
plan
.collect {
// Relation interface matches both file and Iceberg relation
case r: MultiInstanceRelation if r.eq(relation.plan) => Set.empty
case other =>
other.expressions
Expand All @@ -81,10 +79,14 @@ class ApplyFlintSparkCoveringIndex(flint: FlintSpark) extends Rule[LogicalPlan]
.toSet
}

private def findAllCoveringIndexesOnTable(tableName: String): Seq[FlintSparkIndex] = {
private def findAllCoveringIndexesOnTable(tableName: String): Seq[FlintSparkCoveringIndex] = {
val qualifiedTableName = qualifyTableName(flint.spark, tableName)
val indexPattern = getFlintIndexName("*", qualifiedTableName)
flint.describeIndexes(indexPattern)
flint
.describeIndexes(indexPattern)
.collect { // cast to covering index
case index: FlintSparkCoveringIndex => index
}
}

private def isCoveringIndexApplicable(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ object FlintSparkSourceRelationProvider {
* @return
* a sequence of source relation provider
*/
def getProviders(spark: SparkSession): Seq[FlintSparkSourceRelationProvider] = {
def getAllProviders(spark: SparkSession): Seq[FlintSparkSourceRelationProvider] = {
var relations = Seq[FlintSparkSourceRelationProvider]()

// File source is built-in supported
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,16 @@

package org.opensearch.flint.spark.source.file

import org.opensearch.flint.spark.source.{FlintSparkSourceRelation, FlintSparkSourceRelationProvider}
import org.opensearch.flint.spark.source.FlintSparkSourceRelation

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources.LogicalRelation

/**
* Concrete source relation implementation for Spark built-in file-based data sources.
*
* @param plan
* the relation plan associated with the file-based data source
* the `LogicalRelation` that represents the plan associated with the File-based table
*/
case class FileSourceRelation(override val plan: LogicalRelation)
extends FlintSparkSourceRelation {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ case class IcebergSourceRelation(override val plan: DataSourceV2Relation)
extends FlintSparkSourceRelation {

override def tableName: String =
plan.table.name() // TODO: confirm
plan.table.name()

override def output: Seq[AttributeReference] = plan.output
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers {
private val testTable = "spark_catalog.default.apply_covering_index_test"
private val testTable2 = "spark_catalog.default.apply_covering_index_test_2"

// Mock FlintClient to avoid looking for real OpenSearch cluster
/** Mock FlintClient to avoid looking for real OpenSearch cluster */
private val clientBuilder = mockStatic(classOf[FlintClientBuilder])
private val client = mock[FlintClient](RETURNS_DEEP_STUBS)

/** Mock FlintSpark which is required by the rule */
/** Mock FlintSpark which is required by the rule. Deep stub required to replace spark val. */
private val flint = mock[FlintSpark](RETURNS_DEEP_STUBS)

/** Instantiate the rule once for all tests */
Expand All @@ -47,7 +47,6 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers {
.when(() => FlintClientBuilder.build(any(classOf[FlintOptions])))
.thenReturn(client)
when(flint.spark).thenReturn(spark)
// when(flint.spark.conf.getOption(any())).thenReturn(None)
}

override protected def afterAll(): Unit = {
Expand Down

0 comments on commit 3b6a4a8

Please sign in to comment.