Skip to content

Commit

Permalink
Add UT for CV and MV
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Nov 15, 2024
1 parent 5e0168b commit 21d09d0
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,18 @@ object FlintSparkIndex {
*/
def generateIdColumn(df: DataFrame, options: FlintSparkIndexOptions): DataFrame = {
// Assume output rows must be unique if a simple query plan has aggregate operator
def isAggregated: Boolean = {
def isAggregated: Boolean =
df.queryExecution.logical.exists(_.isInstanceOf[Aggregate])
}

val idExpr = options.idExpression()
if (idExpr.exists(_.nonEmpty)) {
df.withColumn(ID_COLUMN, expr(idExpr.get))
} else if (isAggregated) {
val allOutputCols = df.columns.map(col)
df.withColumn(ID_COLUMN, sha1(concat_ws("\0", allOutputCols: _*)))
} else {
df
options.idExpression() match {
case Some(idExpr) if idExpr.nonEmpty =>
df.withColumn(ID_COLUMN, expr(idExpr))

case None if isAggregated =>
val allOutputCols = df.columns.map(col)
df.withColumn(ID_COLUMN, sha1(concat_ws("\0", allOutputCols: _*)))

case _ => df
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@

package org.opensearch.flint.spark.covering

import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN
import org.opensearch.flint.spark.FlintSparkIndexOptions
import org.scalatest.matchers.must.Matchers.contain
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

import org.apache.spark.FlintSuite
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, expr}

class FlintSparkCoveringIndexSuite extends FlintSuite {

private val testTable = "spark_catalog.default.ci_test"

test("get covering index name") {
val index =
new FlintSparkCoveringIndex("ci", "spark_catalog.default.test", Map("name" -> "string"))
Expand Down Expand Up @@ -54,4 +60,84 @@ class FlintSparkCoveringIndexSuite extends FlintSuite {
new FlintSparkCoveringIndex("ci", "default.test", Map.empty)
}
}

test("build batch with ID expression option") {
withTable(testTable) {
sql(s"CREATE TABLE $testTable (timestamp TIMESTAMP, name STRING) USING JSON")
val index =
FlintSparkCoveringIndex(
"name_idx",
testTable,
Map("name" -> "string"),
options = FlintSparkIndexOptions(Map("id_expression" -> "name")))

comparePlans(
index.build(spark, None).queryExecution.logical,
spark
.table(testTable)
.select(col("name"))
.withColumn(ID_COLUMN, expr("name"))
.queryExecution
.logical,
checkAnalysis = false)
}
}

test("build batch should not have ID column without ID expression option") {
withTable(testTable) {
sql(s"CREATE TABLE $testTable (name STRING, age INTEGER) USING JSON")
val index = FlintSparkCoveringIndex("name_idx", testTable, Map("name" -> "string"))

comparePlans(
index.build(spark, None).queryExecution.logical,
spark
.table(testTable)
.select(col("name"))
.queryExecution
.logical,
checkAnalysis = false)
}
}

test("build stream with ID expression option") {
withTable(testTable) {
sql(s"CREATE TABLE $testTable (name STRING, age INTEGER) USING JSON")
val index = FlintSparkCoveringIndex(
"name_idx",
testTable,
Map("name" -> "string"),
options =
FlintSparkIndexOptions(Map("auto_refresh" -> "true", "id_expression" -> "name")))

comparePlans(
index.build(spark, Some(spark.table(testTable))).queryExecution.logical,
spark
.table(testTable)
.select("name")
.withColumn(ID_COLUMN, col("name"))
.queryExecution
.logical,
checkAnalysis = false)
}
}

test("build stream should not have ID column without ID expression option") {
withTable(testTable) {
sql(s"CREATE TABLE $testTable (name STRING, age INTEGER) USING JSON")
val index = FlintSparkCoveringIndex(
"name_idx",
testTable,
Map("name" -> "string"),
options = FlintSparkIndexOptions(Map("auto_refresh" -> "true")))

comparePlans(
index.build(spark, Some(spark.table(testTable))).queryExecution.logical,
spark
.table(testTable)
.select(col("name"))
.queryExecution
.logical,
checkAnalysis = false)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package org.opensearch.flint.spark.mv

import scala.collection.JavaConverters.{mapAsJavaMapConverter, mapAsScalaMapConverter}

import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN
import org.opensearch.flint.spark.FlintSparkIndexOptions
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE
import org.opensearch.flint.spark.mv.FlintSparkMaterializedViewSuite.{streamingRelation, StreamingDslLogicalPlan}
Expand All @@ -15,12 +16,14 @@ import org.scalatestplus.mockito.MockitoSugar.mock

import org.apache.spark.FlintSuite
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation}
import org.apache.spark.sql.catalyst.dsl.expressions.{intToLiteral, stringToLiteral, DslAttr, DslExpression, StringToAttributeConversionHelper}
import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, ConcatWs, Literal, Sha1}
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.functions.{col, concat_ws, expr, sha1}
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -107,7 +110,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite {
| FROM $testTable
| GROUP BY TUMBLE(time, '1 Minute')
|""".stripMargin
val options = Map("watermark_delay" -> "30 Seconds")
val options = Map("watermark_delay" -> "30 Seconds", "id_expression" -> "")

withAggregateMaterializedView(testQuery, Array(testTable), options) { actualPlan =>
comparePlans(
Expand All @@ -132,7 +135,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite {
| WHERE age > 30
| GROUP BY TUMBLE(time, '1 Minute')
|""".stripMargin
val options = Map("watermark_delay" -> "30 Seconds")
val options = Map("watermark_delay" -> "30 Seconds", "id_expression" -> "")

withAggregateMaterializedView(testQuery, Array(testTable), options) { actualPlan =>
comparePlans(
Expand Down Expand Up @@ -189,6 +192,142 @@ class FlintSparkMaterializedViewSuite extends FlintSuite {
}
}

test("build batch with ID expression option") {
withTable(testTable) {
sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV")
val testMvQuery = s"SELECT time, name FROM $testTable"
val mv = FlintSparkMaterializedView(
testMvName,
testMvQuery,
Array.empty,
Map.empty,
FlintSparkIndexOptions(Map("id_expression" -> "time")))

comparePlans(
mv.build(spark, None).queryExecution.logical,
spark
.sql(testMvQuery)
.withColumn(ID_COLUMN, expr("time"))
.queryExecution
.logical,
checkAnalysis = false)
}
}

test("build batch should not have ID column if non-aggregated") {
withTable(testTable) {
sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV")
val testMvQuery = s"SELECT time, name FROM $testTable"
val mv = FlintSparkMaterializedView(testMvName, testMvQuery, Array.empty, Map.empty)

comparePlans(
mv.build(spark, None).queryExecution.logical,
spark.sql(testMvQuery).queryExecution.logical,
checkAnalysis = false)
}
}

test("build batch should have ID column if aggregated") {
withTable(testTable) {
sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV")
val mv = FlintSparkMaterializedView(
testMvName,
s""" SELECT time, name, AVG(age)
| FROM $testTable
| GROUP BY time, name""".stripMargin,
Array.empty,
Map.empty)

comparePlans(
mv.build(spark, None).queryExecution.logical,
spark
.table(testTable)
.groupBy("time", "name")
.avg("age")
.withColumn(ID_COLUMN, sha1(concat_ws("\0", col("time"), col("name"), col("avg(age)"))))
.queryExecution
.logical,
checkAnalysis = false)
}
}

test("build stream with ID expression option") {
withTable(testTable) {
sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV")
val mv = FlintSparkMaterializedView(
testMvName,
s"SELECT time, name FROM $testTable",
Array.empty,
Map.empty,
FlintSparkIndexOptions(Map("auto_refresh" -> "true", "id_expression" -> "name")))

mv.buildStream(spark).queryExecution.logical.exists {
case Project(projectList, _) =>
projectList.exists {
case Alias(UnresolvedAttribute(Seq("name")), ID_COLUMN) => true
case _ => false
}
case _ => false
} shouldBe true
}
}

test("build stream should not have ID column if non-aggregated") {
withTable(testTable) {
sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV")
val mv = FlintSparkMaterializedView(
testMvName,
s"SELECT time, name FROM $testTable",
Array.empty,
Map.empty,
FlintSparkIndexOptions(Map("auto_refresh" -> "true")))

mv.buildStream(spark).queryExecution.logical.exists {
case Project(projectList, _) =>
projectList.forall(_.name != ID_COLUMN)
case _ => false
} shouldBe true
}
}

test("build stream should have ID column if aggregated") {
withTable(testTable) {
sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV")
val testMvQuery =
s"""
| SELECT
| window.start AS startTime,
| COUNT(*) AS count
| FROM $testTable
| GROUP BY TUMBLE(time, '1 Minute')
|""".stripMargin
val mv = FlintSparkMaterializedView(
testMvName,
testMvQuery,
Array.empty,
Map.empty,
FlintSparkIndexOptions(Map("auto_refresh" -> "true", "watermark_delay" -> "10 Seconds")))

mv.buildStream(spark).queryExecution.logical.exists {
case Project(projectList, _) =>
val asciiNull = UTF8String.fromString("\0")
projectList.exists {
case Alias(
Sha1(
ConcatWs(
Seq(
Literal(`asciiNull`, StringType),
UnresolvedAttribute(Seq("startTime")),
UnresolvedAttribute(Seq("count"))))),
ID_COLUMN) =>
true
case _ => false
}
case _ => false
} shouldBe true
}
}

private def withAggregateMaterializedView(
query: String,
sourceTables: Array[String],
Expand Down

0 comments on commit 21d09d0

Please sign in to comment.