From 24f578bdeb44f6d41a68d56ded2ebf57745d2f1c Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Wed, 1 Nov 2023 14:31:06 -0700 Subject: [PATCH] Add more UT Signed-off-by: Chen Dai --- .../covering/FlintSparkCoveringIndex.scala | 10 +-- .../FlintSparkCoveringIndexSuite.scala | 80 ++++++++++++++++++- 2 files changed, 79 insertions(+), 11 deletions(-) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala index 3aa4701d5..27d971fc0 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala @@ -68,18 +68,12 @@ case class FlintSparkCoveringIndex( // Add optional ID column if (options.idExpression().isDefined) { val idExpr = options.idExpression().get - logInfo(s"Generate ID column based on expression $idExpr") + job = job.withColumn(ID_COLUMN, expr(idExpr)) colNames = colNames :+ ID_COLUMN } else { - val idColNames = - spark - .table(tableName) - .columns - .toSet - .intersect(Set("timestamp", "@timestamp")) - + val idColNames = job.columns.toSet.intersect(Set("timestamp", "@timestamp")) if (idColNames.isEmpty) { logWarning("Cannot generate ID column which may cause duplicate data when restart") } else { diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala index fe7df433b..f4cacd385 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala @@ -6,14 +6,18 @@ package org.opensearch.flint.spark.covering import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN +import org.opensearch.flint.spark.FlintSparkIndexOptions import org.scalatest.matchers.should.Matchers import org.apache.spark.FlintSuite import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.functions.{col, concat, input_file_name, sha1} +import org.apache.spark.sql.functions._ class FlintSparkCoveringIndexSuite extends FlintSuite with Matchers { + /** Test table name */ + val testTable = "spark_catalog.default.ci_test" + test("get covering index name") { val index = FlintSparkCoveringIndex("ci", "spark_catalog.default.test", Map("name" -> "string")) @@ -33,21 +37,91 @@ class FlintSparkCoveringIndexSuite extends FlintSuite with Matchers { } } - test("should generate id column based on timestamp column") { - val testTable = "spark_catalog.default.ci_test" + test("should generate id column based on ID expression in index options") { withTable(testTable) { sql(s"CREATE TABLE $testTable (timestamp TIMESTAMP, name STRING) USING JSON") + val index = + FlintSparkCoveringIndex( + "name_idx", + testTable, + Map("name" -> "string"), + options = FlintSparkIndexOptions(Map("id_expression" -> "now()"))) + + assertDataFrameEquals( + index.build(spark, None), + spark + .table(testTable) + .withColumn(ID_COLUMN, expr("now()")) + .select(col("name"), col(ID_COLUMN))) + } + } + + test("should generate id column based on timestamp column if found") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (timestamp TIMESTAMP, name STRING) USING JSON") + val index = FlintSparkCoveringIndex("name_idx", testTable, Map("name" -> "string")) + + assertDataFrameEquals( + index.build(spark, None), + spark + .table(testTable) + .withColumn(ID_COLUMN, sha1(concat(input_file_name(), col("timestamp")))) + .select(col("name"), col(ID_COLUMN))) + } + } + + test("should generate id column based on @timestamp column if found") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (`@timestamp` TIMESTAMP, name STRING) USING JSON") + val index = FlintSparkCoveringIndex("name_idx", testTable, Map("name" -> "string")) + + assertDataFrameEquals( + index.build(spark, None), + spark + .table(testTable) + .withColumn(ID_COLUMN, sha1(concat(input_file_name(), col("@timestamp")))) + .select(col("name"), col(ID_COLUMN))) + } + } + + test("should not generate id column if no ID expression or timestamp column") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (name STRING, age INTEGER) USING JSON") val index = FlintSparkCoveringIndex("name_idx", testTable, Map("name" -> "string")) assertDataFrameEquals( index.build(spark, None), spark .table(testTable) + .select(col("name"))) + } + } + + test("should generate id column if micro batch has timestamp column") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (timestamp TIMESTAMP, name STRING) USING JSON") + val index = FlintSparkCoveringIndex("name_idx", testTable, Map("name" -> "string")) + val batch = spark.read.table(testTable).select("timestamp", "name") + + assertDataFrameEquals( + index.build(spark, Some(batch)), + batch .withColumn(ID_COLUMN, sha1(concat(input_file_name(), col("timestamp")))) .select(col("name"), col(ID_COLUMN))) } } + test("should not generate id column if micro batch doesn't have timestamp column") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (timestamp TIMESTAMP, name STRING) USING JSON") + val index = FlintSparkCoveringIndex("name_idx", testTable, Map("name" -> "string")) + val batch = spark.read.table(testTable).select("name") + + assertDataFrameEquals(index.build(spark, Some(batch)), batch.select(col("name"))) + } + } + + /* Assert unresolved logical plan in DataFrame equals without semantic analysis */ private def assertDataFrameEquals(df1: DataFrame, df2: DataFrame): Unit = { comparePlans(df1.queryExecution.logical, df2.queryExecution.logical, checkAnalysis = false) }