Skip to content

Commit

Permalink
Generate id column for CV and MV
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Nov 14, 2024
1 parent bf60e59 commit 5e0168b
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ import org.opensearch.flint.common.metadata.FlintMetadata
import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry
import org.opensearch.flint.core.metadata.FlintJsonHelper._

import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.{Column, DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.flint.datatype.FlintDataType
import org.apache.spark.sql.functions.{col, concat_ws, expr, sha1}
import org.apache.spark.sql.types.StructType

/**
Expand Down Expand Up @@ -117,6 +119,38 @@ object FlintSparkIndex {
s"${parts(0)}.${parts(1)}.`${parts.drop(2).mkString(".")}`"
}

/**
* Generate an ID column in the precedence below:
* ```
* 1. Use ID expression provided in the index option;
* 2. SHA-1 based on all output columns if aggregated;
* 3. Otherwise, no ID column generated.
* ```
*
* @param df
* which DataFrame to generate ID column
* @param options
* Flint index options
* @return
* DataFrame with/without ID column
*/
def generateIdColumn(df: DataFrame, options: FlintSparkIndexOptions): DataFrame = {
// Assume output rows must be unique if a simple query plan has aggregate operator
def isAggregated: Boolean = {
df.queryExecution.logical.exists(_.isInstanceOf[Aggregate])
}

val idExpr = options.idExpression()
if (idExpr.exists(_.nonEmpty)) {
df.withColumn(ID_COLUMN, expr(idExpr.get))
} else if (isAggregated) {
val allOutputCols = df.columns.map(col)
df.withColumn(ID_COLUMN, sha1(concat_ws("\0", allOutputCols: _*)))
} else {
df
}
}

/**
* Populate environment variables to persist in Flint metadata.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import java.util.{Collections, UUID}
import org.json4s.{Formats, NoTypeHints}
import org.json4s.native.JsonMethods._
import org.json4s.native.Serialization
import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName.{AUTO_REFRESH, CHECKPOINT_LOCATION, EXTRA_OPTIONS, INCREMENTAL_REFRESH, INDEX_SETTINGS, OptionName, OUTPUT_MODE, REFRESH_INTERVAL, SCHEDULER_MODE, WATERMARK_DELAY}
import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName.{AUTO_REFRESH, CHECKPOINT_LOCATION, EXTRA_OPTIONS, ID_EXPRESSION, INCREMENTAL_REFRESH, INDEX_SETTINGS, OptionName, OUTPUT_MODE, REFRESH_INTERVAL, SCHEDULER_MODE, WATERMARK_DELAY}
import org.opensearch.flint.spark.FlintSparkIndexOptions.validateOptionNames
import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.SchedulerMode
import org.opensearch.flint.spark.scheduler.util.IntervalSchedulerParser
Expand Down Expand Up @@ -96,6 +96,14 @@ case class FlintSparkIndexOptions(options: Map[String, String]) {
*/
def indexSettings(): Option[String] = getOptionValue(INDEX_SETTINGS)

/**
* An expression that generates unique value as source data row ID.
*
* @return
* ID expression
*/
def idExpression(): Option[String] = getOptionValue(ID_EXPRESSION)

/**
* Extra streaming source options that can be simply passed to DataStreamReader or
* Relation.options
Expand Down Expand Up @@ -187,6 +195,7 @@ object FlintSparkIndexOptions {
val WATERMARK_DELAY: OptionName.Value = Value("watermark_delay")
val OUTPUT_MODE: OptionName.Value = Value("output_mode")
val INDEX_SETTINGS: OptionName.Value = Value("index_settings")
val ID_EXPRESSION: OptionName.Value = Value("id_expression")
val EXTRA_OPTIONS: OptionName.Value = Value("extra_options")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter
import org.opensearch.flint.common.metadata.FlintMetadata
import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry
import org.opensearch.flint.spark._
import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateSchema, metadataBuilder, quotedTableName}
import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateIdColumn, generateSchema, metadataBuilder, quotedTableName, ID_COLUMN}
import org.opensearch.flint.spark.FlintSparkIndexOptions.empty
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.{getFlintIndexName, COVERING_INDEX_TYPE}

Expand Down Expand Up @@ -71,10 +71,13 @@ case class FlintSparkCoveringIndex(
val job = df.getOrElse(spark.read.table(quotedTableName(tableName)))

// Add optional filtering condition
filterCondition
.map(job.where)
.getOrElse(job)
.select(colNames.head, colNames.tail: _*)
val batchDf =
filterCondition
.map(job.where)
.getOrElse(job)
.select(colNames.head, colNames.tail: _*)

generateIdColumn(batchDf, options)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import scala.collection.convert.ImplicitConversions.`map AsScala`
import org.opensearch.flint.common.metadata.FlintMetadata
import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry
import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndex, FlintSparkIndexBuilder, FlintSparkIndexOptions}
import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateSchema, metadataBuilder, StreamingRefresh}
import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateIdColumn, generateSchema, metadataBuilder, ID_COLUMN, StreamingRefresh}
import org.opensearch.flint.spark.FlintSparkIndexOptions.empty
import org.opensearch.flint.spark.function.TumbleFunction
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{getFlintIndexName, MV_INDEX_TYPE}
Expand Down Expand Up @@ -81,7 +81,8 @@ case class FlintSparkMaterializedView(
override def build(spark: SparkSession, df: Option[DataFrame]): DataFrame = {
require(df.isEmpty, "materialized view doesn't support reading from other data frame")

spark.sql(query)
val batchDf = spark.sql(query)
generateIdColumn(batchDf, options)
}

override def buildStream(spark: SparkSession): DataFrame = {
Expand All @@ -99,7 +100,9 @@ case class FlintSparkMaterializedView(
case relation: UnresolvedRelation if !relation.isStreaming =>
relation.copy(isStreaming = true, options = optionsWithExtra(spark, relation))
}
logicalPlanToDataFrame(spark, streamingPlan)

val streamingDf = logicalPlanToDataFrame(spark, streamingPlan)
generateIdColumn(streamingDf, options)
}

private def watermark(timeCol: Attribute, child: LogicalPlan) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,23 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite {
private val testTable = "spark_catalog.default.covering_sql_test"
private val testIndex = "name_and_age"
private val testFlintIndex = getFlintIndexName(testIndex, testTable)
private val testTimeSeriesTable = "spark_catalog.default.covering_sql_ts_test"
private val testFlintTimeSeriesIndex = getFlintIndexName(testIndex, testTimeSeriesTable)

override def beforeEach(): Unit = {
super.beforeEach()

createPartitionedAddressTable(testTable)
createTimeSeriesTable(testTimeSeriesTable)
}

override def afterEach(): Unit = {
super.afterEach()

// Delete all test indices
deleteTestIndex(testFlintIndex)
deleteTestIndex(testFlintIndex, testFlintTimeSeriesIndex)
sql(s"DROP TABLE $testTable")
sql(s"DROP TABLE $testTimeSeriesTable")
}

test("create covering index with auto refresh") {
Expand Down Expand Up @@ -86,6 +90,37 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite {
}
}

test("create covering index with auto refresh and ID expression") {
sql(s"""
| CREATE INDEX $testIndex ON $testTimeSeriesTable
| (time, age, address)
| WITH (
| auto_refresh = true,
| id_expression = 'address'
| )
|""".stripMargin)

val job = spark.streams.active.find(_.name == testFlintTimeSeriesIndex)
awaitStreamingComplete(job.get.id.toString)

val indexData = flint.queryIndex(testFlintTimeSeriesIndex)
indexData.count() shouldBe 3 // only 3 rows left due to same ID
}

test("create covering index with full refresh and ID expression") {
sql(s"""
| CREATE INDEX $testIndex ON $testTimeSeriesTable
| (time, age, address)
| WITH (
| id_expression = 'address'
| )
|""".stripMargin)
sql(s"REFRESH INDEX $testIndex ON $testTimeSeriesTable")

val indexData = flint.queryIndex(testFlintTimeSeriesIndex)
indexData.count() shouldBe 3 // only 3 rows left due to same ID
}

test("create covering index with index settings") {
sql(s"""
| CREATE INDEX $testIndex ON $testTable ( name )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,47 @@ class FlintSparkMaterializedViewSqlITSuite extends FlintSparkSuite {
}
}

test("create materialized view with auto refresh and ID expression") {
withTempDir { checkpointDir =>
sql(s"""
| CREATE MATERIALIZED VIEW $testMvName
| AS $testQuery
| WITH (
| auto_refresh = true,
| checkpoint_location = '${checkpointDir.getAbsolutePath}',
| watermark_delay = '1 Second',
| id_expression = 'count'
| )
|""".stripMargin)

// Wait for streaming job complete current micro batch
val job = spark.streams.active.find(_.name == testFlintIndex)
job shouldBe defined
failAfter(streamingTimeout) {
job.get.processAllAvailable()
}

// 1 row missing due to ID conflict intentionally
flint.queryIndex(testFlintIndex).count() shouldBe 2
}
}

test("create materialized view with full refresh and ID expression") {
sql(s"""
| CREATE MATERIALIZED VIEW $testMvName
| AS $testQuery
| WITH (
| id_expression = 'count'
| )
|""".stripMargin)

sql(s"REFRESH MATERIALIZED VIEW $testMvName")

// 2 rows missing due to ID conflict intentionally
val indexData = spark.read.format(FLINT_DATASOURCE).load(testFlintIndex)
indexData.count() shouldBe 2
}

test("create materialized view with index settings") {
sql(s"""
| CREATE MATERIALIZED VIEW $testMvName
Expand Down

0 comments on commit 5e0168b

Please sign in to comment.