Skip to content

Commit

Permalink
Refactor MV build stream
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Oct 16, 2023
1 parent 47d9ef5 commit 0a4132f
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.flint.spark.mv

import java.util.Locale

import scala.collection.JavaConverters.mapAsJavaMapConverter

import org.opensearch.flint.core.metadata.FlintMetadata
Expand All @@ -15,9 +17,10 @@ import org.opensearch.flint.spark.function.TumbleFunction
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{getFlintIndexName, MV_INDEX_TYPE}

import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, EventTimeWatermark}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, EventTimeWatermark, LogicalPlan}
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.flint.logicalPlanToDataFrame
import org.apache.spark.unsafe.types.UTF8String
Expand All @@ -43,7 +46,7 @@ case class FlintSparkMaterializedView(
with StreamingRefresh {

/** TODO: add it to index option */
private val watermarkDelay = UTF8String.fromString("0 Minute")
private val watermarkDelay = "0 Minute"

override val kind: String = MV_INDEX_TYPE

Expand All @@ -65,41 +68,60 @@ case class FlintSparkMaterializedView(
}

override def build(spark: SparkSession, df: Option[DataFrame]): DataFrame = {
require(df.isEmpty, "materialized view doesn't support reading from other table")
require(df.isEmpty, "materialized view doesn't support reading from other data frame")

spark.sql(query)
}

override def buildStream(spark: SparkSession): DataFrame = {
val batchPlan = spark.sql(query).queryExecution.logical

// Convert unresolved batch plan to streaming plan by:
// 1.Insert Watermark operator below Aggregate (required by Spark streaming)
// 2.Set isStreaming flag to true in Relation operator
val streamingPlan = batchPlan transform {
case WindowingAggregate(agg, timeCol) =>
agg.copy(child = watermark(timeCol, watermarkDelay, agg.child))

// Insert watermark operator between Aggregate and its child
case Aggregate(grouping, agg, child) =>
val timeCol = grouping.collect {
case UnresolvedFunction(identifier, args, _, _, _)
if identifier.mkString(".").equalsIgnoreCase(TumbleFunction.identifier.funcName) =>
args.head
}
case relation: UnresolvedRelation if !relation.isStreaming =>
relation.copy(isStreaming = true)
}
logicalPlanToDataFrame(spark, streamingPlan)
}

if (timeCol.isEmpty) {
throw new IllegalStateException(
"Windowing function is required for streaming aggregation")
}
Aggregate(
grouping,
agg,
EventTimeWatermark(
timeCol.head.asInstanceOf[Attribute],
IntervalUtils.stringToInterval(watermarkDelay),
child))

// Reset isStreaming flag in relation to true
case UnresolvedRelation(multipartIdentifier, options, _) =>
UnresolvedRelation(multipartIdentifier, options, isStreaming = true)
private def watermark(timeCol: Attribute, delay: String, child: LogicalPlan) = {
EventTimeWatermark(
timeCol,
IntervalUtils.stringToInterval(UTF8String.fromString(watermarkDelay)),
child)
}

private object WindowingAggregate {

def unapply(agg: Aggregate): Option[(Aggregate, Attribute)] = {
val winFuncs = agg.groupingExpressions.collect {
case func: UnresolvedFunction if isWindowingFunction(func) =>
func
}

if (winFuncs.size != 1) {
throw new IllegalStateException(
"A windowing function is required for streaming aggregation")
}

// Assume first aggregate item must be time column
val winFunc = winFuncs.head
val timeCol = winFunc.arguments.head.asInstanceOf[Attribute]
Some(agg, timeCol)
}

logicalPlanToDataFrame(spark, streamingPlan)
private def isWindowingFunction(func: UnresolvedFunction): Boolean = {
val funcName = func.nameParts.mkString(".").toLowerCase(Locale.ROOT)
val funcIdent = FunctionIdentifier(funcName)

// TODO: support other window functions
funcIdent == TumbleFunction.identifier
}
}
}

Expand Down Expand Up @@ -155,7 +177,7 @@ object FlintSparkMaterializedView {
}

override protected def buildIndex(): FlintSparkIndex = {
// TODO: need to change this and Flint DS to support complex field type
// TODO: change here and FlintDS class to support complex field type in future
val outputSchema = flint.spark
.sql(query)
.schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class FlintSparkMaterializedViewSuite extends FlintSuite {
}
}

test("build stream should fail if there is aggregation without windowing function") {
test("build stream should fail if there is aggregation but no windowing function") {
val testTable = "mv_build_test"
withTable(testTable) {
sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV")
Expand Down

0 comments on commit 0a4132f

Please sign in to comment.