From 5240d53ba4ea3639ef4110a904e85aa8074c2e90 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Mon, 29 Apr 2024 17:46:03 -0700 Subject: [PATCH] Extract source relation and provider abstraction Signed-off-by: Chen Dai --- build.sbt | 1 + .../ApplyFlintSparkCoveringIndex.scala | 58 ++++++++++++++----- .../source/FlintSparkSourceRelation.scala | 19 +++--- ...elation.scala => FileSourceRelation.scala} | 17 +++++- .../iceberg/IcebergSourceRelation.scala | 51 ++++++++++++++++ .../ApplyFlintSparkCoveringIndexSuite.scala | 3 +- 6 files changed, 120 insertions(+), 29 deletions(-) rename flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/file/{FlintSparkFileSourceRelation.scala => FileSourceRelation.scala} (54%) create mode 100644 flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/iceberg/IcebergSourceRelation.scala diff --git a/build.sbt b/build.sbt index 4cf923fc2..f7f602051 100644 --- a/build.sbt +++ b/build.sbt @@ -130,6 +130,7 @@ lazy val flintSparkIntegration = (project in file("flint-spark-integration")) libraryDependencies ++= Seq( "com.amazonaws" % "aws-java-sdk" % "1.12.397" % "provided" exclude ("com.fasterxml.jackson.core", "jackson-databind"), + "org.apache.iceberg" %% s"iceberg-spark-runtime-$sparkMinorVersion" % icebergVersion % "provided", "org.scalactic" %% "scalactic" % "3.2.15" % "test", "org.scalatest" %% "scalatest" % "3.2.15" % "test", "org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test", diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndex.scala index 661be43a9..bbc1ada5a 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndex.scala @@ -10,7 +10,9 @@ import java.util import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState.DELETED import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndex} import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.getFlintIndexName -import org.opensearch.flint.spark.source.FlintSparkSourceRelation +import org.opensearch.flint.spark.source.{FlintSparkSourceRelation, FlintSparkSourceRelationProvider} +import org.opensearch.flint.spark.source.file.FileSourceRelationProvider +import org.opensearch.flint.spark.source.iceberg.IcebergSourceRelationProvider import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V2WriteCommand} @@ -29,19 +31,42 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap */ class ApplyFlintSparkCoveringIndex(flint: FlintSpark) extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case FlintSparkSourceRelation(relation) - if !plan.isInstanceOf[V2WriteCommand] => // TODO: make sure only intercept SELECT query - val relationCols = collectRelationColumnsInQueryPlan(plan, relation) + private val supportedSourceRelations: Seq[FlintSparkSourceRelationProvider] = { + var relations = Seq[FlintSparkSourceRelationProvider]() + relations = relations :+ new FileSourceRelationProvider - // Choose the first covering index that meets all criteria above - findAllCoveringIndexesOnTable(relation.tableName) - .sortBy(_.name()) - .collectFirst { - case index: FlintSparkCoveringIndex if isCoveringIndexApplicable(index, relationCols) => - replaceTableRelationWithIndexRelation(index, relation) - } - .getOrElse(relation.plan) // If no index found, return the original relation + if (flint.spark.conf + .getOption("spark.sql.catalog.spark_catalog") + .contains("org.apache.iceberg.spark.SparkSessionCatalog")) { + relations = relations :+ new IcebergSourceRelationProvider + } + relations + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (plan.isInstanceOf[V2WriteCommand]) { + plan + } else { + plan transform { case subPlan => + supportedSourceRelations + .collectFirst { + case relationProvider if relationProvider.isSupported(subPlan) => + val relation = relationProvider.getRelation(subPlan) + val relationCols = collectRelationColumnsInQueryPlan(plan, relation) + + // Choose the first covering index that meets all criteria above + findAllCoveringIndexesOnTable(relation.tableName) + .sortBy(_.name()) + .collectFirst { + case index: FlintSparkCoveringIndex + if isCoveringIndexApplicable(index, relationCols) => + replaceTableRelationWithIndexRelation(index, relation) + } + .getOrElse(subPlan) // If no index found, return the original relation + } + .getOrElse(subPlan) + } + } } private def collectRelationColumnsInQueryPlan( @@ -54,12 +79,13 @@ class ApplyFlintSparkCoveringIndex(flint: FlintSpark) extends Rule[LogicalPlan] val relationColsById = relation.output.map(attr => (attr.exprId, attr)).toMap plan .collect { - case _: LogicalRelation => Set.empty + case r: LogicalRelation if r.eq(relation.plan) => Set.empty case other => other.expressions .flatMap(_.references) - .flatMap(ref => - relationColsById.get(ref.exprId)) // Ignore attribute not belong to target relation + .flatMap(ref => { + relationColsById.get(ref.exprId) + }) // Ignore attribute not belong to current relation being rewritten .map(attr => attr.name) } .flatten diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/FlintSparkSourceRelation.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/FlintSparkSourceRelation.scala index ea70f5a83..3e208569b 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/FlintSparkSourceRelation.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/FlintSparkSourceRelation.scala @@ -5,9 +5,13 @@ package org.opensearch.flint.spark.source +import org.opensearch.flint.spark.source.file.FileSourceRelation +import org.opensearch.flint.spark.source.iceberg.IcebergSourceRelation + import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation /** * This source relation abstraction allows Flint to interact uniformly with different kinds of @@ -35,14 +39,9 @@ trait FlintSparkSourceRelation { def output: Seq[AttributeReference] } -/** - * Extractor that identifies source relation type and wrapping applicable logical plans into - * appropriate `FlintSparkSourceRelation` instance. - */ -object FlintSparkSourceRelation { - def unapply(plan: LogicalPlan): Option[FlintSparkSourceRelation] = plan match { - case relation @ LogicalRelation(_, _, Some(_), false) => - Some(file.FlintSparkFileSourceRelation(relation)) - case _ => None - } +trait FlintSparkSourceRelationProvider { + + def isSupported(plan: LogicalPlan): Boolean + + def getRelation(plan: LogicalPlan): FlintSparkSourceRelation } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/file/FlintSparkFileSourceRelation.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/file/FileSourceRelation.scala similarity index 54% rename from flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/file/FlintSparkFileSourceRelation.scala rename to flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/file/FileSourceRelation.scala index 8b13ded26..2d8412de9 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/file/FlintSparkFileSourceRelation.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/file/FileSourceRelation.scala @@ -5,9 +5,10 @@ package org.opensearch.flint.spark.source.file -import org.opensearch.flint.spark.source.FlintSparkSourceRelation +import org.opensearch.flint.spark.source.{FlintSparkSourceRelation, FlintSparkSourceRelationProvider} import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.datasources.LogicalRelation /** @@ -16,7 +17,7 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation * @param plan * the relation plan associated with the file-based data source */ -case class FlintSparkFileSourceRelation(override val plan: LogicalRelation) +case class FileSourceRelation(override val plan: LogicalRelation) extends FlintSparkSourceRelation { override def tableName: String = @@ -26,3 +27,15 @@ case class FlintSparkFileSourceRelation(override val plan: LogicalRelation) override def output: Seq[AttributeReference] = plan.output } + +class FileSourceRelationProvider extends FlintSparkSourceRelationProvider { + + override def isSupported(plan: LogicalPlan): Boolean = plan match { + case LogicalRelation(_, _, Some(_), false) => true + case _ => false + } + + override def getRelation(plan: LogicalPlan): FlintSparkSourceRelation = { + FileSourceRelation(plan.asInstanceOf[LogicalRelation]) + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/iceberg/IcebergSourceRelation.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/iceberg/IcebergSourceRelation.scala new file mode 100644 index 000000000..6a8679f83 --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/source/iceberg/IcebergSourceRelation.scala @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.source.iceberg + +import org.apache.iceberg.spark.source.SparkTable +import org.opensearch.flint.spark.source.{FlintSparkSourceRelation, FlintSparkSourceRelationProvider} + +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation + +/** + * Concrete implementation of `FlintSparkSourceRelation` for Iceberg-based data sources. This + * class encapsulates the handling of relations backed by Iceberg tables, which are built on top + * of Spark's DataSourceV2 and TableProvider interfaces. + * + * @param plan + * the `DataSourceV2Relation` that represents the plan associated with the Iceberg table. + */ +case class IcebergSourceRelation(override val plan: DataSourceV2Relation) + extends FlintSparkSourceRelation { + + /** + * Retrieves the fully qualified name of the table from the Iceberg table metadata. If the + * Iceberg table is not correctly referenced or the metadata is missing, an exception is thrown. + */ + override def tableName: String = + plan.table.name() // TODO: confirm + + /** + * Provides the output attributes of the logical plan. These attributes represent the schema of + * the Iceberg table as it appears in Spark's logical plan and are used to define the structure + * of the data returned by scans of the Iceberg table. + */ + override def output: Seq[AttributeReference] = plan.output +} + +class IcebergSourceRelationProvider extends FlintSparkSourceRelationProvider { + + override def isSupported(plan: LogicalPlan): Boolean = plan match { + case DataSourceV2Relation(_: SparkTable, _, _, _, _) => true + case _ => false + } + + override def getRelation(plan: LogicalPlan): FlintSparkSourceRelation = { + IcebergSourceRelation(plan.asInstanceOf[DataSourceV2Relation]) + } +} diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala index bef9118c7..f48f76080 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/ApplyFlintSparkCoveringIndexSuite.scala @@ -32,7 +32,7 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers { private val client = mock[FlintClient](RETURNS_DEEP_STUBS) /** Mock FlintSpark which is required by the rule */ - private val flint = mock[FlintSpark] + private val flint = mock[FlintSpark](RETURNS_DEEP_STUBS) /** Instantiate the rule once for all tests */ private val rule = new ApplyFlintSparkCoveringIndex(flint) @@ -47,6 +47,7 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers { .when(() => FlintClientBuilder.build(any(classOf[FlintOptions]))) .thenReturn(client) when(flint.spark).thenReturn(spark) + // when(flint.spark.conf.getOption(any())).thenReturn(None) } override protected def afterAll(): Unit = {