Skip to content

Commit

Permalink
Add logging and more IT
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Nov 22, 2024
1 parent 962e25d commit 11d76cb
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ import org.opensearch.flint.common.metadata.FlintMetadata
import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry
import org.opensearch.flint.core.metadata.FlintJsonHelper._

import org.apache.spark.sql.{Column, DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.flint.datatype.FlintDataType
import org.apache.spark.sql.functions.{col, concat_ws, expr, sha1, to_json}
import org.apache.spark.sql.types.{MapType, StructType}
Expand Down Expand Up @@ -64,7 +65,7 @@ trait FlintSparkIndex {
def build(spark: SparkSession, df: Option[DataFrame]): DataFrame
}

object FlintSparkIndex {
object FlintSparkIndex extends Logging {

/**
* Interface indicates a Flint index has custom streaming refresh capability other than foreach
Expand Down Expand Up @@ -134,13 +135,13 @@ object FlintSparkIndex {
* @return
* DataFrame with/without ID column
*/
def generateIdColumn(df: DataFrame, options: FlintSparkIndexOptions): DataFrame = {
// Assume output rows must be unique if a simple query plan has aggregate operator
def addIdColumn(df: DataFrame, options: FlintSparkIndexOptions): DataFrame = {
def isAggregated: Boolean =
df.queryExecution.logical.exists(_.isInstanceOf[Aggregate])

options.idExpression() match {
case Some(idExpr) if idExpr.nonEmpty =>
logInfo(s"Using user-provided ID expression: $idExpr")
df.withColumn(ID_COLUMN, expr(idExpr))

case None if isAggregated =>
Expand All @@ -156,7 +157,9 @@ object FlintSparkIndex {
}

// TODO: 1) use only grouping columns; 2) ensure aggregation is on top level
df.withColumn(ID_COLUMN, sha1(concat_ws("\0", allOutputCols: _*)))
val idCol = sha1(concat_ws("\0", allOutputCols: _*))
logInfo(s"Generated ID column for aggregated query: $idCol")
df.withColumn(ID_COLUMN, idCol)

case _ => df
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter
import org.opensearch.flint.common.metadata.FlintMetadata
import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry
import org.opensearch.flint.spark._
import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateIdColumn, generateSchema, metadataBuilder, quotedTableName, ID_COLUMN}
import org.opensearch.flint.spark.FlintSparkIndex.{addIdColumn, flintIndexNamePrefix, generateSchema, metadataBuilder, quotedTableName}
import org.opensearch.flint.spark.FlintSparkIndexOptions.empty
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.{getFlintIndexName, COVERING_INDEX_TYPE}

Expand Down Expand Up @@ -77,7 +77,7 @@ case class FlintSparkCoveringIndex(
.getOrElse(job)
.select(colNames.head, colNames.tail: _*)

generateIdColumn(batchDf, options)
addIdColumn(batchDf, options)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import scala.collection.convert.ImplicitConversions.`map AsScala`
import org.opensearch.flint.common.metadata.FlintMetadata
import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry
import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndex, FlintSparkIndexBuilder, FlintSparkIndexOptions}
import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateIdColumn, generateSchema, metadataBuilder, ID_COLUMN, StreamingRefresh}
import org.opensearch.flint.spark.FlintSparkIndex.{addIdColumn, flintIndexNamePrefix, generateSchema, metadataBuilder, ID_COLUMN, StreamingRefresh}
import org.opensearch.flint.spark.FlintSparkIndexOptions.empty
import org.opensearch.flint.spark.function.TumbleFunction
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{getFlintIndexName, MV_INDEX_TYPE}
Expand Down Expand Up @@ -82,7 +82,7 @@ case class FlintSparkMaterializedView(
require(df.isEmpty, "materialized view doesn't support reading from other data frame")

val batchDf = spark.sql(query)
generateIdColumn(batchDf, options)
addIdColumn(batchDf, options)
}

override def buildStream(spark: SparkSession): DataFrame = {
Expand All @@ -102,7 +102,7 @@ case class FlintSparkMaterializedView(
}

val streamingDf = logicalPlanToDataFrame(spark, streamingPlan)
generateIdColumn(streamingDf, options)
addIdColumn(streamingDf, options)
}

private def watermark(timeCol: Attribute, child: LogicalPlan) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@

package org.opensearch.flint.spark

import java.sql.Timestamp

import org.opensearch.flint.spark.FlintSparkIndex.{generateIdColumn, ID_COLUMN}
import org.opensearch.flint.spark.FlintSparkIndex.{addIdColumn, ID_COLUMN}
import org.scalatest.matchers.should.Matchers

import org.apache.spark.FlintSuite
Expand All @@ -16,39 +14,39 @@ import org.apache.spark.sql.types.StructType

class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers {

test("should generate ID column if ID expression is provided") {
test("should add ID column if ID expression is provided") {
val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name")
val options = new FlintSparkIndexOptions(Map("id_expression" -> "id + 10"))

val resultDf = generateIdColumn(df, options)
val resultDf = addIdColumn(df, options)
checkAnswer(resultDf.select(ID_COLUMN), Seq(Row(11), Row(12)))
}

test("should not generate ID column if ID expression is not provided") {
test("should not add ID column if ID expression is not provided") {
val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name")
val options = FlintSparkIndexOptions.empty

val resultDf = generateIdColumn(df, options)
val resultDf = addIdColumn(df, options)
resultDf.columns should not contain ID_COLUMN
}

test("should not generate ID column if ID expression is empty") {
test("should not add ID column if ID expression is empty") {
val df = spark.createDataFrame(Seq((1, "Alice"), (2, "Bob"))).toDF("id", "name")
val options = FlintSparkIndexOptions.empty

val resultDf = generateIdColumn(df, options)
val resultDf = addIdColumn(df, options)
resultDf.columns should not contain ID_COLUMN
}

test("should generate ID column for aggregated query") {
test("should add ID column for aggregated query") {
val df = spark
.createDataFrame(Seq((1, "Alice"), (2, "Bob"), (3, "Alice")))
.toDF("id", "name")
.groupBy("name")
.count()
val options = FlintSparkIndexOptions.empty

val resultDf = generateIdColumn(df, options)
val resultDf = addIdColumn(df, options)
resultDf.columns should contain(ID_COLUMN)
resultDf.select(ID_COLUMN).distinct().count() shouldBe 2
}
Expand Down Expand Up @@ -93,7 +91,7 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers {
.count()
val options = FlintSparkIndexOptions.empty

val resultDf = generateIdColumn(aggregatedDf, options)
val resultDf = addIdColumn(aggregatedDf, options)
resultDf.columns should contain(ID_COLUMN)
resultDf.select(ID_COLUMN).distinct().count() shouldBe 1
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import org.apache.spark.FlintSuite
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.functions.{col, expr}

class FlintSparkCoveringIndexSuite extends FlintSuite {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite {

val indexData = flint.queryIndex(testFlintTimeSeriesIndex)
indexData.count() shouldBe 3 // only 3 rows left due to same ID

// Rerun should not generate duplicate data
sql(s"REFRESH INDEX $testIndex ON $testTimeSeriesTable")
indexData.count() shouldBe 3
}

test("create covering index with index settings") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,11 @@ class FlintSparkMaterializedViewSqlITSuite extends FlintSparkSuite {
sql(s"REFRESH MATERIALIZED VIEW $testMvName")

// 2 rows missing due to ID conflict intentionally
val indexData = spark.read.format(FLINT_DATASOURCE).load(testFlintIndex)
val indexData = flint.queryIndex(testFlintIndex)
indexData.count() shouldBe 2

// Rerun should not generate duplicate data
sql(s"REFRESH MATERIALIZED VIEW $testMvName")
indexData.count() shouldBe 2
}

Expand Down

0 comments on commit 11d76cb

Please sign in to comment.