diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala index 609fd7b4c..32988a5b2 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala @@ -13,11 +13,9 @@ import org.opensearch.flint.core.metadata.FlintJsonHelper._ import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, SparkSession} -import org.apache.spark.sql.catalyst.plans.logical.Aggregate -import org.apache.spark.sql.catalyst.util.quoteIfNeeded import org.apache.spark.sql.flint.datatype.FlintDataType -import org.apache.spark.sql.functions.{col, concat_ws, expr, sha1, to_json} -import org.apache.spark.sql.types.{MapType, StructType} +import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.types.StructType /** * Flint index interface in Spark. @@ -137,31 +135,10 @@ object FlintSparkIndex extends Logging { * DataFrame with/without ID column */ def addIdColumn(df: DataFrame, options: FlintSparkIndexOptions): DataFrame = { - def isAggregated: Boolean = - df.queryExecution.logical.exists(_.isInstanceOf[Aggregate]) - options.idExpression() match { case Some(idExpr) if idExpr.nonEmpty => logInfo(s"Using user-provided ID expression: $idExpr") df.withColumn(ID_COLUMN, expr(idExpr)) - - case None if isAggregated => - // Since concat doesn't support struct or map type, convert these to json which is more - // deterministic than casting to string, as its format may vary across Spark versions. - val allOutputCols = df.schema.fields.map { field => - field.dataType match { - case _: StructType | _: MapType => - to_json(col(quoteIfNeeded(field.name))) - case _ => - col(quoteIfNeeded(field.name)) - } - } - - // TODO: 1) use only grouping columns; 2) ensure aggregation is on top level - val idCol = sha1(concat_ws("\0", allOutputCols: _*)) - logInfo(s"Generated ID column for aggregated query: $idCol") - df.withColumn(ID_COLUMN, idCol) - case _ => df } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala index 6d0d972f6..8ec4bec40 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala @@ -10,9 +10,8 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.FlintSuite import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction} import org.apache.spark.sql.catalyst.expressions.{Add, ConcatWs, Literal, Sha1, StructsToJson} -import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -43,49 +42,6 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { resultDf.columns should not contain ID_COLUMN } - test("should add ID column for aggregated query") { - val df = spark - .createDataFrame(Seq((1, "Alice"), (2, "Bob"), (3, "Alice"))) - .toDF("id", "name") - .groupBy("name") - .count() - val options = FlintSparkIndexOptions.empty - - val resultDf = addIdColumn(df, options) - resultDf.idColumn() shouldBe Some( - Sha1( - ConcatWs( - Seq( - Literal(UTF8String.fromString("\0"), StringType), - UnresolvedAttribute(Seq("name")), - UnresolvedAttribute(Seq("count")))))) - resultDf.select(ID_COLUMN).distinct().count() shouldBe 2 - } - - test("should add ID column for aggregated query with quoted alias") { - val df = spark - .createDataFrame( - sparkContext.parallelize( - Seq( - Row(1, "Alice", Row("WA", "Seattle")), - Row(2, "Bob", Row("OR", "Portland")), - Row(3, "Alice", Row("WA", "Seattle")))), - StructType.fromDDL("id INT, name STRING, address STRUCT")) - .toDF("id", "name", "address") - .groupBy(col("name").as("test.name"), col("address").as("test.address")) - .count() - val options = FlintSparkIndexOptions.empty - - val resultDf = addIdColumn(df, options) - resultDf.idColumn() shouldBe Some( - Sha1(ConcatWs(Seq( - Literal(UTF8String.fromString("\0"), StringType), - UnresolvedAttribute(Seq("test.name")), - new StructsToJson(UnresolvedAttribute(Seq("test.address"))), - UnresolvedAttribute(Seq("count")))))) - resultDf.select(ID_COLUMN).distinct().count() shouldBe 2 - } - test("should generate ID column for various column types") { val schema = StructType.fromDDL(""" boolean_col BOOLEAN, @@ -124,23 +80,32 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { "struct_col", "struct_col.subfield2") .count() - val options = FlintSparkIndexOptions.empty + val options = FlintSparkIndexOptions(Map("id_expression" -> + "sha1(concat_ws('\0',boolean_col,string_col,long_col,int_col,double_col,float_col,timestamp_col,date_col,to_json(struct_col),struct_col.subfield2))")) val resultDf = addIdColumn(aggregatedDf, options) resultDf.idColumn() shouldBe Some( - Sha1(ConcatWs(Seq( - Literal(UTF8String.fromString("\0"), StringType), - UnresolvedAttribute(Seq("boolean_col")), - UnresolvedAttribute(Seq("string_col")), - UnresolvedAttribute(Seq("long_col")), - UnresolvedAttribute(Seq("int_col")), - UnresolvedAttribute(Seq("double_col")), - UnresolvedAttribute(Seq("float_col")), - UnresolvedAttribute(Seq("timestamp_col")), - UnresolvedAttribute(Seq("date_col")), - new StructsToJson(UnresolvedAttribute(Seq("struct_col"))), - UnresolvedAttribute(Seq("subfield2")), - UnresolvedAttribute(Seq("count")))))) + UnresolvedFunction( + "sha1", + Seq(UnresolvedFunction( + "concat_ws", + Seq( + Literal(UTF8String.fromString("\0"), StringType), + UnresolvedAttribute(Seq("boolean_col")), + UnresolvedAttribute(Seq("string_col")), + UnresolvedAttribute(Seq("long_col")), + UnresolvedAttribute(Seq("int_col")), + UnresolvedAttribute(Seq("double_col")), + UnresolvedAttribute(Seq("float_col")), + UnresolvedAttribute(Seq("timestamp_col")), + UnresolvedAttribute(Seq("date_col")), + UnresolvedFunction( + "to_json", + Seq(UnresolvedAttribute(Seq("struct_col"))), + isDistinct = false), + UnresolvedAttribute(Seq("struct_col", "subfield2"))), + isDistinct = false)), + isDistinct = false)) resultDf.select(ID_COLUMN).distinct().count() shouldBe 1 } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala index 7774eb2fb..838eddf21 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala @@ -18,10 +18,9 @@ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} import org.apache.spark.sql.catalyst.dsl.expressions.{intToLiteral, stringToLiteral, DslAttr, DslExpression, StringToAttributeConversionHelper} import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan -import org.apache.spark.sql.catalyst.expressions.{Attribute, ConcatWs, Literal, Sha1} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.util.IntervalUtils -import org.apache.spark.sql.types.StringType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.unsafe.types.UTF8String @@ -118,7 +117,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { | FROM $testTable | GROUP BY TUMBLE(time, '1 Minute') |""".stripMargin - val options = Map("watermark_delay" -> "30 Seconds", "id_expression" -> "") + val options = Map("watermark_delay" -> "30 Seconds") withAggregateMaterializedView(testQuery, Array(testTable), options) { actualPlan => comparePlans( @@ -143,7 +142,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { | WHERE age > 30 | GROUP BY TUMBLE(time, '1 Minute') |""".stripMargin - val options = Map("watermark_delay" -> "30 Seconds", "id_expression" -> "") + val options = Map("watermark_delay" -> "30 Seconds") withAggregateMaterializedView(testQuery, Array(testTable), options) { actualPlan => comparePlans( @@ -208,7 +207,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { batchDf.idColumn() shouldBe Some(UnresolvedAttribute(Seq("time"))) } - test("build batch should not have ID column if non-aggregated") { + test("build batch should not have ID column if not provided") { val testMvQuery = s"SELECT time, name FROM $testTable" val mv = FlintSparkMaterializedView(testMvName, testMvQuery, Array.empty, Map.empty) @@ -216,62 +215,6 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { batchDf.idColumn() shouldBe None } - test("build batch should have ID column if aggregated") { - val mv = FlintSparkMaterializedView( - testMvName, - s""" SELECT time, name, AVG(age) AS avg - | FROM $testTable - | GROUP BY time, name""".stripMargin, - Array.empty, - Map.empty) - - val batchDf = mv.build(spark, None) - batchDf.idColumn() shouldBe Some( - Sha1( - ConcatWs( - Seq( - Literal(UTF8String.fromString("\0"), StringType), - UnresolvedAttribute(Seq("time")), - UnresolvedAttribute(Seq("name")), - UnresolvedAttribute(Seq("avg")))))) - } - - test("build batch should not have ID column if aggregated with ID expression empty") { - val mv = FlintSparkMaterializedView( - testMvName, - s""" SELECT time, name, AVG(age) AS avg - | FROM $testTable - | GROUP BY time, name""".stripMargin, - Array.empty, - Map.empty, - FlintSparkIndexOptions(Map("id_expression" -> ""))) - - val batchDf = mv.build(spark, None) - batchDf.idColumn() shouldBe None - } - - test("build batch should have ID column if aggregated join") { - val mv = FlintSparkMaterializedView( - testMvName, - s""" SELECT t1.time, t1.name, AVG(t1.age) AS avg - | FROM $testTable AS t1 - | JOIN $testTable AS t2 - | ON t1.time = t2.time - | GROUP BY t1.time, t1.name""".stripMargin, - Array.empty, - Map.empty) - - val batchDf = mv.build(spark, None) - batchDf.idColumn() shouldBe Some( - Sha1( - ConcatWs( - Seq( - Literal(UTF8String.fromString("\0"), StringType), - UnresolvedAttribute(Seq("time")), - UnresolvedAttribute(Seq("name")), - UnresolvedAttribute(Seq("avg")))))) - } - test("build stream with ID expression option") { val mv = FlintSparkMaterializedView( testMvName, @@ -284,7 +227,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { streamDf.idColumn() shouldBe Some(UnresolvedAttribute(Seq("time"))) } - test("build stream should not have ID column if non-aggregated") { + test("build stream should not have ID column if not provided") { val mv = FlintSparkMaterializedView( testMvName, s"SELECT time, name FROM $testTable", @@ -296,53 +239,6 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { streamDf.idColumn() shouldBe None } - test("build stream should have ID column if aggregated") { - val testMvQuery = - s""" - | SELECT - | window.start AS startTime, - | COUNT(*) AS count - | FROM $testTable - | GROUP BY TUMBLE(time, '1 Minute') - |""".stripMargin - val mv = FlintSparkMaterializedView( - testMvName, - testMvQuery, - Array.empty, - Map.empty, - FlintSparkIndexOptions(Map("auto_refresh" -> "true", "watermark_delay" -> "10 Seconds"))) - - val streamDf = mv.buildStream(spark) - streamDf.idColumn() shouldBe Some( - Sha1( - ConcatWs( - Seq( - Literal(UTF8String.fromString("\0"), StringType), - UnresolvedAttribute(Seq("startTime")), - UnresolvedAttribute(Seq("count")))))) - } - - test("build stream should not have ID column if aggregated with ID expression empty") { - val testMvQuery = - s""" - | SELECT - | window.start AS startTime, - | COUNT(*) AS count - | FROM $testTable - | GROUP BY TUMBLE(time, '1 Minute') - |""".stripMargin - val mv = FlintSparkMaterializedView( - testMvName, - testMvQuery, - Array.empty, - Map.empty, - FlintSparkIndexOptions( - Map("auto_refresh" -> "true", "watermark_delay" -> "10 Seconds", "id_expression" -> ""))) - - val streamDf = mv.buildStream(spark) - streamDf.idColumn() shouldBe None - } - private def withAggregateMaterializedView( query: String, sourceTables: Array[String], diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala index 0e15e5f8a..7dcd83897 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala @@ -163,7 +163,7 @@ class FlintSparkMaterializedViewSqlITSuite extends FlintSparkSuite { | auto_refresh = true, | checkpoint_location = '${checkpointDir.getAbsolutePath}', | watermark_delay = '1 Second', - | id_expression = 'count' + | id_expression = "sha1(concat_ws('\0',startTime))" | ) |""".stripMargin) @@ -174,8 +174,7 @@ class FlintSparkMaterializedViewSqlITSuite extends FlintSparkSuite { job.get.processAllAvailable() } - // 1 row missing due to ID conflict intentionally - flint.queryIndex(testFlintIndex).count() shouldBe 2 + flint.queryIndex(testFlintIndex).count() shouldBe 3 } }