diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala index 9e200a38d..e2a60f319 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala @@ -136,18 +136,18 @@ object FlintSparkIndex { */ def generateIdColumn(df: DataFrame, options: FlintSparkIndexOptions): DataFrame = { // Assume output rows must be unique if a simple query plan has aggregate operator - def isAggregated: Boolean = { + def isAggregated: Boolean = df.queryExecution.logical.exists(_.isInstanceOf[Aggregate]) - } - val idExpr = options.idExpression() - if (idExpr.exists(_.nonEmpty)) { - df.withColumn(ID_COLUMN, expr(idExpr.get)) - } else if (isAggregated) { - val allOutputCols = df.columns.map(col) - df.withColumn(ID_COLUMN, sha1(concat_ws("\0", allOutputCols: _*))) - } else { - df + options.idExpression() match { + case Some(idExpr) if idExpr.nonEmpty => + df.withColumn(ID_COLUMN, expr(idExpr)) + + case None if isAggregated => + val allOutputCols = df.columns.map(col) + df.withColumn(ID_COLUMN, sha1(concat_ws("\0", allOutputCols: _*))) + + case _ => df } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala index 1cce47d1a..e032ac122 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala @@ -5,13 +5,19 @@ package org.opensearch.flint.spark.covering +import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN +import org.opensearch.flint.spark.FlintSparkIndexOptions import org.scalatest.matchers.must.Matchers.contain import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.apache.spark.FlintSuite +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.{col, expr} class FlintSparkCoveringIndexSuite extends FlintSuite { + private val testTable = "spark_catalog.default.ci_test" + test("get covering index name") { val index = new FlintSparkCoveringIndex("ci", "spark_catalog.default.test", Map("name" -> "string")) @@ -54,4 +60,84 @@ class FlintSparkCoveringIndexSuite extends FlintSuite { new FlintSparkCoveringIndex("ci", "default.test", Map.empty) } } + + test("build batch with ID expression option") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (timestamp TIMESTAMP, name STRING) USING JSON") + val index = + FlintSparkCoveringIndex( + "name_idx", + testTable, + Map("name" -> "string"), + options = FlintSparkIndexOptions(Map("id_expression" -> "name"))) + + comparePlans( + index.build(spark, None).queryExecution.logical, + spark + .table(testTable) + .select(col("name")) + .withColumn(ID_COLUMN, expr("name")) + .queryExecution + .logical, + checkAnalysis = false) + } + } + + test("build batch should not have ID column without ID expression option") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (name STRING, age INTEGER) USING JSON") + val index = FlintSparkCoveringIndex("name_idx", testTable, Map("name" -> "string")) + + comparePlans( + index.build(spark, None).queryExecution.logical, + spark + .table(testTable) + .select(col("name")) + .queryExecution + .logical, + checkAnalysis = false) + } + } + + test("build stream with ID expression option") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (name STRING, age INTEGER) USING JSON") + val index = FlintSparkCoveringIndex( + "name_idx", + testTable, + Map("name" -> "string"), + options = + FlintSparkIndexOptions(Map("auto_refresh" -> "true", "id_expression" -> "name"))) + + comparePlans( + index.build(spark, Some(spark.table(testTable))).queryExecution.logical, + spark + .table(testTable) + .select("name") + .withColumn(ID_COLUMN, col("name")) + .queryExecution + .logical, + checkAnalysis = false) + } + } + + test("build stream should not have ID column without ID expression option") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (name STRING, age INTEGER) USING JSON") + val index = FlintSparkCoveringIndex( + "name_idx", + testTable, + Map("name" -> "string"), + options = FlintSparkIndexOptions(Map("auto_refresh" -> "true"))) + + comparePlans( + index.build(spark, Some(spark.table(testTable))).queryExecution.logical, + spark + .table(testTable) + .select(col("name")) + .queryExecution + .logical, + checkAnalysis = false) + } + } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala index 78d2eb09e..4cc06a1b6 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala @@ -7,6 +7,7 @@ package org.opensearch.flint.spark.mv import scala.collection.JavaConverters.{mapAsJavaMapConverter, mapAsScalaMapConverter} +import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN import org.opensearch.flint.spark.FlintSparkIndexOptions import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE import org.opensearch.flint.spark.mv.FlintSparkMaterializedViewSuite.{streamingRelation, StreamingDslLogicalPlan} @@ -15,12 +16,14 @@ import org.scalatestplus.mockito.MockitoSugar.mock import org.apache.spark.FlintSuite import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} import org.apache.spark.sql.catalyst.dsl.expressions.{intToLiteral, stringToLiteral, DslAttr, DslExpression, StringToAttributeConversionHelper} import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, ConcatWs, Literal, Sha1} +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan, Project} import org.apache.spark.sql.catalyst.util.IntervalUtils +import org.apache.spark.sql.functions.{col, concat_ws, expr, sha1} +import org.apache.spark.sql.types.StringType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.unsafe.types.UTF8String @@ -107,7 +110,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { | FROM $testTable | GROUP BY TUMBLE(time, '1 Minute') |""".stripMargin - val options = Map("watermark_delay" -> "30 Seconds") + val options = Map("watermark_delay" -> "30 Seconds", "id_expression" -> "") withAggregateMaterializedView(testQuery, Array(testTable), options) { actualPlan => comparePlans( @@ -132,7 +135,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { | WHERE age > 30 | GROUP BY TUMBLE(time, '1 Minute') |""".stripMargin - val options = Map("watermark_delay" -> "30 Seconds") + val options = Map("watermark_delay" -> "30 Seconds", "id_expression" -> "") withAggregateMaterializedView(testQuery, Array(testTable), options) { actualPlan => comparePlans( @@ -189,6 +192,142 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { } } + test("build batch with ID expression option") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") + val testMvQuery = s"SELECT time, name FROM $testTable" + val mv = FlintSparkMaterializedView( + testMvName, + testMvQuery, + Array.empty, + Map.empty, + FlintSparkIndexOptions(Map("id_expression" -> "time"))) + + comparePlans( + mv.build(spark, None).queryExecution.logical, + spark + .sql(testMvQuery) + .withColumn(ID_COLUMN, expr("time")) + .queryExecution + .logical, + checkAnalysis = false) + } + } + + test("build batch should not have ID column if non-aggregated") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") + val testMvQuery = s"SELECT time, name FROM $testTable" + val mv = FlintSparkMaterializedView(testMvName, testMvQuery, Array.empty, Map.empty) + + comparePlans( + mv.build(spark, None).queryExecution.logical, + spark.sql(testMvQuery).queryExecution.logical, + checkAnalysis = false) + } + } + + test("build batch should have ID column if aggregated") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") + val mv = FlintSparkMaterializedView( + testMvName, + s""" SELECT time, name, AVG(age) + | FROM $testTable + | GROUP BY time, name""".stripMargin, + Array.empty, + Map.empty) + + comparePlans( + mv.build(spark, None).queryExecution.logical, + spark + .table(testTable) + .groupBy("time", "name") + .avg("age") + .withColumn(ID_COLUMN, sha1(concat_ws("\0", col("time"), col("name"), col("avg(age)")))) + .queryExecution + .logical, + checkAnalysis = false) + } + } + + test("build stream with ID expression option") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") + val mv = FlintSparkMaterializedView( + testMvName, + s"SELECT time, name FROM $testTable", + Array.empty, + Map.empty, + FlintSparkIndexOptions(Map("auto_refresh" -> "true", "id_expression" -> "name"))) + + mv.buildStream(spark).queryExecution.logical.exists { + case Project(projectList, _) => + projectList.exists { + case Alias(UnresolvedAttribute(Seq("name")), ID_COLUMN) => true + case _ => false + } + case _ => false + } shouldBe true + } + } + + test("build stream should not have ID column if non-aggregated") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") + val mv = FlintSparkMaterializedView( + testMvName, + s"SELECT time, name FROM $testTable", + Array.empty, + Map.empty, + FlintSparkIndexOptions(Map("auto_refresh" -> "true"))) + + mv.buildStream(spark).queryExecution.logical.exists { + case Project(projectList, _) => + projectList.forall(_.name != ID_COLUMN) + case _ => false + } shouldBe true + } + } + + test("build stream should have ID column if aggregated") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") + val testMvQuery = + s""" + | SELECT + | window.start AS startTime, + | COUNT(*) AS count + | FROM $testTable + | GROUP BY TUMBLE(time, '1 Minute') + |""".stripMargin + val mv = FlintSparkMaterializedView( + testMvName, + testMvQuery, + Array.empty, + Map.empty, + FlintSparkIndexOptions(Map("auto_refresh" -> "true", "watermark_delay" -> "10 Seconds"))) + + mv.buildStream(spark).queryExecution.logical.exists { + case Project(projectList, _) => + val asciiNull = UTF8String.fromString("\0") + projectList.exists { + case Alias( + Sha1( + ConcatWs( + Seq( + Literal(`asciiNull`, StringType), + UnresolvedAttribute(Seq("startTime")), + UnresolvedAttribute(Seq("count"))))), + ID_COLUMN) => + true + case _ => false + } + case _ => false + } shouldBe true + } + } + private def withAggregateMaterializedView( query: String, sourceTables: Array[String],