diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndex.scala index 0234ec35a..9217495e6 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndex.scala @@ -8,13 +8,14 @@ package org.opensearch.flint.spark.covering import java.util import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry.IndexState.DELETED -import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndex} +import org.opensearch.flint.spark.FlintSpark import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.getFlintIndexName +import org.opensearch.flint.spark.source.{FlintSparkSourceRelation, FlintSparkSourceRelationProvider} +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V2WriteCommand} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.flint.{qualifyTableName, FlintDataSourceV2} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -28,24 +29,37 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap */ class ApplyFlintSparkCoveringIndex(flint: FlintSpark) extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case relation @ LogicalRelation(_, _, Some(table), false) - if !plan.isInstanceOf[V2WriteCommand] => // TODO: make sure only intercept SELECT query - val relationCols = collectRelationColumnsInQueryPlan(relation, plan) - - // Choose the first covering index that meets all criteria above - findAllCoveringIndexesOnTable(table.qualifiedName) - .sortBy(_.name()) - .collectFirst { - case index: FlintSparkCoveringIndex if isCoveringIndexApplicable(index, relationCols) => - replaceTableRelationWithIndexRelation(index, relation) - } - .getOrElse(relation) // If no index found, return the original relation + /** All supported source relation providers */ + private val supportedProviders = FlintSparkSourceRelationProvider.getAllProviders(flint.spark) + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (plan.isInstanceOf[V2WriteCommand]) { // TODO: bypass any non-select plan + plan + } else { + // Iterate each sub plan tree in the given plan + plan transform { case subPlan => + supportedProviders + .collectFirst { + case provider if provider.isSupported(subPlan) => + logInfo(s"Provider [${provider.name()}] can match sub plan ${subPlan.nodeName}") + val relation = provider.getRelation(subPlan) + val relationCols = collectRelationColumnsInQueryPlan(plan, relation) + + // Choose the first covering index that meets all criteria above + findAllCoveringIndexesOnTable(relation.tableName) + .sortBy(_.name()) + .find(index => isCoveringIndexApplicable(index, relationCols)) + .map(index => replaceTableRelationWithIndexRelation(index, relation)) + .getOrElse(subPlan) // If no index found, return the original node + } + .getOrElse(subPlan) // If not supported by any provider, return the original node + } + } } private def collectRelationColumnsInQueryPlan( - relation: LogicalRelation, - plan: LogicalPlan): Set[String] = { + plan: LogicalPlan, + relation: FlintSparkSourceRelation): Set[String] = { /* * Collect all columns of the relation present in query plan, except those in relation itself. * Because this rule executes before push down optimization, relation includes all columns. @@ -53,35 +67,57 @@ class ApplyFlintSparkCoveringIndex(flint: FlintSpark) extends Rule[LogicalPlan] val relationColsById = relation.output.map(attr => (attr.exprId, attr)).toMap plan .collect { - case _: LogicalRelation => Set.empty + // Relation interface matches both file and Iceberg relation + case r: MultiInstanceRelation if r.eq(relation.plan) => Set.empty case other => other.expressions .flatMap(_.references) - .flatMap(ref => - relationColsById.get(ref.exprId)) // Ignore attribute not belong to target relation + .flatMap(ref => { + relationColsById.get(ref.exprId) + }) // Ignore attribute not belong to current relation being rewritten .map(attr => attr.name) } .flatten .toSet } - private def findAllCoveringIndexesOnTable(tableName: String): Seq[FlintSparkIndex] = { + private def findAllCoveringIndexesOnTable(tableName: String): Seq[FlintSparkCoveringIndex] = { val qualifiedTableName = qualifyTableName(flint.spark, tableName) val indexPattern = getFlintIndexName("*", qualifiedTableName) - flint.describeIndexes(indexPattern) + val indexes = + flint + .describeIndexes(indexPattern) + .collect { // cast to covering index + case index: FlintSparkCoveringIndex => index + } + + val indexNames = indexes.map(_.name()).mkString(",") + logInfo(s"Found covering index [$indexNames] on table $qualifiedTableName") + indexes } private def isCoveringIndexApplicable( index: FlintSparkCoveringIndex, relationCols: Set[String]): Boolean = { - index.latestLogEntry.exists(_.state != DELETED) && - index.filterCondition.isEmpty && // TODO: support partial covering index later - relationCols.subsetOf(index.indexedColumns.keySet) + val indexedCols = index.indexedColumns.keySet + val isApplicable = + index.latestLogEntry.exists(_.state != DELETED) && + index.filterCondition.isEmpty && // TODO: support partial covering index later + relationCols.subsetOf(indexedCols) + + logInfo(s""" + | Is covering index ${index.name()} applicable: $isApplicable + | Index state: ${index.latestLogEntry.map(_.state)} + | Index filter condition: ${index.filterCondition} + | Columns required: $relationCols + | Columns indexed: $indexedCols + |""".stripMargin) + isApplicable } private def replaceTableRelationWithIndexRelation( index: FlintSparkCoveringIndex, - relation: LogicalRelation): LogicalPlan = { + relation: FlintSparkSourceRelation): LogicalPlan = { // Make use of data source relation to avoid Spark looking for OpenSearch index in catalog val ds = new FlintDataSourceV2 val options = new CaseInsensitiveStringMap(util.Map.of("path", index.name())) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/FlintSparkSourceRelation.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/FlintSparkSourceRelation.scala new file mode 100644 index 000000000..f5d063ba7 --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/FlintSparkSourceRelation.scala @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.source + +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * This source relation abstraction allows Flint to interact uniformly with different kinds of + * source data formats (like Spark built-in File, Delta table, Iceberg, etc.), hiding the + * specifics of each data source implementation. + */ +trait FlintSparkSourceRelation { + + /** + * @return + * the concrete logical plan of the relation associated + */ + def plan: LogicalPlan + + /** + * @return + * fully qualified table name represented by the relation + */ + def tableName: String + + /** + * @return + * output column list of the relation + */ + def output: Seq[AttributeReference] +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/FlintSparkSourceRelationProvider.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/FlintSparkSourceRelationProvider.scala new file mode 100644 index 000000000..b32fdc23e --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/FlintSparkSourceRelationProvider.scala @@ -0,0 +1,73 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.source + +import org.opensearch.flint.spark.source.file.FileSourceRelationProvider + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * A provider defines what kind of logical plan can be supported by Flint Spark integration. It + * serves similar purpose to Scala extractor which has to be used in match case statement. + * However, the problem here is we want to avoid hard dependency on some data source code, such as + * Iceberg. In this case, we have to maintain a list of provider and run it only if the 3rd party + * library is available in current Spark session. + */ +trait FlintSparkSourceRelationProvider { + + /** + * @return + * the name of the source relation provider + */ + def name(): String + + /** + * Determines whether the given logical plan is supported by this provider. + * + * @param plan + * the logical plan to evaluate + * @return + * true if the plan is supported, false otherwise + */ + def isSupported(plan: LogicalPlan): Boolean + + /** + * Creates a source relation based on the provided logical plan. + * + * @param plan + * the logical plan to wrap in source relation + * @return + * an instance of source relation + */ + def getRelation(plan: LogicalPlan): FlintSparkSourceRelation +} + +/** + * Companion object provides utility methods. + */ +object FlintSparkSourceRelationProvider extends Logging { + + /** + * Retrieve all supported source relation provider for the given Spark session. + * + * @param spark + * the Spark session + * @return + * a sequence of source relation provider + */ + def getAllProviders(spark: SparkSession): Seq[FlintSparkSourceRelationProvider] = { + var relations = Seq[FlintSparkSourceRelationProvider]() + + // File source is built-in supported + relations = relations :+ new FileSourceRelationProvider + + val providerNames = relations.map(_.name()).mkString(",") + logInfo(s"Loaded source relation providers [$providerNames]") + relations + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/file/FileSourceRelation.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/file/FileSourceRelation.scala new file mode 100644 index 000000000..af2bc55b8 --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/file/FileSourceRelation.scala @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.source.file + +import org.opensearch.flint.spark.source.FlintSparkSourceRelation + +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.execution.datasources.LogicalRelation + +/** + * Concrete source relation implementation for Spark built-in file-based data sources. + * + * @param plan + * the `LogicalRelation` that represents the plan associated with the File-based table + */ +case class FileSourceRelation(override val plan: LogicalRelation) + extends FlintSparkSourceRelation { + + override def tableName: String = + plan.catalogTable.get // catalogTable must be present as pre-checked in source relation provider's + .qualifiedName + + override def output: Seq[AttributeReference] = plan.output +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/file/FileSourceRelationProvider.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/file/FileSourceRelationProvider.scala new file mode 100644 index 000000000..d35309dcc --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/file/FileSourceRelationProvider.scala @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.source.file + +import org.opensearch.flint.spark.source.{FlintSparkSourceRelation, FlintSparkSourceRelationProvider} + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.datasources.LogicalRelation + +/** + * Source relation provider for Spark built-in file-based source. + * + * @param name + * the name of the file source provider + */ +class FileSourceRelationProvider(override val name: String = "file") + extends FlintSparkSourceRelationProvider { + + override def isSupported(plan: LogicalPlan): Boolean = plan match { + case LogicalRelation(_, _, Some(_), false) => true + case _ => false + } + + override def getRelation(plan: LogicalPlan): FlintSparkSourceRelation = { + FileSourceRelation(plan.asInstanceOf[LogicalRelation]) + } +} diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala index 5231bdfa6..cadb6c93a 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala @@ -27,12 +27,12 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers { private val testTable = "spark_catalog.default.apply_covering_index_test" private val testTable2 = "spark_catalog.default.apply_covering_index_test_2" - // Mock FlintClient to avoid looking for real OpenSearch cluster + /** Mock FlintClient to avoid looking for real OpenSearch cluster */ private val clientBuilder = mockStatic(classOf[FlintClientBuilder]) private val client = mock[FlintClient](RETURNS_DEEP_STUBS) - /** Mock FlintSpark which is required by the rule */ - private val flint = mock[FlintSpark] + /** Mock FlintSpark which is required by the rule. Deep stub required to replace spark val. */ + private val flint = mock[FlintSpark](RETURNS_DEEP_STUBS) /** Instantiate the rule once for all tests */ private val rule = new ApplyFlintSparkCoveringIndex(flint)