diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala index 9bddfaf22..56a928aa5 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala @@ -11,8 +11,9 @@ import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.core.metadata.FlintJsonHelper._ -import org.apache.spark.sql.{Column, DataFrame, SparkSession} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.flint.datatype.FlintDataType import org.apache.spark.sql.functions.{col, concat_ws, expr, sha1, to_json} import org.apache.spark.sql.types.{MapType, StructType} @@ -64,7 +65,7 @@ trait FlintSparkIndex { def build(spark: SparkSession, df: Option[DataFrame]): DataFrame } -object FlintSparkIndex { +object FlintSparkIndex extends Logging { /** * Interface indicates a Flint index has custom streaming refresh capability other than foreach @@ -134,13 +135,13 @@ object FlintSparkIndex { * @return * DataFrame with/without ID column */ - def generateIdColumn(df: DataFrame, options: FlintSparkIndexOptions): DataFrame = { - // Assume output rows must be unique if a simple query plan has aggregate operator + def addIdColumn(df: DataFrame, options: FlintSparkIndexOptions): DataFrame = { def isAggregated: Boolean = df.queryExecution.logical.exists(_.isInstanceOf[Aggregate]) options.idExpression() match { case Some(idExpr) if idExpr.nonEmpty => + logInfo(s"Using user-provided ID expression: $idExpr") df.withColumn(ID_COLUMN, expr(idExpr)) case None if isAggregated => @@ -156,7 +157,9 @@ object FlintSparkIndex { } // TODO: 1) use only grouping columns; 2) ensure aggregation is on top level - df.withColumn(ID_COLUMN, sha1(concat_ws("\0", allOutputCols: _*))) + val idCol = sha1(concat_ws("\0", allOutputCols: _*)) + logInfo(s"Generated ID column for aggregated query: $idCol") + df.withColumn(ID_COLUMN, idCol) case _ => df } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala index 66cb59869..901c3006c 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala @@ -10,7 +10,7 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.spark._ -import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateIdColumn, generateSchema, metadataBuilder, quotedTableName, ID_COLUMN} +import org.opensearch.flint.spark.FlintSparkIndex.{addIdColumn, flintIndexNamePrefix, generateSchema, metadataBuilder, quotedTableName} import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.{getFlintIndexName, COVERING_INDEX_TYPE} @@ -77,7 +77,7 @@ case class FlintSparkCoveringIndex( .getOrElse(job) .select(colNames.head, colNames.tail: _*) - generateIdColumn(batchDf, options) + addIdColumn(batchDf, options) } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala index 3892ccb82..e3b09661a 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala @@ -13,7 +13,7 @@ import scala.collection.convert.ImplicitConversions.`map AsScala` import org.opensearch.flint.common.metadata.FlintMetadata import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndex, FlintSparkIndexBuilder, FlintSparkIndexOptions} -import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateIdColumn, generateSchema, metadataBuilder, ID_COLUMN, StreamingRefresh} +import org.opensearch.flint.spark.FlintSparkIndex.{addIdColumn, flintIndexNamePrefix, generateSchema, metadataBuilder, ID_COLUMN, StreamingRefresh} import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.opensearch.flint.spark.function.TumbleFunction import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{getFlintIndexName, MV_INDEX_TYPE} @@ -82,7 +82,7 @@ case class FlintSparkMaterializedView( require(df.isEmpty, "materialized view doesn't support reading from other data frame") val batchDf = spark.sql(query) - generateIdColumn(batchDf, options) + addIdColumn(batchDf, options) } override def buildStream(spark: SparkSession): DataFrame = { @@ -102,7 +102,7 @@ case class FlintSparkMaterializedView( } val streamingDf = logicalPlanToDataFrame(spark, streamingPlan) - generateIdColumn(streamingDf, options) + addIdColumn(streamingDf, options) } private def watermark(timeCol: Attribute, child: LogicalPlan) = { diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala index 25e1ed591..8415613e8 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkIndexSuite.scala @@ -5,9 +5,7 @@ package org.opensearch.flint.spark -import java.sql.Timestamp - -import org.opensearch.flint.spark.FlintSparkIndex.{generateIdColumn, ID_COLUMN} +import org.opensearch.flint.spark.FlintSparkIndex.{addIdColumn, ID_COLUMN} import org.scalatest.matchers.should.Matchers import org.apache.spark.FlintSuite @@ -16,31 +14,31 @@ import org.apache.spark.sql.types.StructType class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { - test("should generate ID column if ID expression is provided") { + test("should add ID column if ID expression is provided") { val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name") val options = new FlintSparkIndexOptions(Map("id_expression" -> "id + 10")) - val resultDf = generateIdColumn(df, options) + val resultDf = addIdColumn(df, options) checkAnswer(resultDf.select(ID_COLUMN), Seq(Row(11), Row(12))) } - test("should not generate ID column if ID expression is not provided") { + test("should not add ID column if ID expression is not provided") { val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name") val options = FlintSparkIndexOptions.empty - val resultDf = generateIdColumn(df, options) + val resultDf = addIdColumn(df, options) resultDf.columns should not contain ID_COLUMN } - test("should not generate ID column if ID expression is empty") { + test("should not add ID column if ID expression is empty") { val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name") val options = FlintSparkIndexOptions.empty - val resultDf = generateIdColumn(df, options) + val resultDf = addIdColumn(df, options) resultDf.columns should not contain ID_COLUMN } - test("should generate ID column for aggregated query") { + test("should add ID column for aggregated query") { val df = spark .createDataFrame(Seq((1, "Alice"), (2, "Bob"), (3, "Alice"))) .toDF("id", "name") @@ -48,7 +46,7 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { .count() val options = FlintSparkIndexOptions.empty - val resultDf = generateIdColumn(df, options) + val resultDf = addIdColumn(df, options) resultDf.columns should contain(ID_COLUMN) resultDf.select(ID_COLUMN).distinct().count() shouldBe 2 } @@ -93,7 +91,7 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers { .count() val options = FlintSparkIndexOptions.empty - val resultDf = generateIdColumn(aggregatedDf, options) + val resultDf = addIdColumn(aggregatedDf, options) resultDf.columns should contain(ID_COLUMN) resultDf.select(ID_COLUMN).distinct().count() shouldBe 1 } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala index 75a219fb6..eb255e493 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala @@ -14,7 +14,6 @@ import org.apache.spark.FlintSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.Alias import org.apache.spark.sql.catalyst.plans.logical.Project -import org.apache.spark.sql.functions.{col, expr} class FlintSparkCoveringIndexSuite extends FlintSuite { diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala index 9acef8c25..0791f9b7a 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala @@ -119,6 +119,10 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite { val indexData = flint.queryIndex(testFlintTimeSeriesIndex) indexData.count() shouldBe 3 // only 3 rows left due to same ID + + // Rerun should not generate duplicate data + sql(s"REFRESH INDEX $testIndex ON $testTimeSeriesTable") + indexData.count() shouldBe 3 } test("create covering index with index settings") { diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala index 14a50e81e..dc8b3e1a5 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewSqlITSuite.scala @@ -191,7 +191,11 @@ class FlintSparkMaterializedViewSqlITSuite extends FlintSparkSuite { sql(s"REFRESH MATERIALIZED VIEW $testMvName") // 2 rows missing due to ID conflict intentionally - val indexData = spark.read.format(FLINT_DATASOURCE).load(testFlintIndex) + val indexData = flint.queryIndex(testFlintIndex) + indexData.count() shouldBe 2 + + // Rerun should not generate duplicate data + sql(s"REFRESH MATERIALIZED VIEW $testMvName") indexData.count() shouldBe 2 }