Skip to content

Commit

Permalink
Remove auto gen logic for MV
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Dec 20, 2024
1 parent b1fc848 commit 15ed31b
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 196 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@ import org.opensearch.flint.core.metadata.FlintJsonHelper._

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.catalyst.util.quoteIfNeeded
import org.apache.spark.sql.flint.datatype.FlintDataType
import org.apache.spark.sql.functions.{col, concat_ws, expr, sha1, to_json}
import org.apache.spark.sql.types.{MapType, StructType}
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.types.StructType

/**
* Flint index interface in Spark.
Expand Down Expand Up @@ -137,31 +135,10 @@ object FlintSparkIndex extends Logging {
* DataFrame with/without ID column
*/
def addIdColumn(df: DataFrame, options: FlintSparkIndexOptions): DataFrame = {
def isAggregated: Boolean =
df.queryExecution.logical.exists(_.isInstanceOf[Aggregate])

options.idExpression() match {
case Some(idExpr) if idExpr.nonEmpty =>
logInfo(s"Using user-provided ID expression: $idExpr")
df.withColumn(ID_COLUMN, expr(idExpr))

case None if isAggregated =>
// Since concat doesn't support struct or map type, convert these to json which is more
// deterministic than casting to string, as its format may vary across Spark versions.
val allOutputCols = df.schema.fields.map { field =>
field.dataType match {
case _: StructType | _: MapType =>
to_json(col(quoteIfNeeded(field.name)))
case _ =>
col(quoteIfNeeded(field.name))
}
}

// TODO: 1) use only grouping columns; 2) ensure aggregation is on top level
val idCol = sha1(concat_ws("\0", allOutputCols: _*))
logInfo(s"Generated ID column for aggregated query: $idCol")
df.withColumn(ID_COLUMN, idCol)

case _ => df
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ import org.scalatest.matchers.should.Matchers

import org.apache.spark.FlintSuite
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction}
import org.apache.spark.sql.catalyst.expressions.{Add, ConcatWs, Literal, Sha1, StructsToJson}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -43,49 +42,6 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers {
resultDf.columns should not contain ID_COLUMN
}

test("should add ID column for aggregated query") {
val df = spark
.createDataFrame(Seq((1, "Alice"), (2, "Bob"), (3, "Alice")))
.toDF("id", "name")
.groupBy("name")
.count()
val options = FlintSparkIndexOptions.empty

val resultDf = addIdColumn(df, options)
resultDf.idColumn() shouldBe Some(
Sha1(
ConcatWs(
Seq(
Literal(UTF8String.fromString("\0"), StringType),
UnresolvedAttribute(Seq("name")),
UnresolvedAttribute(Seq("count"))))))
resultDf.select(ID_COLUMN).distinct().count() shouldBe 2
}

test("should add ID column for aggregated query with quoted alias") {
val df = spark
.createDataFrame(
sparkContext.parallelize(
Seq(
Row(1, "Alice", Row("WA", "Seattle")),
Row(2, "Bob", Row("OR", "Portland")),
Row(3, "Alice", Row("WA", "Seattle")))),
StructType.fromDDL("id INT, name STRING, address STRUCT<state: STRING, city: String>"))
.toDF("id", "name", "address")
.groupBy(col("name").as("test.name"), col("address").as("test.address"))
.count()
val options = FlintSparkIndexOptions.empty

val resultDf = addIdColumn(df, options)
resultDf.idColumn() shouldBe Some(
Sha1(ConcatWs(Seq(
Literal(UTF8String.fromString("\0"), StringType),
UnresolvedAttribute(Seq("test.name")),
new StructsToJson(UnresolvedAttribute(Seq("test.address"))),
UnresolvedAttribute(Seq("count"))))))
resultDf.select(ID_COLUMN).distinct().count() shouldBe 2
}

test("should generate ID column for various column types") {
val schema = StructType.fromDDL("""
boolean_col BOOLEAN,
Expand Down Expand Up @@ -124,23 +80,32 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers {
"struct_col",
"struct_col.subfield2")
.count()
val options = FlintSparkIndexOptions.empty
val options = FlintSparkIndexOptions(Map("id_expression" ->
"sha1(concat_ws('\0',boolean_col,string_col,long_col,int_col,double_col,float_col,timestamp_col,date_col,to_json(struct_col),struct_col.subfield2))"))

val resultDf = addIdColumn(aggregatedDf, options)
resultDf.idColumn() shouldBe Some(
Sha1(ConcatWs(Seq(
Literal(UTF8String.fromString("\0"), StringType),
UnresolvedAttribute(Seq("boolean_col")),
UnresolvedAttribute(Seq("string_col")),
UnresolvedAttribute(Seq("long_col")),
UnresolvedAttribute(Seq("int_col")),
UnresolvedAttribute(Seq("double_col")),
UnresolvedAttribute(Seq("float_col")),
UnresolvedAttribute(Seq("timestamp_col")),
UnresolvedAttribute(Seq("date_col")),
new StructsToJson(UnresolvedAttribute(Seq("struct_col"))),
UnresolvedAttribute(Seq("subfield2")),
UnresolvedAttribute(Seq("count"))))))
UnresolvedFunction(
"sha1",
Seq(UnresolvedFunction(
"concat_ws",
Seq(
Literal(UTF8String.fromString("\0"), StringType),
UnresolvedAttribute(Seq("boolean_col")),
UnresolvedAttribute(Seq("string_col")),
UnresolvedAttribute(Seq("long_col")),
UnresolvedAttribute(Seq("int_col")),
UnresolvedAttribute(Seq("double_col")),
UnresolvedAttribute(Seq("float_col")),
UnresolvedAttribute(Seq("timestamp_col")),
UnresolvedAttribute(Seq("date_col")),
UnresolvedFunction(
"to_json",
Seq(UnresolvedAttribute(Seq("struct_col"))),
isDistinct = false),
UnresolvedAttribute(Seq("struct_col", "subfield2"))),
isDistinct = false)),
isDistinct = false))
resultDf.select(ID_COLUMN).distinct().count() shouldBe 1
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@ import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation}
import org.apache.spark.sql.catalyst.dsl.expressions.{intToLiteral, stringToLiteral, DslAttr, DslExpression, StringToAttributeConversionHelper}
import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan
import org.apache.spark.sql.catalyst.expressions.{Attribute, ConcatWs, Literal, Sha1}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan}
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -118,7 +117,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite {
| FROM $testTable
| GROUP BY TUMBLE(time, '1 Minute')
|""".stripMargin
val options = Map("watermark_delay" -> "30 Seconds", "id_expression" -> "")
val options = Map("watermark_delay" -> "30 Seconds")

withAggregateMaterializedView(testQuery, Array(testTable), options) { actualPlan =>
comparePlans(
Expand All @@ -143,7 +142,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite {
| WHERE age > 30
| GROUP BY TUMBLE(time, '1 Minute')
|""".stripMargin
val options = Map("watermark_delay" -> "30 Seconds", "id_expression" -> "")
val options = Map("watermark_delay" -> "30 Seconds")

withAggregateMaterializedView(testQuery, Array(testTable), options) { actualPlan =>
comparePlans(
Expand Down Expand Up @@ -208,70 +207,14 @@ class FlintSparkMaterializedViewSuite extends FlintSuite {
batchDf.idColumn() shouldBe Some(UnresolvedAttribute(Seq("time")))
}

test("build batch should not have ID column if non-aggregated") {
test("build batch should not have ID column if not provided") {
val testMvQuery = s"SELECT time, name FROM $testTable"
val mv = FlintSparkMaterializedView(testMvName, testMvQuery, Array.empty, Map.empty)

val batchDf = mv.build(spark, None)
batchDf.idColumn() shouldBe None
}

test("build batch should have ID column if aggregated") {
val mv = FlintSparkMaterializedView(
testMvName,
s""" SELECT time, name, AVG(age) AS avg
| FROM $testTable
| GROUP BY time, name""".stripMargin,
Array.empty,
Map.empty)

val batchDf = mv.build(spark, None)
batchDf.idColumn() shouldBe Some(
Sha1(
ConcatWs(
Seq(
Literal(UTF8String.fromString("\0"), StringType),
UnresolvedAttribute(Seq("time")),
UnresolvedAttribute(Seq("name")),
UnresolvedAttribute(Seq("avg"))))))
}

test("build batch should not have ID column if aggregated with ID expression empty") {
val mv = FlintSparkMaterializedView(
testMvName,
s""" SELECT time, name, AVG(age) AS avg
| FROM $testTable
| GROUP BY time, name""".stripMargin,
Array.empty,
Map.empty,
FlintSparkIndexOptions(Map("id_expression" -> "")))

val batchDf = mv.build(spark, None)
batchDf.idColumn() shouldBe None
}

test("build batch should have ID column if aggregated join") {
val mv = FlintSparkMaterializedView(
testMvName,
s""" SELECT t1.time, t1.name, AVG(t1.age) AS avg
| FROM $testTable AS t1
| JOIN $testTable AS t2
| ON t1.time = t2.time
| GROUP BY t1.time, t1.name""".stripMargin,
Array.empty,
Map.empty)

val batchDf = mv.build(spark, None)
batchDf.idColumn() shouldBe Some(
Sha1(
ConcatWs(
Seq(
Literal(UTF8String.fromString("\0"), StringType),
UnresolvedAttribute(Seq("time")),
UnresolvedAttribute(Seq("name")),
UnresolvedAttribute(Seq("avg"))))))
}

test("build stream with ID expression option") {
val mv = FlintSparkMaterializedView(
testMvName,
Expand All @@ -284,7 +227,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite {
streamDf.idColumn() shouldBe Some(UnresolvedAttribute(Seq("time")))
}

test("build stream should not have ID column if non-aggregated") {
test("build stream should not have ID column if not provided") {
val mv = FlintSparkMaterializedView(
testMvName,
s"SELECT time, name FROM $testTable",
Expand All @@ -296,53 +239,6 @@ class FlintSparkMaterializedViewSuite extends FlintSuite {
streamDf.idColumn() shouldBe None
}

test("build stream should have ID column if aggregated") {
val testMvQuery =
s"""
| SELECT
| window.start AS startTime,
| COUNT(*) AS count
| FROM $testTable
| GROUP BY TUMBLE(time, '1 Minute')
|""".stripMargin
val mv = FlintSparkMaterializedView(
testMvName,
testMvQuery,
Array.empty,
Map.empty,
FlintSparkIndexOptions(Map("auto_refresh" -> "true", "watermark_delay" -> "10 Seconds")))

val streamDf = mv.buildStream(spark)
streamDf.idColumn() shouldBe Some(
Sha1(
ConcatWs(
Seq(
Literal(UTF8String.fromString("\0"), StringType),
UnresolvedAttribute(Seq("startTime")),
UnresolvedAttribute(Seq("count"))))))
}

test("build stream should not have ID column if aggregated with ID expression empty") {
val testMvQuery =
s"""
| SELECT
| window.start AS startTime,
| COUNT(*) AS count
| FROM $testTable
| GROUP BY TUMBLE(time, '1 Minute')
|""".stripMargin
val mv = FlintSparkMaterializedView(
testMvName,
testMvQuery,
Array.empty,
Map.empty,
FlintSparkIndexOptions(
Map("auto_refresh" -> "true", "watermark_delay" -> "10 Seconds", "id_expression" -> "")))

val streamDf = mv.buildStream(spark)
streamDf.idColumn() shouldBe None
}

private def withAggregateMaterializedView(
query: String,
sourceTables: Array[String],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class FlintSparkMaterializedViewSqlITSuite extends FlintSparkSuite {
| auto_refresh = true,
| checkpoint_location = '${checkpointDir.getAbsolutePath}',
| watermark_delay = '1 Second',
| id_expression = 'count'
| id_expression = "sha1(concat_ws('\0',startTime))"
| )
|""".stripMargin)

Expand All @@ -174,8 +174,7 @@ class FlintSparkMaterializedViewSqlITSuite extends FlintSparkSuite {
job.get.processAllAvailable()
}

// 1 row missing due to ID conflict intentionally
flint.queryIndex(testFlintIndex).count() shouldBe 2
flint.queryIndex(testFlintIndex).count() shouldBe 3
}
}

Expand Down

0 comments on commit 15ed31b

Please sign in to comment.