diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala index 130be61fc..9bddfaf22 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala @@ -155,7 +155,7 @@ object FlintSparkIndex { } } - // TODO: use only grouping columns + // TODO: 1) use only grouping columns; 2) ensure aggregation is on top level df.withColumn(ID_COLUMN, sha1(concat_ws("\0", allOutputCols: _*))) case _ => df diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexOptionsSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexOptionsSuite.scala index d7de6d29b..f752ae68a 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexOptionsSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexOptionsSuite.scala @@ -6,7 +6,6 @@ package org.opensearch.flint.spark import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName._ -import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.SchedulerMode import org.scalatest.matchers.should.Matchers import org.apache.spark.FlintSuite @@ -22,6 +21,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers { WATERMARK_DELAY.toString shouldBe "watermark_delay" OUTPUT_MODE.toString shouldBe "output_mode" INDEX_SETTINGS.toString shouldBe "index_settings" + ID_EXPRESSION.toString shouldBe "id_expression" EXTRA_OPTIONS.toString shouldBe "extra_options" } @@ -36,6 +36,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers { "watermark_delay" -> "30 Seconds", "output_mode" -> "complete", "index_settings" -> """{"number_of_shards": 3}""", + "id_expression" -> """sha1(col("timestamp"))""", "extra_options" -> """ { | "alb_logs": { @@ -55,6 +56,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers { options.watermarkDelay() shouldBe Some("30 Seconds") options.outputMode() shouldBe Some("complete") options.indexSettings() shouldBe Some("""{"number_of_shards": 3}""") + options.idExpression() shouldBe Some("""sha1(col("timestamp"))""") options.extraSourceOptions("alb_logs") shouldBe Map("opt1" -> "val1") options.extraSinkOptions() shouldBe Map("opt2" -> "val2", "opt3" -> "val3") } @@ -83,6 +85,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers { options.watermarkDelay() shouldBe empty options.outputMode() shouldBe empty options.indexSettings() shouldBe empty + options.idExpression() shouldBe empty options.extraSourceOptions("alb_logs") shouldBe empty options.extraSinkOptions() shouldBe empty options.optionsWithDefault should contain("auto_refresh" -> "false") diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala index e67818532..25e1ed591 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala @@ -24,6 +24,14 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { checkAnswer(resultDf.select(ID_COLUMN), Seq(Row(11), Row(12))) } + test("should not generate ID column if ID expression is not provided") { + val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name") + val options = FlintSparkIndexOptions.empty + + val resultDf = generateIdColumn(df, options) + resultDf.columns should not contain ID_COLUMN + } + test("should not generate ID column if ID expression is empty") { val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name") val options = FlintSparkIndexOptions.empty @@ -41,45 +49,11 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { val options = FlintSparkIndexOptions.empty val resultDf = generateIdColumn(df, options) + resultDf.columns should contain(ID_COLUMN) resultDf.select(ID_COLUMN).distinct().count() shouldBe 2 } - test("should generate ID column for aggregated query using tumble function") { - val df = spark - .createDataFrame( - Seq( - (Timestamp.valueOf("2023-01-01 00:00:00"), 1, "Alice"), - (Timestamp.valueOf("2023-01-01 00:10:00"), 2, "Bob"), - (Timestamp.valueOf("2023-01-01 00:15:00"), 3, "Alice"))) - .toDF("timestamp", "id", "name") - val groupByDf = df - .selectExpr("TUMBLE(timestamp, '10 minutes') as window", "name") - .groupBy("window", "name") - .count() - .select("window.start", "name", "count") - val options = FlintSparkIndexOptions.empty - - val resultDf = generateIdColumn(groupByDf, options) - resultDf.select(ID_COLUMN).distinct().count() shouldBe 3 - } - - test("should not generate ID column for aggregated query if ID expression is empty") { - val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name") - val options = FlintSparkIndexOptions.empty - - val resultDf = generateIdColumn(df, options) - resultDf.columns should not contain ID_COLUMN - } - - test("should not generate ID column if ID expression is not provided") { - val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name") - val options = FlintSparkIndexOptions.empty - - val resultDf = generateIdColumn(df, options) - resultDf.columns should not contain ID_COLUMN - } - - test("should generate ID column for aggregated query with various column types") { + test("should generate ID column for various column types") { val schema = StructType.fromDDL(""" boolean_col BOOLEAN, string_col STRING, @@ -117,9 +91,10 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { "struct_col", "struct_col.subfield2") .count() - val options = FlintSparkIndexOptions.empty + val resultDf = generateIdColumn(aggregatedDf, options) + resultDf.columns should contain(ID_COLUMN) resultDf.select(ID_COLUMN).distinct().count() shouldBe 1 } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala index e032ac122..75a219fb6 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala @@ -11,7 +11,9 @@ import org.scalatest.matchers.must.Matchers.contain import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.apache.spark.FlintSuite -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.functions.{col, expr} class FlintSparkCoveringIndexSuite extends FlintSuite { @@ -71,31 +73,12 @@ class FlintSparkCoveringIndexSuite extends FlintSuite { Map("name" -> "string"), options = FlintSparkIndexOptions(Map("id_expression" -> "name"))) - comparePlans( - index.build(spark, None).queryExecution.logical, - spark - .table(testTable) - .select(col("name")) - .withColumn(ID_COLUMN, expr("name")) - .queryExecution - .logical, - checkAnalysis = false) - } - } - - test("build batch should not have ID column without ID expression option") { - withTable(testTable) { - sql(s"CREATE TABLE $testTable (name STRING, age INTEGER) USING JSON") - val index = FlintSparkCoveringIndex("name_idx", testTable, Map("name" -> "string")) - - comparePlans( - index.build(spark, None).queryExecution.logical, - spark - .table(testTable) - .select(col("name")) - .queryExecution - .logical, - checkAnalysis = false) + val batchDf = index.build(spark, None) + batchDf.queryExecution.logical.collectFirst { case Project(projectList, _) => + projectList.collectFirst { case Alias(child, ID_COLUMN) => + child + } + }.flatten shouldBe Some(UnresolvedAttribute(Seq("name"))) } } @@ -109,35 +92,12 @@ class FlintSparkCoveringIndexSuite extends FlintSuite { options = FlintSparkIndexOptions(Map("auto_refresh" -> "true", "id_expression" -> "name"))) - comparePlans( - index.build(spark, Some(spark.table(testTable))).queryExecution.logical, - spark - .table(testTable) - .select("name") - .withColumn(ID_COLUMN, col("name")) - .queryExecution - .logical, - checkAnalysis = false) - } - } - - test("build stream should not have ID column without ID expression option") { - withTable(testTable) { - sql(s"CREATE TABLE $testTable (name STRING, age INTEGER) USING JSON") - val index = FlintSparkCoveringIndex( - "name_idx", - testTable, - Map("name" -> "string"), - options = FlintSparkIndexOptions(Map("auto_refresh" -> "true"))) - - comparePlans( - index.build(spark, Some(spark.table(testTable))).queryExecution.logical, - spark - .table(testTable) - .select(col("name")) - .queryExecution - .logical, - checkAnalysis = false) + val streamDf = index.build(spark, Some(spark.table(testTable))) + streamDf.queryExecution.logical.collectFirst { case Project(projectList, _) => + projectList.collectFirst { case Alias(child, ID_COLUMN) => + child + } + }.flatten shouldBe Some(UnresolvedAttribute(Seq("name"))) } } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala index 4cc06a1b6..200efbe97 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala @@ -10,7 +10,7 @@ import scala.collection.JavaConverters.{mapAsJavaMapConverter, mapAsScalaMapConv import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN import org.opensearch.flint.spark.FlintSparkIndexOptions import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE -import org.opensearch.flint.spark.mv.FlintSparkMaterializedViewSuite.{streamingRelation, StreamingDslLogicalPlan} +import org.opensearch.flint.spark.mv.FlintSparkMaterializedViewSuite.{streamingRelation, DataFrameIdColumnExtractor, StreamingDslLogicalPlan} import org.scalatest.matchers.should.Matchers._ import org.scalatestplus.mockito.MockitoSugar.mock @@ -19,10 +19,9 @@ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} import org.apache.spark.sql.catalyst.dsl.expressions.{intToLiteral, stringToLiteral, DslAttr, DslExpression, StringToAttributeConversionHelper} import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, ConcatWs, Literal, Sha1} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, ConcatWs, Expression, Literal, Sha1} import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan, Project} import org.apache.spark.sql.catalyst.util.IntervalUtils -import org.apache.spark.sql.functions.{col, concat_ws, expr, sha1} import org.apache.spark.sql.types.StringType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.unsafe.types.UTF8String @@ -39,6 +38,16 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { val testMvName = "spark_catalog.default.mv" val testQuery = "SELECT 1" + override def beforeAll(): Unit = { + super.beforeAll() + sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") + } + + override def afterAll(): Unit = { + sql(s"DROP TABLE $testTable") + super.afterAll() + } + test("get mv name") { val mv = FlintSparkMaterializedView(testMvName, testQuery, Array.empty, Map.empty) mv.name() shouldBe "flint_spark_catalog_default_mv" @@ -177,155 +186,162 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { } test("build stream should fail if there is aggregation but no windowing function") { - val testTable = "mv_build_test" - withTable(testTable) { - sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") - - val mv = FlintSparkMaterializedView( - testMvName, - s"SELECT name, COUNT(*) AS count FROM $testTable GROUP BY name", - Array(testTable), - Map.empty) + val mv = FlintSparkMaterializedView( + testMvName, + s"SELECT name, COUNT(*) AS count FROM $testTable GROUP BY name", + Array(testTable), + Map.empty) - the[IllegalStateException] thrownBy - mv.buildStream(spark) - } + the[IllegalStateException] thrownBy + mv.buildStream(spark) } test("build batch with ID expression option") { - withTable(testTable) { - sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") - val testMvQuery = s"SELECT time, name FROM $testTable" - val mv = FlintSparkMaterializedView( - testMvName, - testMvQuery, - Array.empty, - Map.empty, - FlintSparkIndexOptions(Map("id_expression" -> "time"))) + val testMvQuery = s"SELECT time, name FROM $testTable" + val mv = FlintSparkMaterializedView( + testMvName, + testMvQuery, + Array.empty, + Map.empty, + FlintSparkIndexOptions(Map("id_expression" -> "time"))) - comparePlans( - mv.build(spark, None).queryExecution.logical, - spark - .sql(testMvQuery) - .withColumn(ID_COLUMN, expr("time")) - .queryExecution - .logical, - checkAnalysis = false) - } + val batchDf = mv.build(spark, None) + batchDf.idColumn() shouldBe Some(UnresolvedAttribute(Seq("time"))) } test("build batch should not have ID column if non-aggregated") { - withTable(testTable) { - sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") - val testMvQuery = s"SELECT time, name FROM $testTable" - val mv = FlintSparkMaterializedView(testMvName, testMvQuery, Array.empty, Map.empty) + val testMvQuery = s"SELECT time, name FROM $testTable" + val mv = FlintSparkMaterializedView(testMvName, testMvQuery, Array.empty, Map.empty) - comparePlans( - mv.build(spark, None).queryExecution.logical, - spark.sql(testMvQuery).queryExecution.logical, - checkAnalysis = false) - } + val batchDf = mv.build(spark, None) + batchDf.idColumn() shouldBe None } test("build batch should have ID column if aggregated") { - withTable(testTable) { - sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") - val mv = FlintSparkMaterializedView( - testMvName, - s""" SELECT time, name, AVG(age) + val mv = FlintSparkMaterializedView( + testMvName, + s""" SELECT time, name, AVG(age) AS avg | FROM $testTable | GROUP BY time, name""".stripMargin, - Array.empty, - Map.empty) + Array.empty, + Map.empty) + + val batchDf = mv.build(spark, None) + batchDf.idColumn() shouldBe Some( + Sha1( + ConcatWs( + Seq( + Literal(UTF8String.fromString("\0"), StringType), + UnresolvedAttribute(Seq("time")), + UnresolvedAttribute(Seq("name")), + UnresolvedAttribute(Seq("avg")))))) + } - comparePlans( - mv.build(spark, None).queryExecution.logical, - spark - .table(testTable) - .groupBy("time", "name") - .avg("age") - .withColumn(ID_COLUMN, sha1(concat_ws("\0", col("time"), col("name"), col("avg(age)")))) - .queryExecution - .logical, - checkAnalysis = false) - } + test("build batch should not have ID column if aggregated with ID expression empty") { + val mv = FlintSparkMaterializedView( + testMvName, + s""" SELECT time, name, AVG(age) AS avg + | FROM $testTable + | GROUP BY time, name""".stripMargin, + Array.empty, + Map.empty, + FlintSparkIndexOptions(Map("id_expression" -> ""))) + + val batchDf = mv.build(spark, None) + batchDf.idColumn() shouldBe None + } + + test("build batch should have ID column if aggregated join") { + val mv = FlintSparkMaterializedView( + testMvName, + s""" SELECT t1.time, t1.name, AVG(t1.age) AS avg + | FROM $testTable AS t1 + | JOIN $testTable AS t2 + | ON t1.time = t2.time + | GROUP BY t1.time, t1.name""".stripMargin, + Array.empty, + Map.empty) + + val batchDf = mv.build(spark, None) + batchDf.idColumn() shouldBe Some( + Sha1( + ConcatWs( + Seq( + Literal(UTF8String.fromString("\0"), StringType), + UnresolvedAttribute(Seq("time")), + UnresolvedAttribute(Seq("name")), + UnresolvedAttribute(Seq("avg")))))) } test("build stream with ID expression option") { - withTable(testTable) { - sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") - val mv = FlintSparkMaterializedView( - testMvName, - s"SELECT time, name FROM $testTable", - Array.empty, - Map.empty, - FlintSparkIndexOptions(Map("auto_refresh" -> "true", "id_expression" -> "name"))) - - mv.buildStream(spark).queryExecution.logical.exists { - case Project(projectList, _) => - projectList.exists { - case Alias(UnresolvedAttribute(Seq("name")), ID_COLUMN) => true - case _ => false - } - case _ => false - } shouldBe true - } + val mv = FlintSparkMaterializedView( + testMvName, + s"SELECT time, name FROM $testTable", + Array.empty, + Map.empty, + FlintSparkIndexOptions(Map("auto_refresh" -> "true", "id_expression" -> "time"))) + + val streamDf = mv.buildStream(spark) + streamDf.idColumn() shouldBe Some(UnresolvedAttribute(Seq("time"))) } test("build stream should not have ID column if non-aggregated") { - withTable(testTable) { - sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") - val mv = FlintSparkMaterializedView( - testMvName, - s"SELECT time, name FROM $testTable", - Array.empty, - Map.empty, - FlintSparkIndexOptions(Map("auto_refresh" -> "true"))) + val mv = FlintSparkMaterializedView( + testMvName, + s"SELECT time, name FROM $testTable", + Array.empty, + Map.empty, + FlintSparkIndexOptions(Map("auto_refresh" -> "true"))) - mv.buildStream(spark).queryExecution.logical.exists { - case Project(projectList, _) => - projectList.forall(_.name != ID_COLUMN) - case _ => false - } shouldBe true - } + val streamDf = mv.buildStream(spark) + streamDf.idColumn() shouldBe None } test("build stream should have ID column if aggregated") { - withTable(testTable) { - sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") - val testMvQuery = - s""" + val testMvQuery = + s""" | SELECT | window.start AS startTime, | COUNT(*) AS count | FROM $testTable | GROUP BY TUMBLE(time, '1 Minute') |""".stripMargin - val mv = FlintSparkMaterializedView( - testMvName, - testMvQuery, - Array.empty, - Map.empty, - FlintSparkIndexOptions(Map("auto_refresh" -> "true", "watermark_delay" -> "10 Seconds"))) - - mv.buildStream(spark).queryExecution.logical.exists { - case Project(projectList, _) => - val asciiNull = UTF8String.fromString("\0") - projectList.exists { - case Alias( - Sha1( - ConcatWs( - Seq( - Literal(`asciiNull`, StringType), - UnresolvedAttribute(Seq("startTime")), - UnresolvedAttribute(Seq("count"))))), - ID_COLUMN) => - true - case _ => false - } - case _ => false - } shouldBe true - } + val mv = FlintSparkMaterializedView( + testMvName, + testMvQuery, + Array.empty, + Map.empty, + FlintSparkIndexOptions(Map("auto_refresh" -> "true", "watermark_delay" -> "10 Seconds"))) + + val streamDf = mv.buildStream(spark) + streamDf.idColumn() shouldBe Some( + Sha1( + ConcatWs( + Seq( + Literal(UTF8String.fromString("\0"), StringType), + UnresolvedAttribute(Seq("startTime")), + UnresolvedAttribute(Seq("count")))))) + } + + test("build stream should not have ID column if aggregated with ID expression empty") { + val testMvQuery = + s""" + | SELECT + | window.start AS startTime, + | COUNT(*) AS count + | FROM $testTable + | GROUP BY TUMBLE(time, '1 Minute') + |""".stripMargin + val mv = FlintSparkMaterializedView( + testMvName, + testMvQuery, + Array.empty, + Map.empty, + FlintSparkIndexOptions( + Map("auto_refresh" -> "true", "watermark_delay" -> "10 Seconds", "id_expression" -> ""))) + + val streamDf = mv.buildStream(spark) + streamDf.idColumn() shouldBe None } private def withAggregateMaterializedView( @@ -333,19 +349,16 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { sourceTables: Array[String], options: Map[String, String])(codeBlock: LogicalPlan => Unit): Unit = { - withTable(testTable) { - sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") - val mv = - FlintSparkMaterializedView( - testMvName, - query, - sourceTables, - Map.empty, - FlintSparkIndexOptions(options)) - - val actualPlan = mv.buildStream(spark).queryExecution.logical - codeBlock(actualPlan) - } + val mv = + FlintSparkMaterializedView( + testMvName, + query, + sourceTables, + Map.empty, + FlintSparkIndexOptions(options)) + + val actualPlan = mv.buildStream(spark).queryExecution.logical + codeBlock(actualPlan) } } @@ -372,4 +385,15 @@ object FlintSparkMaterializedViewSuite { logicalPlan) } } + + implicit class DataFrameIdColumnExtractor(val df: DataFrame) { + + def idColumn(): Option[Expression] = { + df.queryExecution.logical.collectFirst { case Project(projectList, _) => + projectList.collectFirst { case Alias(child, ID_COLUMN) => + child + } + }.flatten + } + } }