From 0771827b0981989f711f6462ce72d20d4131a086 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Thu, 14 Nov 2024 15:25:30 -0800 Subject: [PATCH 01/10] Generate id column for CV and MV Signed-off-by: Chen Dai --- .../flint/spark/FlintSparkIndex.scala | 36 +++++++++++++++- .../flint/spark/FlintSparkIndexOptions.scala | 11 ++++- .../covering/FlintSparkCoveringIndex.scala | 13 +++--- .../spark/mv/FlintSparkMaterializedView.scala | 9 ++-- .../FlintSparkCoveringIndexSqlITSuite.scala | 37 ++++++++++++++++- ...FlintSparkMaterializedViewSqlITSuite.scala | 41 +++++++++++++++++++ 6 files changed, 136 insertions(+), 11 deletions(-) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala index 44ea5188f..9e200a38d 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala @@ -11,8 +11,10 @@ import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.core.metadata.FlintJsonHelper._ -import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.{Column, DataFrame, SparkSession} +import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.flint.datatype.FlintDataType +import org.apache.spark.sql.functions.{col, concat_ws, expr, sha1} import org.apache.spark.sql.types.StructType /** @@ -117,6 +119,38 @@ object FlintSparkIndex { s"${parts(0)}.${parts(1)}.`${parts.drop(2).mkString(".")}`" } + /** + * Generate an ID column in the precedence below: + * ``` + * 1. Use ID expression provided in the index option; + * 2. SHA-1 based on all output columns if aggregated; + * 3. Otherwise, no ID column generated. + * ``` + * + * @param df + * which DataFrame to generate ID column + * @param options + * Flint index options + * @return + * DataFrame with/without ID column + */ + def generateIdColumn(df: DataFrame, options: FlintSparkIndexOptions): DataFrame = { + // Assume output rows must be unique if a simple query plan has aggregate operator + def isAggregated: Boolean = { + df.queryExecution.logical.exists(_.isInstanceOf[Aggregate]) + } + + val idExpr = options.idExpression() + if (idExpr.exists(_.nonEmpty)) { + df.withColumn(ID_COLUMN, expr(idExpr.get)) + } else if (isAggregated) { + val allOutputCols = df.columns.map(col) + df.withColumn(ID_COLUMN, sha1(concat_ws("\0", allOutputCols: _*))) + } else { + df + } + } + /** * Populate environment variables to persist in Flint metadata. * diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala index 9b58a696c..b866463c8 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala @@ -10,7 +10,7 @@ import java.util.{Collections, UUID} import org.json4s.{Formats, NoTypeHints} import org.json4s.native.JsonMethods._ import org.json4s.native.Serialization -import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName.{AUTO_REFRESH, CHECKPOINT_LOCATION, EXTRA_OPTIONS, INCREMENTAL_REFRESH, INDEX_SETTINGS, OptionName, OUTPUT_MODE, REFRESH_INTERVAL, SCHEDULER_MODE, WATERMARK_DELAY} +import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName.{AUTO_REFRESH, CHECKPOINT_LOCATION, EXTRA_OPTIONS, ID_EXPRESSION, INCREMENTAL_REFRESH, INDEX_SETTINGS, OptionName, OUTPUT_MODE, REFRESH_INTERVAL, SCHEDULER_MODE, WATERMARK_DELAY} import org.opensearch.flint.spark.FlintSparkIndexOptions.validateOptionNames import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.SchedulerMode import org.opensearch.flint.spark.scheduler.util.IntervalSchedulerParser @@ -96,6 +96,14 @@ case class FlintSparkIndexOptions(options: Map[String, String]) { */ def indexSettings(): Option[String] = getOptionValue(INDEX_SETTINGS) + /** + * An expression that generates unique value as source data row ID. + * + * @return + * ID expression + */ + def idExpression(): Option[String] = getOptionValue(ID_EXPRESSION) + /** * Extra streaming source options that can be simply passed to DataStreamReader or * Relation.options @@ -187,6 +195,7 @@ object FlintSparkIndexOptions { val WATERMARK_DELAY: OptionName.Value = Value("watermark_delay") val OUTPUT_MODE: OptionName.Value = Value("output_mode") val INDEX_SETTINGS: OptionName.Value = Value("index_settings") + val ID_EXPRESSION: OptionName.Value = Value("id_expression") val EXTRA_OPTIONS: OptionName.Value = Value("extra_options") } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala index 8748bf874..66cb59869 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala @@ -10,7 +10,7 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.spark._ -import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateSchema, metadataBuilder, quotedTableName} +import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateIdColumn, generateSchema, metadataBuilder, quotedTableName, ID_COLUMN} import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.{getFlintIndexName, COVERING_INDEX_TYPE} @@ -71,10 +71,13 @@ case class FlintSparkCoveringIndex( val job = df.getOrElse(spark.read.table(quotedTableName(tableName))) // Add optional filtering condition - filterCondition - .map(job.where) - .getOrElse(job) - .select(colNames.head, colNames.tail: _*) + val batchDf = + filterCondition + .map(job.where) + .getOrElse(job) + .select(colNames.head, colNames.tail: _*) + + generateIdColumn(batchDf, options) } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala index d5c450e7e..3892ccb82 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala @@ -13,7 +13,7 @@ import scala.collection.convert.ImplicitConversions.`map AsScala` import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndex, FlintSparkIndexBuilder, FlintSparkIndexOptions} -import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateSchema, metadataBuilder, StreamingRefresh} +import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateIdColumn, generateSchema, metadataBuilder, ID_COLUMN, StreamingRefresh} import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.opensearch.flint.spark.function.TumbleFunction import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{getFlintIndexName, MV_INDEX_TYPE} @@ -81,7 +81,8 @@ case class FlintSparkMaterializedView( override def build(spark: SparkSession, df: Option[DataFrame]): DataFrame = { require(df.isEmpty, "materialized view doesn't support reading from other data frame") - spark.sql(query) + val batchDf = spark.sql(query) + generateIdColumn(batchDf, options) } override def buildStream(spark: SparkSession): DataFrame = { @@ -99,7 +100,9 @@ case class FlintSparkMaterializedView( case relation: UnresolvedRelation if !relation.isStreaming => relation.copy(isStreaming = true, options = optionsWithExtra(spark, relation)) } - logicalPlanToDataFrame(spark, streamingPlan) + + val streamingDf = logicalPlanToDataFrame(spark, streamingPlan) + generateIdColumn(streamingDf, options) } private def watermark(timeCol: Attribute, child: LogicalPlan) = { diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala index aac06a2c1..9acef8c25 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala @@ -28,19 +28,23 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite { private val testTable = "spark_catalog.default.covering_sql_test" private val testIndex = "name_and_age" private val testFlintIndex = getFlintIndexName(testIndex, testTable) + private val testTimeSeriesTable = "spark_catalog.default.covering_sql_ts_test" + private val testFlintTimeSeriesIndex = getFlintIndexName(testIndex, testTimeSeriesTable) override def beforeEach(): Unit = { super.beforeEach() createPartitionedAddressTable(testTable) + createTimeSeriesTable(testTimeSeriesTable) } override def afterEach(): Unit = { super.afterEach() // Delete all test indices - deleteTestIndex(testFlintIndex) + deleteTestIndex(testFlintIndex, testFlintTimeSeriesIndex) sql(s"DROP TABLE $testTable") + sql(s"DROP TABLE $testTimeSeriesTable") } test("create covering index with auto refresh") { @@ -86,6 +90,37 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite { } } + test("create covering index with auto refresh and ID expression") { + sql(s""" + | CREATE INDEX $testIndex ON $testTimeSeriesTable + | (time, age, address) + | WITH ( + | auto_refresh = true, + | id_expression = 'address' + | ) + |""".stripMargin) + + val job = spark.streams.active.find(_.name == testFlintTimeSeriesIndex) + awaitStreamingComplete(job.get.id.toString) + + val indexData = flint.queryIndex(testFlintTimeSeriesIndex) + indexData.count() shouldBe 3 // only 3 rows left due to same ID + } + + test("create covering index with full refresh and ID expression") { + sql(s""" + | CREATE INDEX $testIndex ON $testTimeSeriesTable + | (time, age, address) + | WITH ( + | id_expression = 'address' + | ) + |""".stripMargin) + sql(s"REFRESH INDEX $testIndex ON $testTimeSeriesTable") + + val indexData = flint.queryIndex(testFlintTimeSeriesIndex) + indexData.count() shouldBe 3 // only 3 rows left due to same ID + } + test("create covering index with index settings") { sql(s""" | CREATE INDEX $testIndex ON $testTable ( name ) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala index bf5e6309e..8fb220a77 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala @@ -154,6 +154,47 @@ class FlintSparkMaterializedViewSqlITSuite extends FlintSparkSuite { } } + test("create materialized view with auto refresh and ID expression") { + withTempDir { checkpointDir => + sql(s""" + | CREATE MATERIALIZED VIEW $testMvName + | AS $testQuery + | WITH ( + | auto_refresh = true, + | checkpoint_location = '${checkpointDir.getAbsolutePath}', + | watermark_delay = '1 Second', + | id_expression = 'count' + | ) + |""".stripMargin) + + // Wait for streaming job complete current micro batch + val job = spark.streams.active.find(_.name == testFlintIndex) + job shouldBe defined + failAfter(streamingTimeout) { + job.get.processAllAvailable() + } + + // 1 row missing due to ID conflict intentionally + flint.queryIndex(testFlintIndex).count() shouldBe 2 + } + } + + test("create materialized view with full refresh and ID expression") { + sql(s""" + | CREATE MATERIALIZED VIEW $testMvName + | AS $testQuery + | WITH ( + | id_expression = 'count' + | ) + |""".stripMargin) + + sql(s"REFRESH MATERIALIZED VIEW $testMvName") + + // 2 rows missing due to ID conflict intentionally + val indexData = spark.read.format(FLINT_DATASOURCE).load(testFlintIndex) + indexData.count() shouldBe 2 + } + test("create materialized view with index settings") { sql(s""" | CREATE MATERIALIZED VIEW $testMvName From e3155d568a18d681767af5a32757e14179444ea9 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Fri, 15 Nov 2024 09:39:09 -0800 Subject: [PATCH 02/10] Add UT for CV and MV Signed-off-by: Chen Dai --- .../flint/spark/FlintSparkIndex.scala | 20 +-- .../FlintSparkCoveringIndexSuite.scala | 86 ++++++++++ .../mv/FlintSparkMaterializedViewSuite.scala | 149 +++++++++++++++++- 3 files changed, 240 insertions(+), 15 deletions(-) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala index 9e200a38d..e2a60f319 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala @@ -136,18 +136,18 @@ object FlintSparkIndex { */ def generateIdColumn(df: DataFrame, options: FlintSparkIndexOptions): DataFrame = { // Assume output rows must be unique if a simple query plan has aggregate operator - def isAggregated: Boolean = { + def isAggregated: Boolean = df.queryExecution.logical.exists(_.isInstanceOf[Aggregate]) - } - val idExpr = options.idExpression() - if (idExpr.exists(_.nonEmpty)) { - df.withColumn(ID_COLUMN, expr(idExpr.get)) - } else if (isAggregated) { - val allOutputCols = df.columns.map(col) - df.withColumn(ID_COLUMN, sha1(concat_ws("\0", allOutputCols: _*))) - } else { - df + options.idExpression() match { + case Some(idExpr) if idExpr.nonEmpty => + df.withColumn(ID_COLUMN, expr(idExpr)) + + case None if isAggregated => + val allOutputCols = df.columns.map(col) + df.withColumn(ID_COLUMN, sha1(concat_ws("\0", allOutputCols: _*))) + + case _ => df } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala index 1cce47d1a..e032ac122 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala @@ -5,13 +5,19 @@ package org.opensearch.flint.spark.covering +import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN +import org.opensearch.flint.spark.FlintSparkIndexOptions import org.scalatest.matchers.must.Matchers.contain import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.apache.spark.FlintSuite +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.{col, expr} class FlintSparkCoveringIndexSuite extends FlintSuite { + private val testTable = "spark_catalog.default.ci_test" + test("get covering index name") { val index = new FlintSparkCoveringIndex("ci", "spark_catalog.default.test", Map("name" -> "string")) @@ -54,4 +60,84 @@ class FlintSparkCoveringIndexSuite extends FlintSuite { new FlintSparkCoveringIndex("ci", "default.test", Map.empty) } } + + test("build batch with ID expression option") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (timestamp TIMESTAMP, name STRING) USING JSON") + val index = + FlintSparkCoveringIndex( + "name_idx", + testTable, + Map("name" -> "string"), + options = FlintSparkIndexOptions(Map("id_expression" -> "name"))) + + comparePlans( + index.build(spark, None).queryExecution.logical, + spark + .table(testTable) + .select(col("name")) + .withColumn(ID_COLUMN, expr("name")) + .queryExecution + .logical, + checkAnalysis = false) + } + } + + test("build batch should not have ID column without ID expression option") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (name STRING, age INTEGER) USING JSON") + val index = FlintSparkCoveringIndex("name_idx", testTable, Map("name" -> "string")) + + comparePlans( + index.build(spark, None).queryExecution.logical, + spark + .table(testTable) + .select(col("name")) + .queryExecution + .logical, + checkAnalysis = false) + } + } + + test("build stream with ID expression option") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (name STRING, age INTEGER) USING JSON") + val index = FlintSparkCoveringIndex( + "name_idx", + testTable, + Map("name" -> "string"), + options = + FlintSparkIndexOptions(Map("auto_refresh" -> "true", "id_expression" -> "name"))) + + comparePlans( + index.build(spark, Some(spark.table(testTable))).queryExecution.logical, + spark + .table(testTable) + .select("name") + .withColumn(ID_COLUMN, col("name")) + .queryExecution + .logical, + checkAnalysis = false) + } + } + + test("build stream should not have ID column without ID expression option") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (name STRING, age INTEGER) USING JSON") + val index = FlintSparkCoveringIndex( + "name_idx", + testTable, + Map("name" -> "string"), + options = FlintSparkIndexOptions(Map("auto_refresh" -> "true"))) + + comparePlans( + index.build(spark, Some(spark.table(testTable))).queryExecution.logical, + spark + .table(testTable) + .select(col("name")) + .queryExecution + .logical, + checkAnalysis = false) + } + } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala index 78d2eb09e..4cc06a1b6 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala @@ -7,6 +7,7 @@ package org.opensearch.flint.spark.mv import scala.collection.JavaConverters.{mapAsJavaMapConverter, mapAsScalaMapConverter} +import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN import org.opensearch.flint.spark.FlintSparkIndexOptions import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE import org.opensearch.flint.spark.mv.FlintSparkMaterializedViewSuite.{streamingRelation, StreamingDslLogicalPlan} @@ -15,12 +16,14 @@ import org.scalatestplus.mockito.MockitoSugar.mock import org.apache.spark.FlintSuite import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} import org.apache.spark.sql.catalyst.dsl.expressions.{intToLiteral, stringToLiteral, DslAttr, DslExpression, StringToAttributeConversionHelper} import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, ConcatWs, Literal, Sha1} +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan, Project} import org.apache.spark.sql.catalyst.util.IntervalUtils +import org.apache.spark.sql.functions.{col, concat_ws, expr, sha1} +import org.apache.spark.sql.types.StringType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.unsafe.types.UTF8String @@ -107,7 +110,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { | FROM $testTable | GROUP BY TUMBLE(time, '1 Minute') |""".stripMargin - val options = Map("watermark_delay" -> "30 Seconds") + val options = Map("watermark_delay" -> "30 Seconds", "id_expression" -> "") withAggregateMaterializedView(testQuery, Array(testTable), options) { actualPlan => comparePlans( @@ -132,7 +135,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { | WHERE age > 30 | GROUP BY TUMBLE(time, '1 Minute') |""".stripMargin - val options = Map("watermark_delay" -> "30 Seconds") + val options = Map("watermark_delay" -> "30 Seconds", "id_expression" -> "") withAggregateMaterializedView(testQuery, Array(testTable), options) { actualPlan => comparePlans( @@ -189,6 +192,142 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { } } + test("build batch with ID expression option") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") + val testMvQuery = s"SELECT time, name FROM $testTable" + val mv = FlintSparkMaterializedView( + testMvName, + testMvQuery, + Array.empty, + Map.empty, + FlintSparkIndexOptions(Map("id_expression" -> "time"))) + + comparePlans( + mv.build(spark, None).queryExecution.logical, + spark + .sql(testMvQuery) + .withColumn(ID_COLUMN, expr("time")) + .queryExecution + .logical, + checkAnalysis = false) + } + } + + test("build batch should not have ID column if non-aggregated") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") + val testMvQuery = s"SELECT time, name FROM $testTable" + val mv = FlintSparkMaterializedView(testMvName, testMvQuery, Array.empty, Map.empty) + + comparePlans( + mv.build(spark, None).queryExecution.logical, + spark.sql(testMvQuery).queryExecution.logical, + checkAnalysis = false) + } + } + + test("build batch should have ID column if aggregated") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") + val mv = FlintSparkMaterializedView( + testMvName, + s""" SELECT time, name, AVG(age) + | FROM $testTable + | GROUP BY time, name""".stripMargin, + Array.empty, + Map.empty) + + comparePlans( + mv.build(spark, None).queryExecution.logical, + spark + .table(testTable) + .groupBy("time", "name") + .avg("age") + .withColumn(ID_COLUMN, sha1(concat_ws("\0", col("time"), col("name"), col("avg(age)")))) + .queryExecution + .logical, + checkAnalysis = false) + } + } + + test("build stream with ID expression option") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") + val mv = FlintSparkMaterializedView( + testMvName, + s"SELECT time, name FROM $testTable", + Array.empty, + Map.empty, + FlintSparkIndexOptions(Map("auto_refresh" -> "true", "id_expression" -> "name"))) + + mv.buildStream(spark).queryExecution.logical.exists { + case Project(projectList, _) => + projectList.exists { + case Alias(UnresolvedAttribute(Seq("name")), ID_COLUMN) => true + case _ => false + } + case _ => false + } shouldBe true + } + } + + test("build stream should not have ID column if non-aggregated") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") + val mv = FlintSparkMaterializedView( + testMvName, + s"SELECT time, name FROM $testTable", + Array.empty, + Map.empty, + FlintSparkIndexOptions(Map("auto_refresh" -> "true"))) + + mv.buildStream(spark).queryExecution.logical.exists { + case Project(projectList, _) => + projectList.forall(_.name != ID_COLUMN) + case _ => false + } shouldBe true + } + } + + test("build stream should have ID column if aggregated") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") + val testMvQuery = + s""" + | SELECT + | window.start AS startTime, + | COUNT(*) AS count + | FROM $testTable + | GROUP BY TUMBLE(time, '1 Minute') + |""".stripMargin + val mv = FlintSparkMaterializedView( + testMvName, + testMvQuery, + Array.empty, + Map.empty, + FlintSparkIndexOptions(Map("auto_refresh" -> "true", "watermark_delay" -> "10 Seconds"))) + + mv.buildStream(spark).queryExecution.logical.exists { + case Project(projectList, _) => + val asciiNull = UTF8String.fromString("\0") + projectList.exists { + case Alias( + Sha1( + ConcatWs( + Seq( + Literal(`asciiNull`, StringType), + UnresolvedAttribute(Seq("startTime")), + UnresolvedAttribute(Seq("count"))))), + ID_COLUMN) => + true + case _ => false + } + case _ => false + } shouldBe true + } + } + private def withAggregateMaterializedView( query: String, sourceTables: Array[String], From 042b30e61fd53d32f881dd0fbd990351a71895f5 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Wed, 20 Nov 2024 11:10:51 -0800 Subject: [PATCH 03/10] Update with doc and UT Signed-off-by: Chen Dai --- docs/index.md | 2 + .../flint/spark/FlintSparkIndexSuite.scala | 104 ++++++++++++++++++ 2 files changed, 106 insertions(+) create mode 100644 flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala diff --git a/docs/index.md b/docs/index.md index abc801bde..1c02f0cb0 100644 --- a/docs/index.md +++ b/docs/index.md @@ -394,6 +394,7 @@ User can provide the following options in `WITH` clause of create statement: + `watermark_delay`: a string as time expression for how late data can come and still be processed, e.g. 1 minute, 10 seconds. This is required by auto and incremental refresh on materialized view if it has aggregation in the query. + `output_mode`: a mode string that describes how data will be written to streaming sink. If unspecified, default append mode will be applied. + `index_settings`: a JSON string as index settings for OpenSearch index that will be created. Please follow the format in OpenSearch documentation. If unspecified, default OpenSearch index settings will be applied. ++ `id_expression`: an expression string that generates an ID column to avoid duplicate data when index refresh job restart or any retry attempt during an index refresh. If an empty string is provided, no ID column will be generated. + `extra_options`: a JSON string as extra options that can be passed to Spark streaming source and sink API directly. Use qualified source table name (because there could be multiple) and "sink", e.g. '{"sink": "{key: val}", "table1": {key: val}}' Note that the index option name is case-sensitive. Here is an example: @@ -406,6 +407,7 @@ WITH ( watermark_delay = '1 Second', output_mode = 'complete', index_settings = '{"number_of_shards": 2, "number_of_replicas": 3}', + id_expression = 'uuid()', extra_options = '{"spark_catalog.default.alb_logs": {"maxFilesPerTrigger": "1"}}' ) ``` diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala new file mode 100644 index 000000000..1ab0f8e02 --- /dev/null +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala @@ -0,0 +1,104 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import org.opensearch.flint.spark.FlintSparkIndex.{generateIdColumn, ID_COLUMN} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.FlintSuite +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.types.StructType + +class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { + + test("should generate ID column if ID expression is provided") { + val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name") + val options = new FlintSparkIndexOptions(Map("id_expression" -> "id + 10")) + + val resultDf = generateIdColumn(df, options) + checkAnswer(resultDf.select(ID_COLUMN), Seq(Row(11), Row(12))) + } + + test("should not generate ID column if ID expression is empty") { + val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name") + val options = FlintSparkIndexOptions.empty + + val resultDf = generateIdColumn(df, options) + resultDf.columns should not contain ID_COLUMN + } + + test("should generate ID column for aggregated query") { + val df = spark + .createDataFrame(Seq((1, "Alice"), (2, "Bob"), (3, "Alice"))) + .toDF("id", "name") + .groupBy("name") + .count() + val options = FlintSparkIndexOptions.empty + + val resultDf = generateIdColumn(df, options) + resultDf.select(ID_COLUMN).distinct().count() shouldBe 2 + } + + test("should not generate ID column for aggregated query if ID expression is empty") { + val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name") + val options = FlintSparkIndexOptions.empty + + val resultDf = generateIdColumn(df, options) + resultDf.columns should not contain ID_COLUMN + } + + test("should not generate ID column if ID expression is not provided") { + val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name") + val options = FlintSparkIndexOptions.empty + + val resultDf = generateIdColumn(df, options) + resultDf.columns should not contain ID_COLUMN + } + + test("should generate ID column for aggregated query with multiple columns") { + val schema = StructType.fromDDL(""" + boolean_col BOOLEAN, + string_col STRING, + long_col LONG, + int_col INT, + double_col DOUBLE, + float_col FLOAT, + timestamp_col TIMESTAMP, + date_col DATE, + struct_col STRUCT + """) + val data = Seq( + Row( + true, + "Alice", + 100L, + 10, + 10.5, + 3.14f, + java.sql.Timestamp.valueOf("2024-01-01 10:00:00"), + java.sql.Date.valueOf("2024-01-01"), + Row("sub1", 1))) + + val aggregatedDf = spark + .createDataFrame(sparkContext.parallelize(data), schema) + .groupBy( + "boolean_col", + "string_col", + "long_col", + "int_col", + "double_col", + "float_col", + "timestamp_col", + "date_col", + "struct_col.subfield1", + "struct_col.subfield2") + .count() + + val options = FlintSparkIndexOptions.empty + val resultDf = generateIdColumn(aggregatedDf, options) + resultDf.select(ID_COLUMN).distinct().count() shouldBe 1 + } +} From 32169e2302d6c6d2c51fc4b538d5a37e26b65e0a Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Wed, 20 Nov 2024 15:10:16 -0800 Subject: [PATCH 04/10] Handle struct type in tumble function Signed-off-by: Chen Dai --- .../flint/spark/FlintSparkIndex.scala | 19 +++++++++++--- .../flint/spark/FlintSparkIndexSuite.scala | 25 +++++++++++++++++-- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala index e2a60f319..130be61fc 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala @@ -12,10 +12,10 @@ import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.core.metadata.FlintJsonHelper._ import org.apache.spark.sql.{Column, DataFrame, SparkSession} -import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project} import org.apache.spark.sql.flint.datatype.FlintDataType -import org.apache.spark.sql.functions.{col, concat_ws, expr, sha1} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.functions.{col, concat_ws, expr, sha1, to_json} +import org.apache.spark.sql.types.{MapType, StructType} /** * Flint index interface in Spark. @@ -144,7 +144,18 @@ object FlintSparkIndex { df.withColumn(ID_COLUMN, expr(idExpr)) case None if isAggregated => - val allOutputCols = df.columns.map(col) + // Since concat doesn't support struct or map type, convert these to json which is more + // deterministic than casting to string, as its format may vary across Spark versions. + val allOutputCols = df.schema.fields.map { field => + field.dataType match { + case _: StructType | _: MapType => + to_json(col(field.name)) + case _ => + col(field.name) + } + } + + // TODO: use only grouping columns df.withColumn(ID_COLUMN, sha1(concat_ws("\0", allOutputCols: _*))) case _ => df diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala index 1ab0f8e02..e67818532 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala @@ -5,6 +5,8 @@ package org.opensearch.flint.spark +import java.sql.Timestamp + import org.opensearch.flint.spark.FlintSparkIndex.{generateIdColumn, ID_COLUMN} import org.scalatest.matchers.should.Matchers @@ -42,6 +44,25 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { resultDf.select(ID_COLUMN).distinct().count() shouldBe 2 } + test("should generate ID column for aggregated query using tumble function") { + val df = spark + .createDataFrame( + Seq( + (Timestamp.valueOf("2023-01-01 00:00:00"), 1, "Alice"), + (Timestamp.valueOf("2023-01-01 00:10:00"), 2, "Bob"), + (Timestamp.valueOf("2023-01-01 00:15:00"), 3, "Alice"))) + .toDF("timestamp", "id", "name") + val groupByDf = df + .selectExpr("TUMBLE(timestamp, '10 minutes') as window", "name") + .groupBy("window", "name") + .count() + .select("window.start", "name", "count") + val options = FlintSparkIndexOptions.empty + + val resultDf = generateIdColumn(groupByDf, options) + resultDf.select(ID_COLUMN).distinct().count() shouldBe 3 + } + test("should not generate ID column for aggregated query if ID expression is empty") { val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name") val options = FlintSparkIndexOptions.empty @@ -58,7 +79,7 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { resultDf.columns should not contain ID_COLUMN } - test("should generate ID column for aggregated query with multiple columns") { + test("should generate ID column for aggregated query with various column types") { val schema = StructType.fromDDL(""" boolean_col BOOLEAN, string_col STRING, @@ -93,7 +114,7 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { "float_col", "timestamp_col", "date_col", - "struct_col.subfield1", + "struct_col", "struct_col.subfield2") .count() From 71a9a346aa9a486f50a1f4ebcdbc34a90c71d623 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Thu, 21 Nov 2024 16:13:38 -0800 Subject: [PATCH 05/10] Refactor UT and doc Signed-off-by: Chen Dai --- .../flint/spark/FlintSparkIndex.scala | 2 +- .../spark/FlintSparkIndexOptionsSuite.scala | 5 +- .../flint/spark/FlintSparkIndexSuite.scala | 49 +-- .../FlintSparkCoveringIndexSuite.scala | 70 +---- .../mv/FlintSparkMaterializedViewSuite.scala | 292 ++++++++++-------- 5 files changed, 190 insertions(+), 228 deletions(-) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala index 130be61fc..9bddfaf22 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala @@ -155,7 +155,7 @@ object FlintSparkIndex { } } - // TODO: use only grouping columns + // TODO: 1) use only grouping columns; 2) ensure aggregation is on top level df.withColumn(ID_COLUMN, sha1(concat_ws("\0", allOutputCols: _*))) case _ => df diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexOptionsSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexOptionsSuite.scala index d7de6d29b..f752ae68a 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexOptionsSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexOptionsSuite.scala @@ -6,7 +6,6 @@ package org.opensearch.flint.spark import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName._ -import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.SchedulerMode import org.scalatest.matchers.should.Matchers import org.apache.spark.FlintSuite @@ -22,6 +21,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers { WATERMARK_DELAY.toString shouldBe "watermark_delay" OUTPUT_MODE.toString shouldBe "output_mode" INDEX_SETTINGS.toString shouldBe "index_settings" + ID_EXPRESSION.toString shouldBe "id_expression" EXTRA_OPTIONS.toString shouldBe "extra_options" } @@ -36,6 +36,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers { "watermark_delay" -> "30 Seconds", "output_mode" -> "complete", "index_settings" -> """{"number_of_shards": 3}""", + "id_expression" -> """sha1(col("timestamp"))""", "extra_options" -> """ { | "alb_logs": { @@ -55,6 +56,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers { options.watermarkDelay() shouldBe Some("30 Seconds") options.outputMode() shouldBe Some("complete") options.indexSettings() shouldBe Some("""{"number_of_shards": 3}""") + options.idExpression() shouldBe Some("""sha1(col("timestamp"))""") options.extraSourceOptions("alb_logs") shouldBe Map("opt1" -> "val1") options.extraSinkOptions() shouldBe Map("opt2" -> "val2", "opt3" -> "val3") } @@ -83,6 +85,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers { options.watermarkDelay() shouldBe empty options.outputMode() shouldBe empty options.indexSettings() shouldBe empty + options.idExpression() shouldBe empty options.extraSourceOptions("alb_logs") shouldBe empty options.extraSinkOptions() shouldBe empty options.optionsWithDefault should contain("auto_refresh" -> "false") diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala index e67818532..25e1ed591 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala @@ -24,6 +24,14 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { checkAnswer(resultDf.select(ID_COLUMN), Seq(Row(11), Row(12))) } + test("should not generate ID column if ID expression is not provided") { + val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name") + val options = FlintSparkIndexOptions.empty + + val resultDf = generateIdColumn(df, options) + resultDf.columns should not contain ID_COLUMN + } + test("should not generate ID column if ID expression is empty") { val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name") val options = FlintSparkIndexOptions.empty @@ -41,45 +49,11 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { val options = FlintSparkIndexOptions.empty val resultDf = generateIdColumn(df, options) + resultDf.columns should contain(ID_COLUMN) resultDf.select(ID_COLUMN).distinct().count() shouldBe 2 } - test("should generate ID column for aggregated query using tumble function") { - val df = spark - .createDataFrame( - Seq( - (Timestamp.valueOf("2023-01-01 00:00:00"), 1, "Alice"), - (Timestamp.valueOf("2023-01-01 00:10:00"), 2, "Bob"), - (Timestamp.valueOf("2023-01-01 00:15:00"), 3, "Alice"))) - .toDF("timestamp", "id", "name") - val groupByDf = df - .selectExpr("TUMBLE(timestamp, '10 minutes') as window", "name") - .groupBy("window", "name") - .count() - .select("window.start", "name", "count") - val options = FlintSparkIndexOptions.empty - - val resultDf = generateIdColumn(groupByDf, options) - resultDf.select(ID_COLUMN).distinct().count() shouldBe 3 - } - - test("should not generate ID column for aggregated query if ID expression is empty") { - val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name") - val options = FlintSparkIndexOptions.empty - - val resultDf = generateIdColumn(df, options) - resultDf.columns should not contain ID_COLUMN - } - - test("should not generate ID column if ID expression is not provided") { - val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name") - val options = FlintSparkIndexOptions.empty - - val resultDf = generateIdColumn(df, options) - resultDf.columns should not contain ID_COLUMN - } - - test("should generate ID column for aggregated query with various column types") { + test("should generate ID column for various column types") { val schema = StructType.fromDDL(""" boolean_col BOOLEAN, string_col STRING, @@ -117,9 +91,10 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { "struct_col", "struct_col.subfield2") .count() - val options = FlintSparkIndexOptions.empty + val resultDf = generateIdColumn(aggregatedDf, options) + resultDf.columns should contain(ID_COLUMN) resultDf.select(ID_COLUMN).distinct().count() shouldBe 1 } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala index e032ac122..75a219fb6 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala @@ -11,7 +11,9 @@ import org.scalatest.matchers.must.Matchers.contain import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.apache.spark.FlintSuite -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.functions.{col, expr} class FlintSparkCoveringIndexSuite extends FlintSuite { @@ -71,31 +73,12 @@ class FlintSparkCoveringIndexSuite extends FlintSuite { Map("name" -> "string"), options = FlintSparkIndexOptions(Map("id_expression" -> "name"))) - comparePlans( - index.build(spark, None).queryExecution.logical, - spark - .table(testTable) - .select(col("name")) - .withColumn(ID_COLUMN, expr("name")) - .queryExecution - .logical, - checkAnalysis = false) - } - } - - test("build batch should not have ID column without ID expression option") { - withTable(testTable) { - sql(s"CREATE TABLE $testTable (name STRING, age INTEGER) USING JSON") - val index = FlintSparkCoveringIndex("name_idx", testTable, Map("name" -> "string")) - - comparePlans( - index.build(spark, None).queryExecution.logical, - spark - .table(testTable) - .select(col("name")) - .queryExecution - .logical, - checkAnalysis = false) + val batchDf = index.build(spark, None) + batchDf.queryExecution.logical.collectFirst { case Project(projectList, _) => + projectList.collectFirst { case Alias(child, ID_COLUMN) => + child + } + }.flatten shouldBe Some(UnresolvedAttribute(Seq("name"))) } } @@ -109,35 +92,12 @@ class FlintSparkCoveringIndexSuite extends FlintSuite { options = FlintSparkIndexOptions(Map("auto_refresh" -> "true", "id_expression" -> "name"))) - comparePlans( - index.build(spark, Some(spark.table(testTable))).queryExecution.logical, - spark - .table(testTable) - .select("name") - .withColumn(ID_COLUMN, col("name")) - .queryExecution - .logical, - checkAnalysis = false) - } - } - - test("build stream should not have ID column without ID expression option") { - withTable(testTable) { - sql(s"CREATE TABLE $testTable (name STRING, age INTEGER) USING JSON") - val index = FlintSparkCoveringIndex( - "name_idx", - testTable, - Map("name" -> "string"), - options = FlintSparkIndexOptions(Map("auto_refresh" -> "true"))) - - comparePlans( - index.build(spark, Some(spark.table(testTable))).queryExecution.logical, - spark - .table(testTable) - .select(col("name")) - .queryExecution - .logical, - checkAnalysis = false) + val streamDf = index.build(spark, Some(spark.table(testTable))) + streamDf.queryExecution.logical.collectFirst { case Project(projectList, _) => + projectList.collectFirst { case Alias(child, ID_COLUMN) => + child + } + }.flatten shouldBe Some(UnresolvedAttribute(Seq("name"))) } } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala index 4cc06a1b6..200efbe97 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala @@ -10,7 +10,7 @@ import scala.collection.JavaConverters.{mapAsJavaMapConverter, mapAsScalaMapConv import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN import org.opensearch.flint.spark.FlintSparkIndexOptions import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE -import org.opensearch.flint.spark.mv.FlintSparkMaterializedViewSuite.{streamingRelation, StreamingDslLogicalPlan} +import org.opensearch.flint.spark.mv.FlintSparkMaterializedViewSuite.{streamingRelation, DataFrameIdColumnExtractor, StreamingDslLogicalPlan} import org.scalatest.matchers.should.Matchers._ import org.scalatestplus.mockito.MockitoSugar.mock @@ -19,10 +19,9 @@ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} import org.apache.spark.sql.catalyst.dsl.expressions.{intToLiteral, stringToLiteral, DslAttr, DslExpression, StringToAttributeConversionHelper} import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, ConcatWs, Literal, Sha1} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, ConcatWs, Expression, Literal, Sha1} import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan, Project} import org.apache.spark.sql.catalyst.util.IntervalUtils -import org.apache.spark.sql.functions.{col, concat_ws, expr, sha1} import org.apache.spark.sql.types.StringType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.unsafe.types.UTF8String @@ -39,6 +38,16 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { val testMvName = "spark_catalog.default.mv" val testQuery = "SELECT 1" + override def beforeAll(): Unit = { + super.beforeAll() + sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") + } + + override def afterAll(): Unit = { + sql(s"DROP TABLE $testTable") + super.afterAll() + } + test("get mv name") { val mv = FlintSparkMaterializedView(testMvName, testQuery, Array.empty, Map.empty) mv.name() shouldBe "flint_spark_catalog_default_mv" @@ -177,155 +186,162 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { } test("build stream should fail if there is aggregation but no windowing function") { - val testTable = "mv_build_test" - withTable(testTable) { - sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") - - val mv = FlintSparkMaterializedView( - testMvName, - s"SELECT name, COUNT(*) AS count FROM $testTable GROUP BY name", - Array(testTable), - Map.empty) + val mv = FlintSparkMaterializedView( + testMvName, + s"SELECT name, COUNT(*) AS count FROM $testTable GROUP BY name", + Array(testTable), + Map.empty) - the[IllegalStateException] thrownBy - mv.buildStream(spark) - } + the[IllegalStateException] thrownBy + mv.buildStream(spark) } test("build batch with ID expression option") { - withTable(testTable) { - sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") - val testMvQuery = s"SELECT time, name FROM $testTable" - val mv = FlintSparkMaterializedView( - testMvName, - testMvQuery, - Array.empty, - Map.empty, - FlintSparkIndexOptions(Map("id_expression" -> "time"))) + val testMvQuery = s"SELECT time, name FROM $testTable" + val mv = FlintSparkMaterializedView( + testMvName, + testMvQuery, + Array.empty, + Map.empty, + FlintSparkIndexOptions(Map("id_expression" -> "time"))) - comparePlans( - mv.build(spark, None).queryExecution.logical, - spark - .sql(testMvQuery) - .withColumn(ID_COLUMN, expr("time")) - .queryExecution - .logical, - checkAnalysis = false) - } + val batchDf = mv.build(spark, None) + batchDf.idColumn() shouldBe Some(UnresolvedAttribute(Seq("time"))) } test("build batch should not have ID column if non-aggregated") { - withTable(testTable) { - sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") - val testMvQuery = s"SELECT time, name FROM $testTable" - val mv = FlintSparkMaterializedView(testMvName, testMvQuery, Array.empty, Map.empty) + val testMvQuery = s"SELECT time, name FROM $testTable" + val mv = FlintSparkMaterializedView(testMvName, testMvQuery, Array.empty, Map.empty) - comparePlans( - mv.build(spark, None).queryExecution.logical, - spark.sql(testMvQuery).queryExecution.logical, - checkAnalysis = false) - } + val batchDf = mv.build(spark, None) + batchDf.idColumn() shouldBe None } test("build batch should have ID column if aggregated") { - withTable(testTable) { - sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") - val mv = FlintSparkMaterializedView( - testMvName, - s""" SELECT time, name, AVG(age) + val mv = FlintSparkMaterializedView( + testMvName, + s""" SELECT time, name, AVG(age) AS avg | FROM $testTable | GROUP BY time, name""".stripMargin, - Array.empty, - Map.empty) + Array.empty, + Map.empty) + + val batchDf = mv.build(spark, None) + batchDf.idColumn() shouldBe Some( + Sha1( + ConcatWs( + Seq( + Literal(UTF8String.fromString("\0"), StringType), + UnresolvedAttribute(Seq("time")), + UnresolvedAttribute(Seq("name")), + UnresolvedAttribute(Seq("avg")))))) + } - comparePlans( - mv.build(spark, None).queryExecution.logical, - spark - .table(testTable) - .groupBy("time", "name") - .avg("age") - .withColumn(ID_COLUMN, sha1(concat_ws("\0", col("time"), col("name"), col("avg(age)")))) - .queryExecution - .logical, - checkAnalysis = false) - } + test("build batch should not have ID column if aggregated with ID expression empty") { + val mv = FlintSparkMaterializedView( + testMvName, + s""" SELECT time, name, AVG(age) AS avg + | FROM $testTable + | GROUP BY time, name""".stripMargin, + Array.empty, + Map.empty, + FlintSparkIndexOptions(Map("id_expression" -> ""))) + + val batchDf = mv.build(spark, None) + batchDf.idColumn() shouldBe None + } + + test("build batch should have ID column if aggregated join") { + val mv = FlintSparkMaterializedView( + testMvName, + s""" SELECT t1.time, t1.name, AVG(t1.age) AS avg + | FROM $testTable AS t1 + | JOIN $testTable AS t2 + | ON t1.time = t2.time + | GROUP BY t1.time, t1.name""".stripMargin, + Array.empty, + Map.empty) + + val batchDf = mv.build(spark, None) + batchDf.idColumn() shouldBe Some( + Sha1( + ConcatWs( + Seq( + Literal(UTF8String.fromString("\0"), StringType), + UnresolvedAttribute(Seq("time")), + UnresolvedAttribute(Seq("name")), + UnresolvedAttribute(Seq("avg")))))) } test("build stream with ID expression option") { - withTable(testTable) { - sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") - val mv = FlintSparkMaterializedView( - testMvName, - s"SELECT time, name FROM $testTable", - Array.empty, - Map.empty, - FlintSparkIndexOptions(Map("auto_refresh" -> "true", "id_expression" -> "name"))) - - mv.buildStream(spark).queryExecution.logical.exists { - case Project(projectList, _) => - projectList.exists { - case Alias(UnresolvedAttribute(Seq("name")), ID_COLUMN) => true - case _ => false - } - case _ => false - } shouldBe true - } + val mv = FlintSparkMaterializedView( + testMvName, + s"SELECT time, name FROM $testTable", + Array.empty, + Map.empty, + FlintSparkIndexOptions(Map("auto_refresh" -> "true", "id_expression" -> "time"))) + + val streamDf = mv.buildStream(spark) + streamDf.idColumn() shouldBe Some(UnresolvedAttribute(Seq("time"))) } test("build stream should not have ID column if non-aggregated") { - withTable(testTable) { - sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") - val mv = FlintSparkMaterializedView( - testMvName, - s"SELECT time, name FROM $testTable", - Array.empty, - Map.empty, - FlintSparkIndexOptions(Map("auto_refresh" -> "true"))) + val mv = FlintSparkMaterializedView( + testMvName, + s"SELECT time, name FROM $testTable", + Array.empty, + Map.empty, + FlintSparkIndexOptions(Map("auto_refresh" -> "true"))) - mv.buildStream(spark).queryExecution.logical.exists { - case Project(projectList, _) => - projectList.forall(_.name != ID_COLUMN) - case _ => false - } shouldBe true - } + val streamDf = mv.buildStream(spark) + streamDf.idColumn() shouldBe None } test("build stream should have ID column if aggregated") { - withTable(testTable) { - sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") - val testMvQuery = - s""" + val testMvQuery = + s""" | SELECT | window.start AS startTime, | COUNT(*) AS count | FROM $testTable | GROUP BY TUMBLE(time, '1 Minute') |""".stripMargin - val mv = FlintSparkMaterializedView( - testMvName, - testMvQuery, - Array.empty, - Map.empty, - FlintSparkIndexOptions(Map("auto_refresh" -> "true", "watermark_delay" -> "10 Seconds"))) - - mv.buildStream(spark).queryExecution.logical.exists { - case Project(projectList, _) => - val asciiNull = UTF8String.fromString("\0") - projectList.exists { - case Alias( - Sha1( - ConcatWs( - Seq( - Literal(`asciiNull`, StringType), - UnresolvedAttribute(Seq("startTime")), - UnresolvedAttribute(Seq("count"))))), - ID_COLUMN) => - true - case _ => false - } - case _ => false - } shouldBe true - } + val mv = FlintSparkMaterializedView( + testMvName, + testMvQuery, + Array.empty, + Map.empty, + FlintSparkIndexOptions(Map("auto_refresh" -> "true", "watermark_delay" -> "10 Seconds"))) + + val streamDf = mv.buildStream(spark) + streamDf.idColumn() shouldBe Some( + Sha1( + ConcatWs( + Seq( + Literal(UTF8String.fromString("\0"), StringType), + UnresolvedAttribute(Seq("startTime")), + UnresolvedAttribute(Seq("count")))))) + } + + test("build stream should not have ID column if aggregated with ID expression empty") { + val testMvQuery = + s""" + | SELECT + | window.start AS startTime, + | COUNT(*) AS count + | FROM $testTable + | GROUP BY TUMBLE(time, '1 Minute') + |""".stripMargin + val mv = FlintSparkMaterializedView( + testMvName, + testMvQuery, + Array.empty, + Map.empty, + FlintSparkIndexOptions( + Map("auto_refresh" -> "true", "watermark_delay" -> "10 Seconds", "id_expression" -> ""))) + + val streamDf = mv.buildStream(spark) + streamDf.idColumn() shouldBe None } private def withAggregateMaterializedView( @@ -333,19 +349,16 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { sourceTables: Array[String], options: Map[String, String])(codeBlock: LogicalPlan => Unit): Unit = { - withTable(testTable) { - sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") - val mv = - FlintSparkMaterializedView( - testMvName, - query, - sourceTables, - Map.empty, - FlintSparkIndexOptions(options)) - - val actualPlan = mv.buildStream(spark).queryExecution.logical - codeBlock(actualPlan) - } + val mv = + FlintSparkMaterializedView( + testMvName, + query, + sourceTables, + Map.empty, + FlintSparkIndexOptions(options)) + + val actualPlan = mv.buildStream(spark).queryExecution.logical + codeBlock(actualPlan) } } @@ -372,4 +385,15 @@ object FlintSparkMaterializedViewSuite { logicalPlan) } } + + implicit class DataFrameIdColumnExtractor(val df: DataFrame) { + + def idColumn(): Option[Expression] = { + df.queryExecution.logical.collectFirst { case Project(projectList, _) => + projectList.collectFirst { case Alias(child, ID_COLUMN) => + child + } + }.flatten + } + } } From 003a1d6a75bdfaded38bfeb2190611ba803249b4 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Fri, 22 Nov 2024 15:14:25 -0800 Subject: [PATCH 06/10] Add logging and more IT Signed-off-by: Chen Dai --- .../flint/spark/FlintSparkIndex.scala | 15 ++++++++----- .../covering/FlintSparkCoveringIndex.scala | 4 ++-- .../spark/mv/FlintSparkMaterializedView.scala | 6 ++--- .../flint/spark/FlintSparkIndexSuite.scala | 22 +++++++++---------- .../FlintSparkCoveringIndexSuite.scala | 1 - .../FlintSparkCoveringIndexSqlITSuite.scala | 4 ++++ ...FlintSparkMaterializedViewSqlITSuite.scala | 6 ++++- 7 files changed, 33 insertions(+), 25 deletions(-) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala index 9bddfaf22..56a928aa5 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala @@ -11,8 +11,9 @@ import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.core.metadata.FlintJsonHelper._ -import org.apache.spark.sql.{Column, DataFrame, SparkSession} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.flint.datatype.FlintDataType import org.apache.spark.sql.functions.{col, concat_ws, expr, sha1, to_json} import org.apache.spark.sql.types.{MapType, StructType} @@ -64,7 +65,7 @@ trait FlintSparkIndex { def build(spark: SparkSession, df: Option[DataFrame]): DataFrame } -object FlintSparkIndex { +object FlintSparkIndex extends Logging { /** * Interface indicates a Flint index has custom streaming refresh capability other than foreach @@ -134,13 +135,13 @@ object FlintSparkIndex { * @return * DataFrame with/without ID column */ - def generateIdColumn(df: DataFrame, options: FlintSparkIndexOptions): DataFrame = { - // Assume output rows must be unique if a simple query plan has aggregate operator + def addIdColumn(df: DataFrame, options: FlintSparkIndexOptions): DataFrame = { def isAggregated: Boolean = df.queryExecution.logical.exists(_.isInstanceOf[Aggregate]) options.idExpression() match { case Some(idExpr) if idExpr.nonEmpty => + logInfo(s"Using user-provided ID expression: $idExpr") df.withColumn(ID_COLUMN, expr(idExpr)) case None if isAggregated => @@ -156,7 +157,9 @@ object FlintSparkIndex { } // TODO: 1) use only grouping columns; 2) ensure aggregation is on top level - df.withColumn(ID_COLUMN, sha1(concat_ws("\0", allOutputCols: _*))) + val idCol = sha1(concat_ws("\0", allOutputCols: _*)) + logInfo(s"Generated ID column for aggregated query: $idCol") + df.withColumn(ID_COLUMN, idCol) case _ => df } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala index 66cb59869..901c3006c 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala @@ -10,7 +10,7 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.spark._ -import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateIdColumn, generateSchema, metadataBuilder, quotedTableName, ID_COLUMN} +import org.opensearch.flint.spark.FlintSparkIndex.{addIdColumn, flintIndexNamePrefix, generateSchema, metadataBuilder, quotedTableName} import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.{getFlintIndexName, COVERING_INDEX_TYPE} @@ -77,7 +77,7 @@ case class FlintSparkCoveringIndex( .getOrElse(job) .select(colNames.head, colNames.tail: _*) - generateIdColumn(batchDf, options) + addIdColumn(batchDf, options) } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala index 3892ccb82..e3b09661a 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala @@ -13,7 +13,7 @@ import scala.collection.convert.ImplicitConversions.`map AsScala` import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndex, FlintSparkIndexBuilder, FlintSparkIndexOptions} -import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateIdColumn, generateSchema, metadataBuilder, ID_COLUMN, StreamingRefresh} +import org.opensearch.flint.spark.FlintSparkIndex.{addIdColumn, flintIndexNamePrefix, generateSchema, metadataBuilder, ID_COLUMN, StreamingRefresh} import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.opensearch.flint.spark.function.TumbleFunction import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{getFlintIndexName, MV_INDEX_TYPE} @@ -82,7 +82,7 @@ case class FlintSparkMaterializedView( require(df.isEmpty, "materialized view doesn't support reading from other data frame") val batchDf = spark.sql(query) - generateIdColumn(batchDf, options) + addIdColumn(batchDf, options) } override def buildStream(spark: SparkSession): DataFrame = { @@ -102,7 +102,7 @@ case class FlintSparkMaterializedView( } val streamingDf = logicalPlanToDataFrame(spark, streamingPlan) - generateIdColumn(streamingDf, options) + addIdColumn(streamingDf, options) } private def watermark(timeCol: Attribute, child: LogicalPlan) = { diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala index 25e1ed591..8415613e8 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala @@ -5,9 +5,7 @@ package org.opensearch.flint.spark -import java.sql.Timestamp - -import org.opensearch.flint.spark.FlintSparkIndex.{generateIdColumn, ID_COLUMN} +import org.opensearch.flint.spark.FlintSparkIndex.{addIdColumn, ID_COLUMN} import org.scalatest.matchers.should.Matchers import org.apache.spark.FlintSuite @@ -16,31 +14,31 @@ import org.apache.spark.sql.types.StructType class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { - test("should generate ID column if ID expression is provided") { + test("should add ID column if ID expression is provided") { val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name") val options = new FlintSparkIndexOptions(Map("id_expression" -> "id + 10")) - val resultDf = generateIdColumn(df, options) + val resultDf = addIdColumn(df, options) checkAnswer(resultDf.select(ID_COLUMN), Seq(Row(11), Row(12))) } - test("should not generate ID column if ID expression is not provided") { + test("should not add ID column if ID expression is not provided") { val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name") val options = FlintSparkIndexOptions.empty - val resultDf = generateIdColumn(df, options) + val resultDf = addIdColumn(df, options) resultDf.columns should not contain ID_COLUMN } - test("should not generate ID column if ID expression is empty") { + test("should not add ID column if ID expression is empty") { val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name") val options = FlintSparkIndexOptions.empty - val resultDf = generateIdColumn(df, options) + val resultDf = addIdColumn(df, options) resultDf.columns should not contain ID_COLUMN } - test("should generate ID column for aggregated query") { + test("should add ID column for aggregated query") { val df = spark .createDataFrame(Seq((1, "Alice"), (2, "Bob"), (3, "Alice"))) .toDF("id", "name") @@ -48,7 +46,7 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { .count() val options = FlintSparkIndexOptions.empty - val resultDf = generateIdColumn(df, options) + val resultDf = addIdColumn(df, options) resultDf.columns should contain(ID_COLUMN) resultDf.select(ID_COLUMN).distinct().count() shouldBe 2 } @@ -93,7 +91,7 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { .count() val options = FlintSparkIndexOptions.empty - val resultDf = generateIdColumn(aggregatedDf, options) + val resultDf = addIdColumn(aggregatedDf, options) resultDf.columns should contain(ID_COLUMN) resultDf.select(ID_COLUMN).distinct().count() shouldBe 1 } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala index 75a219fb6..eb255e493 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala @@ -14,7 +14,6 @@ import org.apache.spark.FlintSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.Alias import org.apache.spark.sql.catalyst.plans.logical.Project -import org.apache.spark.sql.functions.{col, expr} class FlintSparkCoveringIndexSuite extends FlintSuite { diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala index 9acef8c25..0791f9b7a 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala @@ -119,6 +119,10 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite { val indexData = flint.queryIndex(testFlintTimeSeriesIndex) indexData.count() shouldBe 3 // only 3 rows left due to same ID + + // Rerun should not generate duplicate data + sql(s"REFRESH INDEX $testIndex ON $testTimeSeriesTable") + indexData.count() shouldBe 3 } test("create covering index with index settings") { diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala index 8fb220a77..0e15e5f8a 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala @@ -191,7 +191,11 @@ class FlintSparkMaterializedViewSqlITSuite extends FlintSparkSuite { sql(s"REFRESH MATERIALIZED VIEW $testMvName") // 2 rows missing due to ID conflict intentionally - val indexData = spark.read.format(FLINT_DATASOURCE).load(testFlintIndex) + val indexData = flint.queryIndex(testFlintIndex) + indexData.count() shouldBe 2 + + // Rerun should not generate duplicate data + sql(s"REFRESH MATERIALIZED VIEW $testMvName") indexData.count() shouldBe 2 } From cd10c8ee2cfc31307edd913f26d78fac5fd82677 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Mon, 25 Nov 2024 09:01:02 -0800 Subject: [PATCH 07/10] Fix id expression comment Signed-off-by: Chen Dai --- .../org/opensearch/flint/spark/FlintSparkIndexOptions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala index b866463c8..1ad88de6d 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexOptions.scala @@ -97,7 +97,7 @@ case class FlintSparkIndexOptions(options: Map[String, String]) { def indexSettings(): Option[String] = getOptionValue(INDEX_SETTINGS) /** - * An expression that generates unique value as source data row ID. + * An expression that generates unique value as index data row ID. * * @return * ID expression From b1fc84825fdf41a318c17cc29ee992e6eba3261f Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Tue, 26 Nov 2024 15:45:27 -0800 Subject: [PATCH 08/10] Refactor UT assertions Signed-off-by: Chen Dai --- .../flint/spark/FlintSparkIndex.scala | 5 +- .../scala/org/apache/spark/FlintSuite.scala | 16 +++++- .../flint/spark/FlintSparkIndexSuite.scala | 54 +++++++++++++++++-- .../mv/FlintSparkMaterializedViewSuite.scala | 18 ++----- 4 files changed, 72 insertions(+), 21 deletions(-) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala index 56a928aa5..609fd7b4c 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala @@ -14,6 +14,7 @@ import org.opensearch.flint.core.metadata.FlintJsonHelper._ import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.catalyst.util.quoteIfNeeded import org.apache.spark.sql.flint.datatype.FlintDataType import org.apache.spark.sql.functions.{col, concat_ws, expr, sha1, to_json} import org.apache.spark.sql.types.{MapType, StructType} @@ -150,9 +151,9 @@ object FlintSparkIndex extends Logging { val allOutputCols = df.schema.fields.map { field => field.dataType match { case _: StructType | _: MapType => - to_json(col(field.name)) + to_json(col(quoteIfNeeded(field.name))) case _ => - col(field.name) + col(quoteIfNeeded(field.name)) } } diff --git a/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala b/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala index 1d301087f..a6d771534 100644 --- a/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala +++ b/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala @@ -6,9 +6,12 @@ package org.apache.spark import org.opensearch.flint.spark.FlintSparkExtensions +import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN -import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.expressions.{Alias, CodegenObjectFactoryMode, Expression} import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation +import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.flint.config.{FlintConfigEntry, FlintSparkConf} import org.apache.spark.sql.flint.config.FlintSparkConf.{EXTERNAL_SCHEDULER_ENABLED, HYBRID_SCAN_ENABLED, METADATA_CACHE_WRITE} import org.apache.spark.sql.internal.SQLConf @@ -68,4 +71,15 @@ trait FlintSuite extends SharedSparkSession { setFlintSparkConf(METADATA_CACHE_WRITE, "false") } } + + protected implicit class DataFrameExtensions(val df: DataFrame) { + + def idColumn(): Option[Expression] = { + df.queryExecution.logical.collectFirst { case Project(projectList, _) => + projectList.collectFirst { case Alias(child, ID_COLUMN) => + child + } + }.flatten + } + } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala index 8415613e8..6d0d972f6 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala @@ -10,7 +10,11 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.FlintSuite import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{Add, ConcatWs, Literal, Sha1, StructsToJson} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.unsafe.types.UTF8String class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { @@ -19,6 +23,7 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { val options = new FlintSparkIndexOptions(Map("id_expression" -> "id + 10")) val resultDf = addIdColumn(df, options) + resultDf.idColumn() shouldBe Some(Add(UnresolvedAttribute("id"), Literal(10))) checkAnswer(resultDf.select(ID_COLUMN), Seq(Row(11), Row(12))) } @@ -47,7 +52,37 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { val options = FlintSparkIndexOptions.empty val resultDf = addIdColumn(df, options) - resultDf.columns should contain(ID_COLUMN) + resultDf.idColumn() shouldBe Some( + Sha1( + ConcatWs( + Seq( + Literal(UTF8String.fromString("\0"), StringType), + UnresolvedAttribute(Seq("name")), + UnresolvedAttribute(Seq("count")))))) + resultDf.select(ID_COLUMN).distinct().count() shouldBe 2 + } + + test("should add ID column for aggregated query with quoted alias") { + val df = spark + .createDataFrame( + sparkContext.parallelize( + Seq( + Row(1, "Alice", Row("WA", "Seattle")), + Row(2, "Bob", Row("OR", "Portland")), + Row(3, "Alice", Row("WA", "Seattle")))), + StructType.fromDDL("id INT, name STRING, address STRUCT")) + .toDF("id", "name", "address") + .groupBy(col("name").as("test.name"), col("address").as("test.address")) + .count() + val options = FlintSparkIndexOptions.empty + + val resultDf = addIdColumn(df, options) + resultDf.idColumn() shouldBe Some( + Sha1(ConcatWs(Seq( + Literal(UTF8String.fromString("\0"), StringType), + UnresolvedAttribute(Seq("test.name")), + new StructsToJson(UnresolvedAttribute(Seq("test.address"))), + UnresolvedAttribute(Seq("count")))))) resultDf.select(ID_COLUMN).distinct().count() shouldBe 2 } @@ -92,7 +127,20 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { val options = FlintSparkIndexOptions.empty val resultDf = addIdColumn(aggregatedDf, options) - resultDf.columns should contain(ID_COLUMN) + resultDf.idColumn() shouldBe Some( + Sha1(ConcatWs(Seq( + Literal(UTF8String.fromString("\0"), StringType), + UnresolvedAttribute(Seq("boolean_col")), + UnresolvedAttribute(Seq("string_col")), + UnresolvedAttribute(Seq("long_col")), + UnresolvedAttribute(Seq("int_col")), + UnresolvedAttribute(Seq("double_col")), + UnresolvedAttribute(Seq("float_col")), + UnresolvedAttribute(Seq("timestamp_col")), + UnresolvedAttribute(Seq("date_col")), + new StructsToJson(UnresolvedAttribute(Seq("struct_col"))), + UnresolvedAttribute(Seq("subfield2")), + UnresolvedAttribute(Seq("count")))))) resultDf.select(ID_COLUMN).distinct().count() shouldBe 1 } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala index 200efbe97..7774eb2fb 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala @@ -7,10 +7,9 @@ package org.opensearch.flint.spark.mv import scala.collection.JavaConverters.{mapAsJavaMapConverter, mapAsScalaMapConverter} -import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN import org.opensearch.flint.spark.FlintSparkIndexOptions import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE -import org.opensearch.flint.spark.mv.FlintSparkMaterializedViewSuite.{streamingRelation, DataFrameIdColumnExtractor, StreamingDslLogicalPlan} +import org.opensearch.flint.spark.mv.FlintSparkMaterializedViewSuite.{streamingRelation, StreamingDslLogicalPlan} import org.scalatest.matchers.should.Matchers._ import org.scalatestplus.mockito.MockitoSugar.mock @@ -19,8 +18,8 @@ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} import org.apache.spark.sql.catalyst.dsl.expressions.{intToLiteral, stringToLiteral, DslAttr, DslExpression, StringToAttributeConversionHelper} import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, ConcatWs, Expression, Literal, Sha1} -import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.expressions.{Attribute, ConcatWs, Literal, Sha1} +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.types.StringType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -385,15 +384,4 @@ object FlintSparkMaterializedViewSuite { logicalPlan) } } - - implicit class DataFrameIdColumnExtractor(val df: DataFrame) { - - def idColumn(): Option[Expression] = { - df.queryExecution.logical.collectFirst { case Project(projectList, _) => - projectList.collectFirst { case Alias(child, ID_COLUMN) => - child - } - }.flatten - } - } } From 15ed31b951db4537004f4a98199ec1b409637e44 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Fri, 20 Dec 2024 13:42:15 -0800 Subject: [PATCH 09/10] Remove auto gen logic for MV Signed-off-by: Chen Dai --- .../flint/spark/FlintSparkIndex.scala | 27 +---- .../flint/spark/FlintSparkIndexSuite.scala | 83 ++++--------- .../mv/FlintSparkMaterializedViewSuite.scala | 114 +----------------- ...FlintSparkMaterializedViewSqlITSuite.scala | 5 +- 4 files changed, 33 insertions(+), 196 deletions(-) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala index 609fd7b4c..32988a5b2 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala @@ -13,11 +13,9 @@ import org.opensearch.flint.core.metadata.FlintJsonHelper._ import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, SparkSession} -import org.apache.spark.sql.catalyst.plans.logical.Aggregate -import org.apache.spark.sql.catalyst.util.quoteIfNeeded import org.apache.spark.sql.flint.datatype.FlintDataType -import org.apache.spark.sql.functions.{col, concat_ws, expr, sha1, to_json} -import org.apache.spark.sql.types.{MapType, StructType} +import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.types.StructType /** * Flint index interface in Spark. @@ -137,31 +135,10 @@ object FlintSparkIndex extends Logging { * DataFrame with/without ID column */ def addIdColumn(df: DataFrame, options: FlintSparkIndexOptions): DataFrame = { - def isAggregated: Boolean = - df.queryExecution.logical.exists(_.isInstanceOf[Aggregate]) - options.idExpression() match { case Some(idExpr) if idExpr.nonEmpty => logInfo(s"Using user-provided ID expression: $idExpr") df.withColumn(ID_COLUMN, expr(idExpr)) - - case None if isAggregated => - // Since concat doesn't support struct or map type, convert these to json which is more - // deterministic than casting to string, as its format may vary across Spark versions. - val allOutputCols = df.schema.fields.map { field => - field.dataType match { - case _: StructType | _: MapType => - to_json(col(quoteIfNeeded(field.name))) - case _ => - col(quoteIfNeeded(field.name)) - } - } - - // TODO: 1) use only grouping columns; 2) ensure aggregation is on top level - val idCol = sha1(concat_ws("\0", allOutputCols: _*)) - logInfo(s"Generated ID column for aggregated query: $idCol") - df.withColumn(ID_COLUMN, idCol) - case _ => df } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala index 6d0d972f6..8ec4bec40 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala @@ -10,9 +10,8 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.FlintSuite import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction} import org.apache.spark.sql.catalyst.expressions.{Add, ConcatWs, Literal, Sha1, StructsToJson} -import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -43,49 +42,6 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { resultDf.columns should not contain ID_COLUMN } - test("should add ID column for aggregated query") { - val df = spark - .createDataFrame(Seq((1, "Alice"), (2, "Bob"), (3, "Alice"))) - .toDF("id", "name") - .groupBy("name") - .count() - val options = FlintSparkIndexOptions.empty - - val resultDf = addIdColumn(df, options) - resultDf.idColumn() shouldBe Some( - Sha1( - ConcatWs( - Seq( - Literal(UTF8String.fromString("\0"), StringType), - UnresolvedAttribute(Seq("name")), - UnresolvedAttribute(Seq("count")))))) - resultDf.select(ID_COLUMN).distinct().count() shouldBe 2 - } - - test("should add ID column for aggregated query with quoted alias") { - val df = spark - .createDataFrame( - sparkContext.parallelize( - Seq( - Row(1, "Alice", Row("WA", "Seattle")), - Row(2, "Bob", Row("OR", "Portland")), - Row(3, "Alice", Row("WA", "Seattle")))), - StructType.fromDDL("id INT, name STRING, address STRUCT")) - .toDF("id", "name", "address") - .groupBy(col("name").as("test.name"), col("address").as("test.address")) - .count() - val options = FlintSparkIndexOptions.empty - - val resultDf = addIdColumn(df, options) - resultDf.idColumn() shouldBe Some( - Sha1(ConcatWs(Seq( - Literal(UTF8String.fromString("\0"), StringType), - UnresolvedAttribute(Seq("test.name")), - new StructsToJson(UnresolvedAttribute(Seq("test.address"))), - UnresolvedAttribute(Seq("count")))))) - resultDf.select(ID_COLUMN).distinct().count() shouldBe 2 - } - test("should generate ID column for various column types") { val schema = StructType.fromDDL(""" boolean_col BOOLEAN, @@ -124,23 +80,32 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { "struct_col", "struct_col.subfield2") .count() - val options = FlintSparkIndexOptions.empty + val options = FlintSparkIndexOptions(Map("id_expression" -> + "sha1(concat_ws('\0',boolean_col,string_col,long_col,int_col,double_col,float_col,timestamp_col,date_col,to_json(struct_col),struct_col.subfield2))")) val resultDf = addIdColumn(aggregatedDf, options) resultDf.idColumn() shouldBe Some( - Sha1(ConcatWs(Seq( - Literal(UTF8String.fromString("\0"), StringType), - UnresolvedAttribute(Seq("boolean_col")), - UnresolvedAttribute(Seq("string_col")), - UnresolvedAttribute(Seq("long_col")), - UnresolvedAttribute(Seq("int_col")), - UnresolvedAttribute(Seq("double_col")), - UnresolvedAttribute(Seq("float_col")), - UnresolvedAttribute(Seq("timestamp_col")), - UnresolvedAttribute(Seq("date_col")), - new StructsToJson(UnresolvedAttribute(Seq("struct_col"))), - UnresolvedAttribute(Seq("subfield2")), - UnresolvedAttribute(Seq("count")))))) + UnresolvedFunction( + "sha1", + Seq(UnresolvedFunction( + "concat_ws", + Seq( + Literal(UTF8String.fromString("\0"), StringType), + UnresolvedAttribute(Seq("boolean_col")), + UnresolvedAttribute(Seq("string_col")), + UnresolvedAttribute(Seq("long_col")), + UnresolvedAttribute(Seq("int_col")), + UnresolvedAttribute(Seq("double_col")), + UnresolvedAttribute(Seq("float_col")), + UnresolvedAttribute(Seq("timestamp_col")), + UnresolvedAttribute(Seq("date_col")), + UnresolvedFunction( + "to_json", + Seq(UnresolvedAttribute(Seq("struct_col"))), + isDistinct = false), + UnresolvedAttribute(Seq("struct_col", "subfield2"))), + isDistinct = false)), + isDistinct = false)) resultDf.select(ID_COLUMN).distinct().count() shouldBe 1 } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala index 7774eb2fb..838eddf21 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala @@ -18,10 +18,9 @@ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} import org.apache.spark.sql.catalyst.dsl.expressions.{intToLiteral, stringToLiteral, DslAttr, DslExpression, StringToAttributeConversionHelper} import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan -import org.apache.spark.sql.catalyst.expressions.{Attribute, ConcatWs, Literal, Sha1} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.util.IntervalUtils -import org.apache.spark.sql.types.StringType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.unsafe.types.UTF8String @@ -118,7 +117,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { | FROM $testTable | GROUP BY TUMBLE(time, '1 Minute') |""".stripMargin - val options = Map("watermark_delay" -> "30 Seconds", "id_expression" -> "") + val options = Map("watermark_delay" -> "30 Seconds") withAggregateMaterializedView(testQuery, Array(testTable), options) { actualPlan => comparePlans( @@ -143,7 +142,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { | WHERE age > 30 | GROUP BY TUMBLE(time, '1 Minute') |""".stripMargin - val options = Map("watermark_delay" -> "30 Seconds", "id_expression" -> "") + val options = Map("watermark_delay" -> "30 Seconds") withAggregateMaterializedView(testQuery, Array(testTable), options) { actualPlan => comparePlans( @@ -208,7 +207,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { batchDf.idColumn() shouldBe Some(UnresolvedAttribute(Seq("time"))) } - test("build batch should not have ID column if non-aggregated") { + test("build batch should not have ID column if not provided") { val testMvQuery = s"SELECT time, name FROM $testTable" val mv = FlintSparkMaterializedView(testMvName, testMvQuery, Array.empty, Map.empty) @@ -216,62 +215,6 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { batchDf.idColumn() shouldBe None } - test("build batch should have ID column if aggregated") { - val mv = FlintSparkMaterializedView( - testMvName, - s""" SELECT time, name, AVG(age) AS avg - | FROM $testTable - | GROUP BY time, name""".stripMargin, - Array.empty, - Map.empty) - - val batchDf = mv.build(spark, None) - batchDf.idColumn() shouldBe Some( - Sha1( - ConcatWs( - Seq( - Literal(UTF8String.fromString("\0"), StringType), - UnresolvedAttribute(Seq("time")), - UnresolvedAttribute(Seq("name")), - UnresolvedAttribute(Seq("avg")))))) - } - - test("build batch should not have ID column if aggregated with ID expression empty") { - val mv = FlintSparkMaterializedView( - testMvName, - s""" SELECT time, name, AVG(age) AS avg - | FROM $testTable - | GROUP BY time, name""".stripMargin, - Array.empty, - Map.empty, - FlintSparkIndexOptions(Map("id_expression" -> ""))) - - val batchDf = mv.build(spark, None) - batchDf.idColumn() shouldBe None - } - - test("build batch should have ID column if aggregated join") { - val mv = FlintSparkMaterializedView( - testMvName, - s""" SELECT t1.time, t1.name, AVG(t1.age) AS avg - | FROM $testTable AS t1 - | JOIN $testTable AS t2 - | ON t1.time = t2.time - | GROUP BY t1.time, t1.name""".stripMargin, - Array.empty, - Map.empty) - - val batchDf = mv.build(spark, None) - batchDf.idColumn() shouldBe Some( - Sha1( - ConcatWs( - Seq( - Literal(UTF8String.fromString("\0"), StringType), - UnresolvedAttribute(Seq("time")), - UnresolvedAttribute(Seq("name")), - UnresolvedAttribute(Seq("avg")))))) - } - test("build stream with ID expression option") { val mv = FlintSparkMaterializedView( testMvName, @@ -284,7 +227,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { streamDf.idColumn() shouldBe Some(UnresolvedAttribute(Seq("time"))) } - test("build stream should not have ID column if non-aggregated") { + test("build stream should not have ID column if not provided") { val mv = FlintSparkMaterializedView( testMvName, s"SELECT time, name FROM $testTable", @@ -296,53 +239,6 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { streamDf.idColumn() shouldBe None } - test("build stream should have ID column if aggregated") { - val testMvQuery = - s""" - | SELECT - | window.start AS startTime, - | COUNT(*) AS count - | FROM $testTable - | GROUP BY TUMBLE(time, '1 Minute') - |""".stripMargin - val mv = FlintSparkMaterializedView( - testMvName, - testMvQuery, - Array.empty, - Map.empty, - FlintSparkIndexOptions(Map("auto_refresh" -> "true", "watermark_delay" -> "10 Seconds"))) - - val streamDf = mv.buildStream(spark) - streamDf.idColumn() shouldBe Some( - Sha1( - ConcatWs( - Seq( - Literal(UTF8String.fromString("\0"), StringType), - UnresolvedAttribute(Seq("startTime")), - UnresolvedAttribute(Seq("count")))))) - } - - test("build stream should not have ID column if aggregated with ID expression empty") { - val testMvQuery = - s""" - | SELECT - | window.start AS startTime, - | COUNT(*) AS count - | FROM $testTable - | GROUP BY TUMBLE(time, '1 Minute') - |""".stripMargin - val mv = FlintSparkMaterializedView( - testMvName, - testMvQuery, - Array.empty, - Map.empty, - FlintSparkIndexOptions( - Map("auto_refresh" -> "true", "watermark_delay" -> "10 Seconds", "id_expression" -> ""))) - - val streamDf = mv.buildStream(spark) - streamDf.idColumn() shouldBe None - } - private def withAggregateMaterializedView( query: String, sourceTables: Array[String], diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala index 0e15e5f8a..7dcd83897 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala @@ -163,7 +163,7 @@ class FlintSparkMaterializedViewSqlITSuite extends FlintSparkSuite { | auto_refresh = true, | checkpoint_location = '${checkpointDir.getAbsolutePath}', | watermark_delay = '1 Second', - | id_expression = 'count' + | id_expression = "sha1(concat_ws('\0',startTime))" | ) |""".stripMargin) @@ -174,8 +174,7 @@ class FlintSparkMaterializedViewSqlITSuite extends FlintSparkSuite { job.get.processAllAvailable() } - // 1 row missing due to ID conflict intentionally - flint.queryIndex(testFlintIndex).count() shouldBe 2 + flint.queryIndex(testFlintIndex).count() shouldBe 3 } } From 01b48facc6de8702905df3652982b328e8f3f730 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Fri, 20 Dec 2024 14:43:41 -0800 Subject: [PATCH 10/10] Update user manual and scaladoc Signed-off-by: Chen Dai --- docs/index.md | 4 ++-- .../org/opensearch/flint/spark/FlintSparkIndex.scala | 7 +------ .../src/test/scala/org/apache/spark/FlintSuite.scala | 12 ++++++++++++ 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/docs/index.md b/docs/index.md index 1c02f0cb0..684ba7da6 100644 --- a/docs/index.md +++ b/docs/index.md @@ -394,7 +394,7 @@ User can provide the following options in `WITH` clause of create statement: + `watermark_delay`: a string as time expression for how late data can come and still be processed, e.g. 1 minute, 10 seconds. This is required by auto and incremental refresh on materialized view if it has aggregation in the query. + `output_mode`: a mode string that describes how data will be written to streaming sink. If unspecified, default append mode will be applied. + `index_settings`: a JSON string as index settings for OpenSearch index that will be created. Please follow the format in OpenSearch documentation. If unspecified, default OpenSearch index settings will be applied. -+ `id_expression`: an expression string that generates an ID column to avoid duplicate data when index refresh job restart or any retry attempt during an index refresh. If an empty string is provided, no ID column will be generated. ++ `id_expression`: an expression string that generates an ID column to guarantee idempotency when index refresh job restart or any retry attempt during an index refresh. If an empty string is provided, no ID column will be generated. + `extra_options`: a JSON string as extra options that can be passed to Spark streaming source and sink API directly. Use qualified source table name (because there could be multiple) and "sink", e.g. '{"sink": "{key: val}", "table1": {key: val}}' Note that the index option name is case-sensitive. Here is an example: @@ -407,7 +407,7 @@ WITH ( watermark_delay = '1 Second', output_mode = 'complete', index_settings = '{"number_of_shards": 2, "number_of_replicas": 3}', - id_expression = 'uuid()', + id_expression = "sha1(concat_ws('\0',startTime,status))", extra_options = '{"spark_catalog.default.alb_logs": {"maxFilesPerTrigger": "1"}}' ) ``` diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala index 32988a5b2..300233777 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala @@ -120,12 +120,7 @@ object FlintSparkIndex extends Logging { } /** - * Generate an ID column in the precedence below: - * ``` - * 1. Use ID expression provided in the index option; - * 2. SHA-1 based on all output columns if aggregated; - * 3. Otherwise, no ID column generated. - * ``` + * Generate an ID column using ID expression provided in the index option. * * @param df * which DataFrame to generate ID column diff --git a/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala b/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala index a6d771534..78debda35 100644 --- a/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala +++ b/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala @@ -72,8 +72,20 @@ trait FlintSuite extends SharedSparkSession { } } + /** + * Implicit class to extend DataFrame functionality with additional utilities. + * + * @param df + * the DataFrame to which the additional methods are added + */ protected implicit class DataFrameExtensions(val df: DataFrame) { + /** + * Retrieves the ID column expression from the logical plan of the DataFrame, if it exists. + * + * @return + * an `Option` containing the `Expression` for the ID column if present, or `None` otherwise + */ def idColumn(): Option[Expression] = { df.queryExecution.logical.collectFirst { case Project(projectList, _) => projectList.collectFirst { case Alias(child, ID_COLUMN) =>