Skip to content

Commit

Permalink
Merge branch 'main' into improve-error-handling
Browse files Browse the repository at this point in the history
  • Loading branch information
dai-chen committed Jun 26, 2024
2 parents 6e72473 + beac01a commit ed105c9
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ package org.opensearch.flint.spark.covering
import java.util

import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry.IndexState.DELETED
import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndex}
import org.opensearch.flint.spark.FlintSpark
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.getFlintIndexName
import org.opensearch.flint.spark.source.{FlintSparkSourceRelation, FlintSparkSourceRelationProvider}

import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V2WriteCommand}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.flint.{qualifyTableName, FlintDataSourceV2}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
Expand All @@ -28,60 +29,95 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
*/
class ApplyFlintSparkCoveringIndex(flint: FlintSpark) extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case relation @ LogicalRelation(_, _, Some(table), false)
if !plan.isInstanceOf[V2WriteCommand] => // TODO: make sure only intercept SELECT query
val relationCols = collectRelationColumnsInQueryPlan(relation, plan)

// Choose the first covering index that meets all criteria above
findAllCoveringIndexesOnTable(table.qualifiedName)
.sortBy(_.name())
.collectFirst {
case index: FlintSparkCoveringIndex if isCoveringIndexApplicable(index, relationCols) =>
replaceTableRelationWithIndexRelation(index, relation)
}
.getOrElse(relation) // If no index found, return the original relation
/** All supported source relation providers */
private val supportedProviders = FlintSparkSourceRelationProvider.getAllProviders(flint.spark)

override def apply(plan: LogicalPlan): LogicalPlan = {
if (plan.isInstanceOf[V2WriteCommand]) { // TODO: bypass any non-select plan
plan
} else {
// Iterate each sub plan tree in the given plan
plan transform { case subPlan =>
supportedProviders
.collectFirst {
case provider if provider.isSupported(subPlan) =>
logInfo(s"Provider [${provider.name()}] can match sub plan ${subPlan.nodeName}")
val relation = provider.getRelation(subPlan)
val relationCols = collectRelationColumnsInQueryPlan(plan, relation)

// Choose the first covering index that meets all criteria above
findAllCoveringIndexesOnTable(relation.tableName)
.sortBy(_.name())
.find(index => isCoveringIndexApplicable(index, relationCols))
.map(index => replaceTableRelationWithIndexRelation(index, relation))
.getOrElse(subPlan) // If no index found, return the original node
}
.getOrElse(subPlan) // If not supported by any provider, return the original node
}
}
}

private def collectRelationColumnsInQueryPlan(
relation: LogicalRelation,
plan: LogicalPlan): Set[String] = {
plan: LogicalPlan,
relation: FlintSparkSourceRelation): Set[String] = {
/*
* Collect all columns of the relation present in query plan, except those in relation itself.
* Because this rule executes before push down optimization, relation includes all columns.
*/
val relationColsById = relation.output.map(attr => (attr.exprId, attr)).toMap
plan
.collect {
case _: LogicalRelation => Set.empty
// Relation interface matches both file and Iceberg relation
case r: MultiInstanceRelation if r.eq(relation.plan) => Set.empty
case other =>
other.expressions
.flatMap(_.references)
.flatMap(ref =>
relationColsById.get(ref.exprId)) // Ignore attribute not belong to target relation
.flatMap(ref => {
relationColsById.get(ref.exprId)
}) // Ignore attribute not belong to current relation being rewritten
.map(attr => attr.name)
}
.flatten
.toSet
}

private def findAllCoveringIndexesOnTable(tableName: String): Seq[FlintSparkIndex] = {
private def findAllCoveringIndexesOnTable(tableName: String): Seq[FlintSparkCoveringIndex] = {
val qualifiedTableName = qualifyTableName(flint.spark, tableName)
val indexPattern = getFlintIndexName("*", qualifiedTableName)
flint.describeIndexes(indexPattern)
val indexes =
flint
.describeIndexes(indexPattern)
.collect { // cast to covering index
case index: FlintSparkCoveringIndex => index
}

val indexNames = indexes.map(_.name()).mkString(",")
logInfo(s"Found covering index [$indexNames] on table $qualifiedTableName")
indexes
}

private def isCoveringIndexApplicable(
index: FlintSparkCoveringIndex,
relationCols: Set[String]): Boolean = {
index.latestLogEntry.exists(_.state != DELETED) &&
index.filterCondition.isEmpty && // TODO: support partial covering index later
relationCols.subsetOf(index.indexedColumns.keySet)
val indexedCols = index.indexedColumns.keySet
val isApplicable =
index.latestLogEntry.exists(_.state != DELETED) &&
index.filterCondition.isEmpty && // TODO: support partial covering index later
relationCols.subsetOf(indexedCols)

logInfo(s"""
| Is covering index ${index.name()} applicable: $isApplicable
| Index state: ${index.latestLogEntry.map(_.state)}
| Index filter condition: ${index.filterCondition}
| Columns required: $relationCols
| Columns indexed: $indexedCols
|""".stripMargin)
isApplicable
}

private def replaceTableRelationWithIndexRelation(
index: FlintSparkCoveringIndex,
relation: LogicalRelation): LogicalPlan = {
relation: FlintSparkSourceRelation): LogicalPlan = {
// Make use of data source relation to avoid Spark looking for OpenSearch index in catalog
val ds = new FlintDataSourceV2
val options = new CaseInsensitiveStringMap(util.Map.of("path", index.name()))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.source

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

/**
* This source relation abstraction allows Flint to interact uniformly with different kinds of
* source data formats (like Spark built-in File, Delta table, Iceberg, etc.), hiding the
* specifics of each data source implementation.
*/
trait FlintSparkSourceRelation {

/**
* @return
* the concrete logical plan of the relation associated
*/
def plan: LogicalPlan

/**
* @return
* fully qualified table name represented by the relation
*/
def tableName: String

/**
* @return
* output column list of the relation
*/
def output: Seq[AttributeReference]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.source

import org.opensearch.flint.spark.source.file.FileSourceRelationProvider

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

/**
* A provider defines what kind of logical plan can be supported by Flint Spark integration. It
* serves similar purpose to Scala extractor which has to be used in match case statement.
* However, the problem here is we want to avoid hard dependency on some data source code, such as
* Iceberg. In this case, we have to maintain a list of provider and run it only if the 3rd party
* library is available in current Spark session.
*/
trait FlintSparkSourceRelationProvider {

/**
* @return
* the name of the source relation provider
*/
def name(): String

/**
* Determines whether the given logical plan is supported by this provider.
*
* @param plan
* the logical plan to evaluate
* @return
* true if the plan is supported, false otherwise
*/
def isSupported(plan: LogicalPlan): Boolean

/**
* Creates a source relation based on the provided logical plan.
*
* @param plan
* the logical plan to wrap in source relation
* @return
* an instance of source relation
*/
def getRelation(plan: LogicalPlan): FlintSparkSourceRelation
}

/**
* Companion object provides utility methods.
*/
object FlintSparkSourceRelationProvider extends Logging {

/**
* Retrieve all supported source relation provider for the given Spark session.
*
* @param spark
* the Spark session
* @return
* a sequence of source relation provider
*/
def getAllProviders(spark: SparkSession): Seq[FlintSparkSourceRelationProvider] = {
var relations = Seq[FlintSparkSourceRelationProvider]()

// File source is built-in supported
relations = relations :+ new FileSourceRelationProvider

val providerNames = relations.map(_.name()).mkString(",")
logInfo(s"Loaded source relation providers [$providerNames]")
relations
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.source.file

import org.opensearch.flint.spark.source.FlintSparkSourceRelation

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.execution.datasources.LogicalRelation

/**
* Concrete source relation implementation for Spark built-in file-based data sources.
*
* @param plan
* the `LogicalRelation` that represents the plan associated with the File-based table
*/
case class FileSourceRelation(override val plan: LogicalRelation)
extends FlintSparkSourceRelation {

override def tableName: String =
plan.catalogTable.get // catalogTable must be present as pre-checked in source relation provider's
.qualifiedName

override def output: Seq[AttributeReference] = plan.output
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.source.file

import org.opensearch.flint.spark.source.{FlintSparkSourceRelation, FlintSparkSourceRelationProvider}

import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources.LogicalRelation

/**
* Source relation provider for Spark built-in file-based source.
*
* @param name
* the name of the file source provider
*/
class FileSourceRelationProvider(override val name: String = "file")
extends FlintSparkSourceRelationProvider {

override def isSupported(plan: LogicalPlan): Boolean = plan match {
case LogicalRelation(_, _, Some(_), false) => true
case _ => false
}

override def getRelation(plan: LogicalPlan): FlintSparkSourceRelation = {
FileSourceRelation(plan.asInstanceOf[LogicalRelation])
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers {
private val testTable = "spark_catalog.default.apply_covering_index_test"
private val testTable2 = "spark_catalog.default.apply_covering_index_test_2"

// Mock FlintClient to avoid looking for real OpenSearch cluster
/** Mock FlintClient to avoid looking for real OpenSearch cluster */
private val clientBuilder = mockStatic(classOf[FlintClientBuilder])
private val client = mock[FlintClient](RETURNS_DEEP_STUBS)

/** Mock FlintSpark which is required by the rule */
private val flint = mock[FlintSpark]
/** Mock FlintSpark which is required by the rule. Deep stub required to replace spark val. */
private val flint = mock[FlintSpark](RETURNS_DEEP_STUBS)

/** Instantiate the rule once for all tests */
private val rule = new ApplyFlintSparkCoveringIndex(flint)
Expand Down

0 comments on commit ed105c9

Please sign in to comment.