From da2d437df1a571a81ff418a54634b89d14d79a6a Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Wed, 20 Nov 2024 15:10:16 -0800 Subject: [PATCH] Handle struct type in tumble function Signed-off-by: Chen Dai --- .../flint/spark/FlintSparkIndex.scala | 19 +++++++++++--- .../flint/spark/FlintSparkIndexSuite.scala | 25 +++++++++++++++++-- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala index e2a60f319..130be61fc 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala @@ -12,10 +12,10 @@ import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.core.metadata.FlintJsonHelper._ import org.apache.spark.sql.{Column, DataFrame, SparkSession} -import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project} import org.apache.spark.sql.flint.datatype.FlintDataType -import org.apache.spark.sql.functions.{col, concat_ws, expr, sha1} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.functions.{col, concat_ws, expr, sha1, to_json} +import org.apache.spark.sql.types.{MapType, StructType} /** * Flint index interface in Spark. @@ -144,7 +144,18 @@ object FlintSparkIndex { df.withColumn(ID_COLUMN, expr(idExpr)) case None if isAggregated => - val allOutputCols = df.columns.map(col) + // Since concat doesn't support struct or map type, convert these to json which is more + // deterministic than casting to string, as its format may vary across Spark versions. + val allOutputCols = df.schema.fields.map { field => + field.dataType match { + case _: StructType | _: MapType => + to_json(col(field.name)) + case _ => + col(field.name) + } + } + + // TODO: use only grouping columns df.withColumn(ID_COLUMN, sha1(concat_ws("\0", allOutputCols: _*))) case _ => df diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala index 1ab0f8e02..e67818532 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala @@ -5,6 +5,8 @@ package org.opensearch.flint.spark +import java.sql.Timestamp + import org.opensearch.flint.spark.FlintSparkIndex.{generateIdColumn, ID_COLUMN} import org.scalatest.matchers.should.Matchers @@ -42,6 +44,25 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { resultDf.select(ID_COLUMN).distinct().count() shouldBe 2 } + test("should generate ID column for aggregated query using tumble function") { + val df = spark + .createDataFrame( + Seq( + (Timestamp.valueOf("2023-01-01 00:00:00"), 1, "Alice"), + (Timestamp.valueOf("2023-01-01 00:10:00"), 2, "Bob"), + (Timestamp.valueOf("2023-01-01 00:15:00"), 3, "Alice"))) + .toDF("timestamp", "id", "name") + val groupByDf = df + .selectExpr("TUMBLE(timestamp, '10 minutes') as window", "name") + .groupBy("window", "name") + .count() + .select("window.start", "name", "count") + val options = FlintSparkIndexOptions.empty + + val resultDf = generateIdColumn(groupByDf, options) + resultDf.select(ID_COLUMN).distinct().count() shouldBe 3 + } + test("should not generate ID column for aggregated query if ID expression is empty") { val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name") val options = FlintSparkIndexOptions.empty @@ -58,7 +79,7 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { resultDf.columns should not contain ID_COLUMN } - test("should generate ID column for aggregated query with multiple columns") { + test("should generate ID column for aggregated query with various column types") { val schema = StructType.fromDDL(""" boolean_col BOOLEAN, string_col STRING, @@ -93,7 +114,7 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { "float_col", "timestamp_col", "date_col", - "struct_col.subfield1", + "struct_col", "struct_col.subfield2") .count()