diff --git a/docs/index.md b/docs/index.md index abc801bde..684ba7da6 100644 --- a/docs/index.md +++ b/docs/index.md @@ -394,6 +394,7 @@ User can provide the following options in `WITH` clause of create statement: + `watermark_delay`: a string as time expression for how late data can come and still be processed, e.g. 1 minute, 10 seconds. This is required by auto and incremental refresh on materialized view if it has aggregation in the query. + `output_mode`: a mode string that describes how data will be written to streaming sink. If unspecified, default append mode will be applied. + `index_settings`: a JSON string as index settings for OpenSearch index that will be created. Please follow the format in OpenSearch documentation. If unspecified, default OpenSearch index settings will be applied. ++ `id_expression`: an expression string that generates an ID column to guarantee idempotency when index refresh job restart or any retry attempt during an index refresh. If an empty string is provided, no ID column will be generated. + `extra_options`: a JSON string as extra options that can be passed to Spark streaming source and sink API directly. Use qualified source table name (because there could be multiple) and "sink", e.g. '{"sink": "{key: val}", "table1": {key: val}}' Note that the index option name is case-sensitive. Here is an example: @@ -406,6 +407,7 @@ WITH ( watermark_delay = '1 Second', output_mode = 'complete', index_settings = '{"number_of_shards": 2, "number_of_replicas": 3}', + id_expression = "sha1(concat_ws('\0',startTime,status))", extra_options = '{"spark_catalog.default.alb_logs": {"maxFilesPerTrigger": "1"}}' ) ``` diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala index 44ea5188f..300233777 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala @@ -11,8 +11,10 @@ import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.core.metadata.FlintJsonHelper._ +import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.flint.datatype.FlintDataType +import org.apache.spark.sql.functions.expr import org.apache.spark.sql.types.StructType /** @@ -62,7 +64,7 @@ trait FlintSparkIndex { def build(spark: SparkSession, df: Option[DataFrame]): DataFrame } -object FlintSparkIndex { +object FlintSparkIndex extends Logging { /** * Interface indicates a Flint index has custom streaming refresh capability other than foreach @@ -117,6 +119,25 @@ object FlintSparkIndex { s"${parts(0)}.${parts(1)}.`${parts.drop(2).mkString(".")}`" } + /** + * Generate an ID column using ID expression provided in the index option. + * + * @param df + * which DataFrame to generate ID column + * @param options + * Flint index options + * @return + * DataFrame with/without ID column + */ + def addIdColumn(df: DataFrame, options: FlintSparkIndexOptions): DataFrame = { + options.idExpression() match { + case Some(idExpr) if idExpr.nonEmpty => + logInfo(s"Using user-provided ID expression: $idExpr") + df.withColumn(ID_COLUMN, expr(idExpr)) + case _ => df + } + } + /** * Populate environment variables to persist in Flint metadata. * diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala index 9b58a696c..1ad88de6d 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala @@ -10,7 +10,7 @@ import java.util.{Collections, UUID} import org.json4s.{Formats, NoTypeHints} import org.json4s.native.JsonMethods._ import org.json4s.native.Serialization -import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName.{AUTO_REFRESH, CHECKPOINT_LOCATION, EXTRA_OPTIONS, INCREMENTAL_REFRESH, INDEX_SETTINGS, OptionName, OUTPUT_MODE, REFRESH_INTERVAL, SCHEDULER_MODE, WATERMARK_DELAY} +import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName.{AUTO_REFRESH, CHECKPOINT_LOCATION, EXTRA_OPTIONS, ID_EXPRESSION, INCREMENTAL_REFRESH, INDEX_SETTINGS, OptionName, OUTPUT_MODE, REFRESH_INTERVAL, SCHEDULER_MODE, WATERMARK_DELAY} import org.opensearch.flint.spark.FlintSparkIndexOptions.validateOptionNames import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.SchedulerMode import org.opensearch.flint.spark.scheduler.util.IntervalSchedulerParser @@ -96,6 +96,14 @@ case class FlintSparkIndexOptions(options: Map[String, String]) { */ def indexSettings(): Option[String] = getOptionValue(INDEX_SETTINGS) + /** + * An expression that generates unique value as index data row ID. + * + * @return + * ID expression + */ + def idExpression(): Option[String] = getOptionValue(ID_EXPRESSION) + /** * Extra streaming source options that can be simply passed to DataStreamReader or * Relation.options @@ -187,6 +195,7 @@ object FlintSparkIndexOptions { val WATERMARK_DELAY: OptionName.Value = Value("watermark_delay") val OUTPUT_MODE: OptionName.Value = Value("output_mode") val INDEX_SETTINGS: OptionName.Value = Value("index_settings") + val ID_EXPRESSION: OptionName.Value = Value("id_expression") val EXTRA_OPTIONS: OptionName.Value = Value("extra_options") } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala index 8748bf874..901c3006c 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala @@ -10,7 +10,7 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.spark._ -import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateSchema, metadataBuilder, quotedTableName} +import org.opensearch.flint.spark.FlintSparkIndex.{addIdColumn, flintIndexNamePrefix, generateSchema, metadataBuilder, quotedTableName} import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.{getFlintIndexName, COVERING_INDEX_TYPE} @@ -71,10 +71,13 @@ case class FlintSparkCoveringIndex( val job = df.getOrElse(spark.read.table(quotedTableName(tableName))) // Add optional filtering condition - filterCondition - .map(job.where) - .getOrElse(job) - .select(colNames.head, colNames.tail: _*) + val batchDf = + filterCondition + .map(job.where) + .getOrElse(job) + .select(colNames.head, colNames.tail: _*) + + addIdColumn(batchDf, options) } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala index d5c450e7e..e3b09661a 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala @@ -13,7 +13,7 @@ import scala.collection.convert.ImplicitConversions.`map AsScala` import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndex, FlintSparkIndexBuilder, FlintSparkIndexOptions} -import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateSchema, metadataBuilder, StreamingRefresh} +import org.opensearch.flint.spark.FlintSparkIndex.{addIdColumn, flintIndexNamePrefix, generateSchema, metadataBuilder, ID_COLUMN, StreamingRefresh} import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.opensearch.flint.spark.function.TumbleFunction import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{getFlintIndexName, MV_INDEX_TYPE} @@ -81,7 +81,8 @@ case class FlintSparkMaterializedView( override def build(spark: SparkSession, df: Option[DataFrame]): DataFrame = { require(df.isEmpty, "materialized view doesn't support reading from other data frame") - spark.sql(query) + val batchDf = spark.sql(query) + addIdColumn(batchDf, options) } override def buildStream(spark: SparkSession): DataFrame = { @@ -99,7 +100,9 @@ case class FlintSparkMaterializedView( case relation: UnresolvedRelation if !relation.isStreaming => relation.copy(isStreaming = true, options = optionsWithExtra(spark, relation)) } - logicalPlanToDataFrame(spark, streamingPlan) + + val streamingDf = logicalPlanToDataFrame(spark, streamingPlan) + addIdColumn(streamingDf, options) } private def watermark(timeCol: Attribute, child: LogicalPlan) = { diff --git a/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala b/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala index 1d301087f..78debda35 100644 --- a/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala +++ b/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala @@ -6,9 +6,12 @@ package org.apache.spark import org.opensearch.flint.spark.FlintSparkExtensions +import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN -import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.expressions.{Alias, CodegenObjectFactoryMode, Expression} import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation +import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.flint.config.{FlintConfigEntry, FlintSparkConf} import org.apache.spark.sql.flint.config.FlintSparkConf.{EXTERNAL_SCHEDULER_ENABLED, HYBRID_SCAN_ENABLED, METADATA_CACHE_WRITE} import org.apache.spark.sql.internal.SQLConf @@ -68,4 +71,27 @@ trait FlintSuite extends SharedSparkSession { setFlintSparkConf(METADATA_CACHE_WRITE, "false") } } + + /** + * Implicit class to extend DataFrame functionality with additional utilities. + * + * @param df + * the DataFrame to which the additional methods are added + */ + protected implicit class DataFrameExtensions(val df: DataFrame) { + + /** + * Retrieves the ID column expression from the logical plan of the DataFrame, if it exists. + * + * @return + * an `Option` containing the `Expression` for the ID column if present, or `None` otherwise + */ + def idColumn(): Option[Expression] = { + df.queryExecution.logical.collectFirst { case Project(projectList, _) => + projectList.collectFirst { case Alias(child, ID_COLUMN) => + child + } + }.flatten + } + } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexOptionsSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexOptionsSuite.scala index d7de6d29b..f752ae68a 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexOptionsSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexOptionsSuite.scala @@ -6,7 +6,6 @@ package org.opensearch.flint.spark import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName._ -import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.SchedulerMode import org.scalatest.matchers.should.Matchers import org.apache.spark.FlintSuite @@ -22,6 +21,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers { WATERMARK_DELAY.toString shouldBe "watermark_delay" OUTPUT_MODE.toString shouldBe "output_mode" INDEX_SETTINGS.toString shouldBe "index_settings" + ID_EXPRESSION.toString shouldBe "id_expression" EXTRA_OPTIONS.toString shouldBe "extra_options" } @@ -36,6 +36,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers { "watermark_delay" -> "30 Seconds", "output_mode" -> "complete", "index_settings" -> """{"number_of_shards": 3}""", + "id_expression" -> """sha1(col("timestamp"))""", "extra_options" -> """ { | "alb_logs": { @@ -55,6 +56,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers { options.watermarkDelay() shouldBe Some("30 Seconds") options.outputMode() shouldBe Some("complete") options.indexSettings() shouldBe Some("""{"number_of_shards": 3}""") + options.idExpression() shouldBe Some("""sha1(col("timestamp"))""") options.extraSourceOptions("alb_logs") shouldBe Map("opt1" -> "val1") options.extraSinkOptions() shouldBe Map("opt2" -> "val2", "opt3" -> "val3") } @@ -83,6 +85,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers { options.watermarkDelay() shouldBe empty options.outputMode() shouldBe empty options.indexSettings() shouldBe empty + options.idExpression() shouldBe empty options.extraSourceOptions("alb_logs") shouldBe empty options.extraSinkOptions() shouldBe empty options.optionsWithDefault should contain("auto_refresh" -> "false") diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala new file mode 100644 index 000000000..8ec4bec40 --- /dev/null +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala @@ -0,0 +1,111 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import org.opensearch.flint.spark.FlintSparkIndex.{addIdColumn, ID_COLUMN} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.FlintSuite +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction} +import org.apache.spark.sql.catalyst.expressions.{Add, ConcatWs, Literal, Sha1, StructsToJson} +import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.unsafe.types.UTF8String + +class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { + + test("should add ID column if ID expression is provided") { + val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name") + val options = new FlintSparkIndexOptions(Map("id_expression" -> "id + 10")) + + val resultDf = addIdColumn(df, options) + resultDf.idColumn() shouldBe Some(Add(UnresolvedAttribute("id"), Literal(10))) + checkAnswer(resultDf.select(ID_COLUMN), Seq(Row(11), Row(12))) + } + + test("should not add ID column if ID expression is not provided") { + val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name") + val options = FlintSparkIndexOptions.empty + + val resultDf = addIdColumn(df, options) + resultDf.columns should not contain ID_COLUMN + } + + test("should not add ID column if ID expression is empty") { + val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name") + val options = FlintSparkIndexOptions.empty + + val resultDf = addIdColumn(df, options) + resultDf.columns should not contain ID_COLUMN + } + + test("should generate ID column for various column types") { + val schema = StructType.fromDDL(""" + boolean_col BOOLEAN, + string_col STRING, + long_col LONG, + int_col INT, + double_col DOUBLE, + float_col FLOAT, + timestamp_col TIMESTAMP, + date_col DATE, + struct_col STRUCT + """) + val data = Seq( + Row( + true, + "Alice", + 100L, + 10, + 10.5, + 3.14f, + java.sql.Timestamp.valueOf("2024-01-01 10:00:00"), + java.sql.Date.valueOf("2024-01-01"), + Row("sub1", 1))) + + val aggregatedDf = spark + .createDataFrame(sparkContext.parallelize(data), schema) + .groupBy( + "boolean_col", + "string_col", + "long_col", + "int_col", + "double_col", + "float_col", + "timestamp_col", + "date_col", + "struct_col", + "struct_col.subfield2") + .count() + val options = FlintSparkIndexOptions(Map("id_expression" -> + "sha1(concat_ws('\0',boolean_col,string_col,long_col,int_col,double_col,float_col,timestamp_col,date_col,to_json(struct_col),struct_col.subfield2))")) + + val resultDf = addIdColumn(aggregatedDf, options) + resultDf.idColumn() shouldBe Some( + UnresolvedFunction( + "sha1", + Seq(UnresolvedFunction( + "concat_ws", + Seq( + Literal(UTF8String.fromString("\0"), StringType), + UnresolvedAttribute(Seq("boolean_col")), + UnresolvedAttribute(Seq("string_col")), + UnresolvedAttribute(Seq("long_col")), + UnresolvedAttribute(Seq("int_col")), + UnresolvedAttribute(Seq("double_col")), + UnresolvedAttribute(Seq("float_col")), + UnresolvedAttribute(Seq("timestamp_col")), + UnresolvedAttribute(Seq("date_col")), + UnresolvedFunction( + "to_json", + Seq(UnresolvedAttribute(Seq("struct_col"))), + isDistinct = false), + UnresolvedAttribute(Seq("struct_col", "subfield2"))), + isDistinct = false)), + isDistinct = false)) + resultDf.select(ID_COLUMN).distinct().count() shouldBe 1 + } +} diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala index 1cce47d1a..d23ad875e 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala @@ -5,13 +5,20 @@ package org.opensearch.flint.spark.covering +import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN +import org.opensearch.flint.spark.FlintSparkIndexOptions import org.scalatest.matchers.must.Matchers.contain import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.apache.spark.FlintSuite +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.plans.logical.Project class FlintSparkCoveringIndexSuite extends FlintSuite { + private val testTable = "spark_catalog.default.ci_test" + test("get covering index name") { val index = new FlintSparkCoveringIndex("ci", "spark_catalog.default.test", Map("name" -> "string")) @@ -54,4 +61,34 @@ class FlintSparkCoveringIndexSuite extends FlintSuite { new FlintSparkCoveringIndex("ci", "default.test", Map.empty) } } + + test("build batch with ID expression option") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (timestamp TIMESTAMP, name STRING) USING JSON") + val index = + FlintSparkCoveringIndex( + "name_idx", + testTable, + Map("name" -> "string"), + options = FlintSparkIndexOptions(Map("id_expression" -> "name"))) + + val batchDf = index.build(spark, None) + batchDf.idColumn() shouldBe Some(UnresolvedAttribute(Seq("name"))) + } + } + + test("build stream with ID expression option") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (name STRING, age INTEGER) USING JSON") + val index = FlintSparkCoveringIndex( + "name_idx", + testTable, + Map("name" -> "string"), + options = + FlintSparkIndexOptions(Map("auto_refresh" -> "true", "id_expression" -> "name"))) + + val streamDf = index.build(spark, Some(spark.table(testTable))) + streamDf.idColumn() shouldBe Some(UnresolvedAttribute(Seq("name"))) + } + } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala index 78d2eb09e..838eddf21 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala @@ -15,7 +15,7 @@ import org.scalatestplus.mockito.MockitoSugar.mock import org.apache.spark.FlintSuite import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} import org.apache.spark.sql.catalyst.dsl.expressions.{intToLiteral, stringToLiteral, DslAttr, DslExpression, StringToAttributeConversionHelper} import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan import org.apache.spark.sql.catalyst.expressions.Attribute @@ -36,6 +36,16 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { val testMvName = "spark_catalog.default.mv" val testQuery = "SELECT 1" + override def beforeAll(): Unit = { + super.beforeAll() + sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") + } + + override def afterAll(): Unit = { + sql(s"DROP TABLE $testTable") + super.afterAll() + } + test("get mv name") { val mv = FlintSparkMaterializedView(testMvName, testQuery, Array.empty, Map.empty) mv.name() shouldBe "flint_spark_catalog_default_mv" @@ -174,19 +184,59 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { } test("build stream should fail if there is aggregation but no windowing function") { - val testTable = "mv_build_test" - withTable(testTable) { - sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") + val mv = FlintSparkMaterializedView( + testMvName, + s"SELECT name, COUNT(*) AS count FROM $testTable GROUP BY name", + Array(testTable), + Map.empty) - val mv = FlintSparkMaterializedView( - testMvName, - s"SELECT name, COUNT(*) AS count FROM $testTable GROUP BY name", - Array(testTable), - Map.empty) + the[IllegalStateException] thrownBy + mv.buildStream(spark) + } - the[IllegalStateException] thrownBy - mv.buildStream(spark) - } + test("build batch with ID expression option") { + val testMvQuery = s"SELECT time, name FROM $testTable" + val mv = FlintSparkMaterializedView( + testMvName, + testMvQuery, + Array.empty, + Map.empty, + FlintSparkIndexOptions(Map("id_expression" -> "time"))) + + val batchDf = mv.build(spark, None) + batchDf.idColumn() shouldBe Some(UnresolvedAttribute(Seq("time"))) + } + + test("build batch should not have ID column if not provided") { + val testMvQuery = s"SELECT time, name FROM $testTable" + val mv = FlintSparkMaterializedView(testMvName, testMvQuery, Array.empty, Map.empty) + + val batchDf = mv.build(spark, None) + batchDf.idColumn() shouldBe None + } + + test("build stream with ID expression option") { + val mv = FlintSparkMaterializedView( + testMvName, + s"SELECT time, name FROM $testTable", + Array.empty, + Map.empty, + FlintSparkIndexOptions(Map("auto_refresh" -> "true", "id_expression" -> "time"))) + + val streamDf = mv.buildStream(spark) + streamDf.idColumn() shouldBe Some(UnresolvedAttribute(Seq("time"))) + } + + test("build stream should not have ID column if not provided") { + val mv = FlintSparkMaterializedView( + testMvName, + s"SELECT time, name FROM $testTable", + Array.empty, + Map.empty, + FlintSparkIndexOptions(Map("auto_refresh" -> "true"))) + + val streamDf = mv.buildStream(spark) + streamDf.idColumn() shouldBe None } private def withAggregateMaterializedView( @@ -194,19 +244,16 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { sourceTables: Array[String], options: Map[String, String])(codeBlock: LogicalPlan => Unit): Unit = { - withTable(testTable) { - sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") - val mv = - FlintSparkMaterializedView( - testMvName, - query, - sourceTables, - Map.empty, - FlintSparkIndexOptions(options)) - - val actualPlan = mv.buildStream(spark).queryExecution.logical - codeBlock(actualPlan) - } + val mv = + FlintSparkMaterializedView( + testMvName, + query, + sourceTables, + Map.empty, + FlintSparkIndexOptions(options)) + + val actualPlan = mv.buildStream(spark).queryExecution.logical + codeBlock(actualPlan) } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala index aac06a2c1..0791f9b7a 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala @@ -28,19 +28,23 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite { private val testTable = "spark_catalog.default.covering_sql_test" private val testIndex = "name_and_age" private val testFlintIndex = getFlintIndexName(testIndex, testTable) + private val testTimeSeriesTable = "spark_catalog.default.covering_sql_ts_test" + private val testFlintTimeSeriesIndex = getFlintIndexName(testIndex, testTimeSeriesTable) override def beforeEach(): Unit = { super.beforeEach() createPartitionedAddressTable(testTable) + createTimeSeriesTable(testTimeSeriesTable) } override def afterEach(): Unit = { super.afterEach() // Delete all test indices - deleteTestIndex(testFlintIndex) + deleteTestIndex(testFlintIndex, testFlintTimeSeriesIndex) sql(s"DROP TABLE $testTable") + sql(s"DROP TABLE $testTimeSeriesTable") } test("create covering index with auto refresh") { @@ -86,6 +90,41 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite { } } + test("create covering index with auto refresh and ID expression") { + sql(s""" + | CREATE INDEX $testIndex ON $testTimeSeriesTable + | (time, age, address) + | WITH ( + | auto_refresh = true, + | id_expression = 'address' + | ) + |""".stripMargin) + + val job = spark.streams.active.find(_.name == testFlintTimeSeriesIndex) + awaitStreamingComplete(job.get.id.toString) + + val indexData = flint.queryIndex(testFlintTimeSeriesIndex) + indexData.count() shouldBe 3 // only 3 rows left due to same ID + } + + test("create covering index with full refresh and ID expression") { + sql(s""" + | CREATE INDEX $testIndex ON $testTimeSeriesTable + | (time, age, address) + | WITH ( + | id_expression = 'address' + | ) + |""".stripMargin) + sql(s"REFRESH INDEX $testIndex ON $testTimeSeriesTable") + + val indexData = flint.queryIndex(testFlintTimeSeriesIndex) + indexData.count() shouldBe 3 // only 3 rows left due to same ID + + // Rerun should not generate duplicate data + sql(s"REFRESH INDEX $testIndex ON $testTimeSeriesTable") + indexData.count() shouldBe 3 + } + test("create covering index with index settings") { sql(s""" | CREATE INDEX $testIndex ON $testTable ( name ) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala index bf5e6309e..7dcd83897 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala @@ -154,6 +154,50 @@ class FlintSparkMaterializedViewSqlITSuite extends FlintSparkSuite { } } + test("create materialized view with auto refresh and ID expression") { + withTempDir { checkpointDir => + sql(s""" + | CREATE MATERIALIZED VIEW $testMvName + | AS $testQuery + | WITH ( + | auto_refresh = true, + | checkpoint_location = '${checkpointDir.getAbsolutePath}', + | watermark_delay = '1 Second', + | id_expression = "sha1(concat_ws('\0',startTime))" + | ) + |""".stripMargin) + + // Wait for streaming job complete current micro batch + val job = spark.streams.active.find(_.name == testFlintIndex) + job shouldBe defined + failAfter(streamingTimeout) { + job.get.processAllAvailable() + } + + flint.queryIndex(testFlintIndex).count() shouldBe 3 + } + } + + test("create materialized view with full refresh and ID expression") { + sql(s""" + | CREATE MATERIALIZED VIEW $testMvName + | AS $testQuery + | WITH ( + | id_expression = 'count' + | ) + |""".stripMargin) + + sql(s"REFRESH MATERIALIZED VIEW $testMvName") + + // 2 rows missing due to ID conflict intentionally + val indexData = flint.queryIndex(testFlintIndex) + indexData.count() shouldBe 2 + + // Rerun should not generate duplicate data + sql(s"REFRESH MATERIALIZED VIEW $testMvName") + indexData.count() shouldBe 2 + } + test("create materialized view with index settings") { sql(s""" | CREATE MATERIALIZED VIEW $testMvName