Skip to content

Commit

Permalink
Handle struct type in tumble function
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Nov 22, 2024
1 parent 0fb985a commit da2d437
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry
import org.opensearch.flint.core.metadata.FlintJsonHelper._

import org.apache.spark.sql.{Column, DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project}
import org.apache.spark.sql.flint.datatype.FlintDataType
import org.apache.spark.sql.functions.{col, concat_ws, expr, sha1}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.functions.{col, concat_ws, expr, sha1, to_json}
import org.apache.spark.sql.types.{MapType, StructType}

/**
* Flint index interface in Spark.
Expand Down Expand Up @@ -144,7 +144,18 @@ object FlintSparkIndex {
df.withColumn(ID_COLUMN, expr(idExpr))

case None if isAggregated =>
val allOutputCols = df.columns.map(col)
// Since concat doesn't support struct or map type, convert these to json which is more
// deterministic than casting to string, as its format may vary across Spark versions.
val allOutputCols = df.schema.fields.map { field =>
field.dataType match {
case _: StructType | _: MapType =>
to_json(col(field.name))
case _ =>
col(field.name)
}
}

// TODO: use only grouping columns
df.withColumn(ID_COLUMN, sha1(concat_ws("\0", allOutputCols: _*)))

case _ => df
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.flint.spark

import java.sql.Timestamp

import org.opensearch.flint.spark.FlintSparkIndex.{generateIdColumn, ID_COLUMN}
import org.scalatest.matchers.should.Matchers

Expand Down Expand Up @@ -42,6 +44,25 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers {
resultDf.select(ID_COLUMN).distinct().count() shouldBe 2
}

test("should generate ID column for aggregated query using tumble function") {
val df = spark
.createDataFrame(
Seq(
(Timestamp.valueOf("2023-01-01 00:00:00"), 1, "Alice"),
(Timestamp.valueOf("2023-01-01 00:10:00"), 2, "Bob"),
(Timestamp.valueOf("2023-01-01 00:15:00"), 3, "Alice")))
.toDF("timestamp", "id", "name")
val groupByDf = df
.selectExpr("TUMBLE(timestamp, '10 minutes') as window", "name")
.groupBy("window", "name")
.count()
.select("window.start", "name", "count")
val options = FlintSparkIndexOptions.empty

val resultDf = generateIdColumn(groupByDf, options)
resultDf.select(ID_COLUMN).distinct().count() shouldBe 3
}

test("should not generate ID column for aggregated query if ID expression is empty") {
val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name")
val options = FlintSparkIndexOptions.empty
Expand All @@ -58,7 +79,7 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers {
resultDf.columns should not contain ID_COLUMN
}

test("should generate ID column for aggregated query with multiple columns") {
test("should generate ID column for aggregated query with various column types") {
val schema = StructType.fromDDL("""
boolean_col BOOLEAN,
string_col STRING,
Expand Down Expand Up @@ -93,7 +114,7 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers {
"float_col",
"timestamp_col",
"date_col",
"struct_col.subfield1",
"struct_col",
"struct_col.subfield2")
.count()

Expand Down

0 comments on commit da2d437

Please sign in to comment.