diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala index e87c2ebda..3aa4701d5 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala @@ -15,7 +15,7 @@ import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.{getFlintInde import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.functions.{col, concat, input_file_name, sha1} +import org.apache.spark.sql.functions._ /** * Flint covering index in Spark. @@ -62,30 +62,38 @@ case class FlintSparkCoveringIndex( } override def build(spark: SparkSession, df: Option[DataFrame]): DataFrame = { - val colNames = indexedColumns.keys.toSeq + var colNames = indexedColumns.keys.toSeq var job = df.getOrElse(spark.read.table(tableName)) + // Add optional ID column + if (options.idExpression().isDefined) { + val idExpr = options.idExpression().get + + logInfo(s"Generate ID column based on expression $idExpr") + job = job.withColumn(ID_COLUMN, expr(idExpr)) + colNames = colNames :+ ID_COLUMN + } else { + val idColNames = + spark + .table(tableName) + .columns + .toSet + .intersect(Set("timestamp", "@timestamp")) + + if (idColNames.isEmpty) { + logWarning("Cannot generate ID column which may cause duplicate data when restart") + } else { + logInfo(s"Generate ID column based on first column in $idColNames") + job = job.withColumn(ID_COLUMN, sha1(concat(input_file_name(), col(idColNames.head)))) + colNames = colNames :+ ID_COLUMN + } + } + // Add optional filtering condition - job = filterCondition + filterCondition .map(job.where) .getOrElse(job) .select(colNames.head, colNames.tail: _*) - - // Add optional ID column - val uniqueColNames = - spark - .table(tableName) - .columns - .toSet - .intersect(Set("timestamp", "@timestamp")) - - if (uniqueColNames.nonEmpty) { - logInfo(s"Generate ID column based on first column in $uniqueColNames") - job = job.withColumn(ID_COLUMN, sha1(concat(input_file_name(), col(uniqueColNames.head)))) - } else { - logWarning("Cannot generate ID column which may cause duplicate data when restart") - } - job } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala index 8c144b46b..fe7df433b 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala @@ -5,15 +5,18 @@ package org.opensearch.flint.spark.covering -import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper +import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN +import org.scalatest.matchers.should.Matchers import org.apache.spark.FlintSuite +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.{col, concat, input_file_name, sha1} -class FlintSparkCoveringIndexSuite extends FlintSuite { +class FlintSparkCoveringIndexSuite extends FlintSuite with Matchers { test("get covering index name") { val index = - new FlintSparkCoveringIndex("ci", "spark_catalog.default.test", Map("name" -> "string")) + FlintSparkCoveringIndex("ci", "spark_catalog.default.test", Map("name" -> "string")) index.name() shouldBe "flint_spark_catalog_default_test_ci_index" } @@ -26,7 +29,26 @@ class FlintSparkCoveringIndexSuite extends FlintSuite { test("should fail if no indexed column given") { assertThrows[IllegalArgumentException] { - new FlintSparkCoveringIndex("ci", "default.test", Map.empty) + FlintSparkCoveringIndex("ci", "default.test", Map.empty) } } + + test("should generate id column based on timestamp column") { + val testTable = "spark_catalog.default.ci_test" + withTable(testTable) { + sql(s"CREATE TABLE $testTable (timestamp TIMESTAMP, name STRING) USING JSON") + val index = FlintSparkCoveringIndex("name_idx", testTable, Map("name" -> "string")) + + assertDataFrameEquals( + index.build(spark, None), + spark + .table(testTable) + .withColumn(ID_COLUMN, sha1(concat(input_file_name(), col("timestamp")))) + .select(col("name"), col(ID_COLUMN))) + } + } + + private def assertDataFrameEquals(df1: DataFrame, df2: DataFrame): Unit = { + comparePlans(df1.queryExecution.logical, df2.queryExecution.logical, checkAnalysis = false) + } }