Skip to content

Commit

Permalink
Extract source relation and provider abstraction
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Apr 30, 2024
1 parent 69086cc commit 5240d53
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 29 deletions.
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ lazy val flintSparkIntegration = (project in file("flint-spark-integration"))
libraryDependencies ++= Seq(
"com.amazonaws" % "aws-java-sdk" % "1.12.397" % "provided"
exclude ("com.fasterxml.jackson.core", "jackson-databind"),
"org.apache.iceberg" %% s"iceberg-spark-runtime-$sparkMinorVersion" % icebergVersion % "provided",
"org.scalactic" %% "scalactic" % "3.2.15" % "test",
"org.scalatest" %% "scalatest" % "3.2.15" % "test",
"org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import java.util
import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState.DELETED
import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndex}
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.getFlintIndexName
import org.opensearch.flint.spark.source.FlintSparkSourceRelation
import org.opensearch.flint.spark.source.{FlintSparkSourceRelation, FlintSparkSourceRelationProvider}
import org.opensearch.flint.spark.source.file.FileSourceRelationProvider
import org.opensearch.flint.spark.source.iceberg.IcebergSourceRelationProvider

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V2WriteCommand}
Expand All @@ -29,19 +31,42 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
*/
class ApplyFlintSparkCoveringIndex(flint: FlintSpark) extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case FlintSparkSourceRelation(relation)
if !plan.isInstanceOf[V2WriteCommand] => // TODO: make sure only intercept SELECT query
val relationCols = collectRelationColumnsInQueryPlan(plan, relation)
private val supportedSourceRelations: Seq[FlintSparkSourceRelationProvider] = {
var relations = Seq[FlintSparkSourceRelationProvider]()
relations = relations :+ new FileSourceRelationProvider

// Choose the first covering index that meets all criteria above
findAllCoveringIndexesOnTable(relation.tableName)
.sortBy(_.name())
.collectFirst {
case index: FlintSparkCoveringIndex if isCoveringIndexApplicable(index, relationCols) =>
replaceTableRelationWithIndexRelation(index, relation)
}
.getOrElse(relation.plan) // If no index found, return the original relation
if (flint.spark.conf
.getOption("spark.sql.catalog.spark_catalog")
.contains("org.apache.iceberg.spark.SparkSessionCatalog")) {
relations = relations :+ new IcebergSourceRelationProvider
}
relations
}

override def apply(plan: LogicalPlan): LogicalPlan = {
if (plan.isInstanceOf[V2WriteCommand]) {
plan
} else {
plan transform { case subPlan =>
supportedSourceRelations
.collectFirst {
case relationProvider if relationProvider.isSupported(subPlan) =>
val relation = relationProvider.getRelation(subPlan)
val relationCols = collectRelationColumnsInQueryPlan(plan, relation)

// Choose the first covering index that meets all criteria above
findAllCoveringIndexesOnTable(relation.tableName)
.sortBy(_.name())
.collectFirst {
case index: FlintSparkCoveringIndex
if isCoveringIndexApplicable(index, relationCols) =>
replaceTableRelationWithIndexRelation(index, relation)
}
.getOrElse(subPlan) // If no index found, return the original relation
}
.getOrElse(subPlan)
}
}
}

private def collectRelationColumnsInQueryPlan(
Expand All @@ -54,12 +79,13 @@ class ApplyFlintSparkCoveringIndex(flint: FlintSpark) extends Rule[LogicalPlan]
val relationColsById = relation.output.map(attr => (attr.exprId, attr)).toMap
plan
.collect {
case _: LogicalRelation => Set.empty
case r: LogicalRelation if r.eq(relation.plan) => Set.empty
case other =>
other.expressions
.flatMap(_.references)
.flatMap(ref =>
relationColsById.get(ref.exprId)) // Ignore attribute not belong to target relation
.flatMap(ref => {
relationColsById.get(ref.exprId)
}) // Ignore attribute not belong to current relation being rewritten
.map(attr => attr.name)
}
.flatten
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@

package org.opensearch.flint.spark.source

import org.opensearch.flint.spark.source.file.FileSourceRelation
import org.opensearch.flint.spark.source.iceberg.IcebergSourceRelation

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation

/**
* This source relation abstraction allows Flint to interact uniformly with different kinds of
Expand Down Expand Up @@ -35,14 +39,9 @@ trait FlintSparkSourceRelation {
def output: Seq[AttributeReference]
}

/**
* Extractor that identifies source relation type and wrapping applicable logical plans into
* appropriate `FlintSparkSourceRelation` instance.
*/
object FlintSparkSourceRelation {
def unapply(plan: LogicalPlan): Option[FlintSparkSourceRelation] = plan match {
case relation @ LogicalRelation(_, _, Some(_), false) =>
Some(file.FlintSparkFileSourceRelation(relation))
case _ => None
}
trait FlintSparkSourceRelationProvider {

def isSupported(plan: LogicalPlan): Boolean

def getRelation(plan: LogicalPlan): FlintSparkSourceRelation
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

package org.opensearch.flint.spark.source.file

import org.opensearch.flint.spark.source.FlintSparkSourceRelation
import org.opensearch.flint.spark.source.{FlintSparkSourceRelation, FlintSparkSourceRelationProvider}

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources.LogicalRelation

/**
Expand All @@ -16,7 +17,7 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation
* @param plan
* the relation plan associated with the file-based data source
*/
case class FlintSparkFileSourceRelation(override val plan: LogicalRelation)
case class FileSourceRelation(override val plan: LogicalRelation)
extends FlintSparkSourceRelation {

override def tableName: String =
Expand All @@ -26,3 +27,15 @@ case class FlintSparkFileSourceRelation(override val plan: LogicalRelation)

override def output: Seq[AttributeReference] = plan.output
}

class FileSourceRelationProvider extends FlintSparkSourceRelationProvider {

override def isSupported(plan: LogicalPlan): Boolean = plan match {
case LogicalRelation(_, _, Some(_), false) => true
case _ => false
}

override def getRelation(plan: LogicalPlan): FlintSparkSourceRelation = {
FileSourceRelation(plan.asInstanceOf[LogicalRelation])
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.source.iceberg

import org.apache.iceberg.spark.source.SparkTable
import org.opensearch.flint.spark.source.{FlintSparkSourceRelation, FlintSparkSourceRelationProvider}

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation

/**
* Concrete implementation of `FlintSparkSourceRelation` for Iceberg-based data sources. This
* class encapsulates the handling of relations backed by Iceberg tables, which are built on top
* of Spark's DataSourceV2 and TableProvider interfaces.
*
* @param plan
* the `DataSourceV2Relation` that represents the plan associated with the Iceberg table.
*/
case class IcebergSourceRelation(override val plan: DataSourceV2Relation)
extends FlintSparkSourceRelation {

/**
* Retrieves the fully qualified name of the table from the Iceberg table metadata. If the
* Iceberg table is not correctly referenced or the metadata is missing, an exception is thrown.
*/
override def tableName: String =
plan.table.name() // TODO: confirm

/**
* Provides the output attributes of the logical plan. These attributes represent the schema of
* the Iceberg table as it appears in Spark's logical plan and are used to define the structure
* of the data returned by scans of the Iceberg table.
*/
override def output: Seq[AttributeReference] = plan.output
}

class IcebergSourceRelationProvider extends FlintSparkSourceRelationProvider {

override def isSupported(plan: LogicalPlan): Boolean = plan match {
case DataSourceV2Relation(_: SparkTable, _, _, _, _) => true
case _ => false
}

override def getRelation(plan: LogicalPlan): FlintSparkSourceRelation = {
IcebergSourceRelation(plan.asInstanceOf[DataSourceV2Relation])
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers {
private val client = mock[FlintClient](RETURNS_DEEP_STUBS)

/** Mock FlintSpark which is required by the rule */
private val flint = mock[FlintSpark]
private val flint = mock[FlintSpark](RETURNS_DEEP_STUBS)

/** Instantiate the rule once for all tests */
private val rule = new ApplyFlintSparkCoveringIndex(flint)
Expand All @@ -47,6 +47,7 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers {
.when(() => FlintClientBuilder.build(any(classOf[FlintOptions])))
.thenReturn(client)
when(flint.spark).thenReturn(spark)
// when(flint.spark.conf.getOption(any())).thenReturn(None)
}

override protected def afterAll(): Unit = {
Expand Down

0 comments on commit 5240d53

Please sign in to comment.