From 4bc67a6119edb015c542a7c6ce11b76ee9efe82b Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Wed, 1 Nov 2023 15:03:28 -0700 Subject: [PATCH] Refactor build logic Signed-off-by: Chen Dai --- .../covering/FlintSparkCoveringIndex.scala | 28 ++++++++++--------- .../FlintSparkCoveringIndexSuite.scala | 19 +++++++++++++ 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala index 27d971fc0..802f4d818 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala @@ -66,21 +66,19 @@ case class FlintSparkCoveringIndex( var job = df.getOrElse(spark.read.table(tableName)) // Add optional ID column - if (options.idExpression().isDefined) { - val idExpr = options.idExpression().get - logInfo(s"Generate ID column based on expression $idExpr") - - job = job.withColumn(ID_COLUMN, expr(idExpr)) + val idColumn = + options + .idExpression() + .map(idExpr => Some(expr(idExpr))) + .getOrElse(findTimestampColumn(job) + .map(tsCol => sha1(concat(input_file_name(), col(tsCol))))) + + if (idColumn.isDefined) { + logInfo(s"Generate ID column based on expression $idColumn") colNames = colNames :+ ID_COLUMN + job = job.withColumn(ID_COLUMN, idColumn.get) } else { - val idColNames = job.columns.toSet.intersect(Set("timestamp", "@timestamp")) - if (idColNames.isEmpty) { - logWarning("Cannot generate ID column which may cause duplicate data when restart") - } else { - logInfo(s"Generate ID column based on first column in $idColNames") - job = job.withColumn(ID_COLUMN, sha1(concat(input_file_name(), col(idColNames.head)))) - colNames = colNames :+ ID_COLUMN - } + logWarning("Cannot generate ID column which may cause duplicate data when restart") } // Add optional filtering condition @@ -89,6 +87,10 @@ case class FlintSparkCoveringIndex( .getOrElse(job) .select(colNames.head, colNames.tail: _*) } + + private def findTimestampColumn(df: DataFrame): Option[String] = { + df.columns.toSet.intersect(Set("timestamp", "@timestamp")).headOption + } } object FlintSparkCoveringIndex { diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala index f4cacd385..f1f10fe26 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala @@ -121,6 +121,25 @@ class FlintSparkCoveringIndexSuite extends FlintSuite with Matchers { } } + test("should build with filtering condition") { + withTable(testTable) { + sql(s"CREATE TABLE $testTable (timestamp TIMESTAMP, name STRING) USING JSON") + val index = FlintSparkCoveringIndex( + "name_idx", + testTable, + Map("name" -> "string"), + Some("name = 'test'")) + + assertDataFrameEquals( + index.build(spark, None), + spark + .table(testTable) + .withColumn(ID_COLUMN, sha1(concat(input_file_name(), col("timestamp")))) + .where("name = 'test'") + .select(col("name"), col(ID_COLUMN))) + } + } + /* Assert unresolved logical plan in DataFrame equals without semantic analysis */ private def assertDataFrameEquals(df1: DataFrame, df2: DataFrame): Unit = { comparePlans(df1.queryExecution.logical, df2.queryExecution.logical, checkAnalysis = false)