diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala index fe5329739..220b11a1e 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala @@ -9,14 +9,16 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter import org.opensearch.flint.core.metadata.FlintMetadata -import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Column, DataFrame, SparkSession} import org.apache.spark.sql.flint.datatype.FlintDataType +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.StructType /** * Flint index interface in Spark. */ -trait FlintSparkIndex { +trait FlintSparkIndex extends Logging { /** * Index type @@ -55,7 +57,7 @@ trait FlintSparkIndex { def build(spark: SparkSession, df: Option[DataFrame]): DataFrame } -object FlintSparkIndex { +object FlintSparkIndex extends Logging { /** * Interface indicates a Flint index has custom streaming refresh capability other than foreach @@ -79,6 +81,32 @@ object FlintSparkIndex { */ val ID_COLUMN: String = "__id__" + /** + * Generate an ID column in the precedence below: (1) Use ID expression directly if provided in + * index option; (2) SHA-1 based on all aggregated columns if found; (3) SHA-1 based on source + * file path and timestamp column; 4) No ID column generated + * + * @param df + * data frame to generate ID column for + * @param idExpr + * ID expression option + * @return + * optional ID column expression + */ + def generateIdColumn(df: DataFrame, idExpr: Option[String]): Option[Column] = { + def timestampColumn: Option[String] = { + df.columns.toSet.intersect(Set("timestamp", "@timestamp")).headOption + } + + if (idExpr.isDefined) { + Some(expr(idExpr.get)) + } else if (timestampColumn.isDefined) { + Some(sha1(concat(input_file_name(), col(timestampColumn.get)))) + } else { + None + } + } + /** * Common prefix of Flint index name which is "flint_database_table_" * diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala index 802f4d818..5fbf581ce 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala @@ -9,13 +9,11 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter import org.opensearch.flint.core.metadata.FlintMetadata import org.opensearch.flint.spark._ -import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateSchemaJSON, metadataBuilder, ID_COLUMN} +import org.opensearch.flint.spark.FlintSparkIndex._ import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.{getFlintIndexName, COVERING_INDEX_TYPE} -import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.functions._ /** * Flint covering index in Spark. @@ -33,8 +31,7 @@ case class FlintSparkCoveringIndex( indexedColumns: Map[String, String], filterCondition: Option[String] = None, override val options: FlintSparkIndexOptions = empty) - extends FlintSparkIndex - with Logging { + extends FlintSparkIndex { require(indexedColumns.nonEmpty, "indexed columns must not be empty") @@ -66,13 +63,7 @@ case class FlintSparkCoveringIndex( var job = df.getOrElse(spark.read.table(tableName)) // Add optional ID column - val idColumn = - options - .idExpression() - .map(idExpr => Some(expr(idExpr))) - .getOrElse(findTimestampColumn(job) - .map(tsCol => sha1(concat(input_file_name(), col(tsCol))))) - + val idColumn = generateIdColumn(job, options.idExpression()) if (idColumn.isDefined) { logInfo(s"Generate ID column based on expression $idColumn") colNames = colNames :+ ID_COLUMN @@ -87,10 +78,6 @@ case class FlintSparkCoveringIndex( .getOrElse(job) .select(colNames.head, colNames.tail: _*) } - - private def findTimestampColumn(df: DataFrame): Option[String] = { - df.columns.toSet.intersect(Set("timestamp", "@timestamp")).headOption - } } object FlintSparkCoveringIndex {