Skip to content

Commit

Permalink
Refactor UT and doc
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Nov 22, 2024
1 parent da2d437 commit 962e25d
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 228 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ object FlintSparkIndex {
}
}

// TODO: use only grouping columns
// TODO: 1) use only grouping columns; 2) ensure aggregation is on top level
df.withColumn(ID_COLUMN, sha1(concat_ws("\0", allOutputCols: _*)))

case _ => df
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package org.opensearch.flint.spark

import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName._
import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.SchedulerMode
import org.scalatest.matchers.should.Matchers

import org.apache.spark.FlintSuite
Expand All @@ -22,6 +21,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers {
WATERMARK_DELAY.toString shouldBe "watermark_delay"
OUTPUT_MODE.toString shouldBe "output_mode"
INDEX_SETTINGS.toString shouldBe "index_settings"
ID_EXPRESSION.toString shouldBe "id_expression"
EXTRA_OPTIONS.toString shouldBe "extra_options"
}

Expand All @@ -36,6 +36,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers {
"watermark_delay" -> "30 Seconds",
"output_mode" -> "complete",
"index_settings" -> """{"number_of_shards": 3}""",
"id_expression" -> """sha1(col("timestamp"))""",
"extra_options" ->
""" {
| "alb_logs": {
Expand All @@ -55,6 +56,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers {
options.watermarkDelay() shouldBe Some("30 Seconds")
options.outputMode() shouldBe Some("complete")
options.indexSettings() shouldBe Some("""{"number_of_shards": 3}""")
options.idExpression() shouldBe Some("""sha1(col("timestamp"))""")
options.extraSourceOptions("alb_logs") shouldBe Map("opt1" -> "val1")
options.extraSinkOptions() shouldBe Map("opt2" -> "val2", "opt3" -> "val3")
}
Expand Down Expand Up @@ -83,6 +85,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers {
options.watermarkDelay() shouldBe empty
options.outputMode() shouldBe empty
options.indexSettings() shouldBe empty
options.idExpression() shouldBe empty
options.extraSourceOptions("alb_logs") shouldBe empty
options.extraSinkOptions() shouldBe empty
options.optionsWithDefault should contain("auto_refresh" -> "false")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers {
checkAnswer(resultDf.select(ID_COLUMN), Seq(Row(11), Row(12)))
}

test("should not generate ID column if ID expression is not provided") {
val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name")
val options = FlintSparkIndexOptions.empty

val resultDf = generateIdColumn(df, options)
resultDf.columns should not contain ID_COLUMN
}

test("should not generate ID column if ID expression is empty") {
val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name")
val options = FlintSparkIndexOptions.empty
Expand All @@ -41,45 +49,11 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers {
val options = FlintSparkIndexOptions.empty

val resultDf = generateIdColumn(df, options)
resultDf.columns should contain(ID_COLUMN)
resultDf.select(ID_COLUMN).distinct().count() shouldBe 2
}

test("should generate ID column for aggregated query using tumble function") {
val df = spark
.createDataFrame(
Seq(
(Timestamp.valueOf("2023-01-01 00:00:00"), 1, "Alice"),
(Timestamp.valueOf("2023-01-01 00:10:00"), 2, "Bob"),
(Timestamp.valueOf("2023-01-01 00:15:00"), 3, "Alice")))
.toDF("timestamp", "id", "name")
val groupByDf = df
.selectExpr("TUMBLE(timestamp, '10 minutes') as window", "name")
.groupBy("window", "name")
.count()
.select("window.start", "name", "count")
val options = FlintSparkIndexOptions.empty

val resultDf = generateIdColumn(groupByDf, options)
resultDf.select(ID_COLUMN).distinct().count() shouldBe 3
}

test("should not generate ID column for aggregated query if ID expression is empty") {
val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name")
val options = FlintSparkIndexOptions.empty

val resultDf = generateIdColumn(df, options)
resultDf.columns should not contain ID_COLUMN
}

test("should not generate ID column if ID expression is not provided") {
val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name")
val options = FlintSparkIndexOptions.empty

val resultDf = generateIdColumn(df, options)
resultDf.columns should not contain ID_COLUMN
}

test("should generate ID column for aggregated query with various column types") {
test("should generate ID column for various column types") {
val schema = StructType.fromDDL("""
boolean_col BOOLEAN,
string_col STRING,
Expand Down Expand Up @@ -117,9 +91,10 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers {
"struct_col",
"struct_col.subfield2")
.count()

val options = FlintSparkIndexOptions.empty

val resultDf = generateIdColumn(aggregatedDf, options)
resultDf.columns should contain(ID_COLUMN)
resultDf.select(ID_COLUMN).distinct().count() shouldBe 1
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ import org.scalatest.matchers.must.Matchers.contain
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

import org.apache.spark.FlintSuite
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.functions.{col, expr}

class FlintSparkCoveringIndexSuite extends FlintSuite {
Expand Down Expand Up @@ -71,31 +73,12 @@ class FlintSparkCoveringIndexSuite extends FlintSuite {
Map("name" -> "string"),
options = FlintSparkIndexOptions(Map("id_expression" -> "name")))

comparePlans(
index.build(spark, None).queryExecution.logical,
spark
.table(testTable)
.select(col("name"))
.withColumn(ID_COLUMN, expr("name"))
.queryExecution
.logical,
checkAnalysis = false)
}
}

test("build batch should not have ID column without ID expression option") {
withTable(testTable) {
sql(s"CREATE TABLE $testTable (name STRING, age INTEGER) USING JSON")
val index = FlintSparkCoveringIndex("name_idx", testTable, Map("name" -> "string"))

comparePlans(
index.build(spark, None).queryExecution.logical,
spark
.table(testTable)
.select(col("name"))
.queryExecution
.logical,
checkAnalysis = false)
val batchDf = index.build(spark, None)
batchDf.queryExecution.logical.collectFirst { case Project(projectList, _) =>
projectList.collectFirst { case Alias(child, ID_COLUMN) =>
child
}
}.flatten shouldBe Some(UnresolvedAttribute(Seq("name")))
}
}

Expand All @@ -109,35 +92,12 @@ class FlintSparkCoveringIndexSuite extends FlintSuite {
options =
FlintSparkIndexOptions(Map("auto_refresh" -> "true", "id_expression" -> "name")))

comparePlans(
index.build(spark, Some(spark.table(testTable))).queryExecution.logical,
spark
.table(testTable)
.select("name")
.withColumn(ID_COLUMN, col("name"))
.queryExecution
.logical,
checkAnalysis = false)
}
}

test("build stream should not have ID column without ID expression option") {
withTable(testTable) {
sql(s"CREATE TABLE $testTable (name STRING, age INTEGER) USING JSON")
val index = FlintSparkCoveringIndex(
"name_idx",
testTable,
Map("name" -> "string"),
options = FlintSparkIndexOptions(Map("auto_refresh" -> "true")))

comparePlans(
index.build(spark, Some(spark.table(testTable))).queryExecution.logical,
spark
.table(testTable)
.select(col("name"))
.queryExecution
.logical,
checkAnalysis = false)
val streamDf = index.build(spark, Some(spark.table(testTable)))
streamDf.queryExecution.logical.collectFirst { case Project(projectList, _) =>
projectList.collectFirst { case Alias(child, ID_COLUMN) =>
child
}
}.flatten shouldBe Some(UnresolvedAttribute(Seq("name")))
}
}
}
Loading

0 comments on commit 962e25d

Please sign in to comment.