diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala index f3066ec9f..a5c99752f 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala @@ -140,11 +140,12 @@ class FlintSpark(val spark: SparkSession) { case INCREMENTAL => // TODO: Use Foreach sink for now. Need to move this to FlintSparkSkippingIndex // once finalized. Otherwise, covering index/MV may have different logic. - val job = spark.readStream - .table(tableName) - .writeStream - .queryName(indexName) - .outputMode(Append()) + val job = + index.buildStream(spark) + .queryName(indexName) + .outputMode(Append()) + .format(FLINT_DATASOURCE) + .options(flintSparkConf.properties) index.options .checkpointLocation() @@ -153,6 +154,7 @@ class FlintSpark(val spark: SparkSession) { .refreshInterval() .foreach(interval => job.trigger(Trigger.ProcessingTime(interval))) + /* val jobId = job .foreachBatch { (batchDF: DataFrame, _: Long) => @@ -160,6 +162,9 @@ class FlintSpark(val spark: SparkSession) { } .start() .id + */ + + val jobId = job.start(indexName).id Some(jobId.toString) } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala index 633068311..e8265daa6 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala @@ -19,7 +19,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedRel import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, EventTimeWatermark} import org.apache.spark.sql.catalyst.util.IntervalUtils -import org.apache.spark.sql.flint.{dataFrameToLogicalPlan, logicalPlanToDataFrame} +import org.apache.spark.sql.flint.logicalPlanToDataFrame import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.unsafe.types.UTF8String @@ -61,14 +61,14 @@ case class FlintSparkMaterializedView( } override def buildStream(spark: SparkSession): DataStreamWriter[Row] = { - val batchPlan = dataFrameToLogicalPlan(spark.sql(query)) + val batchPlan = spark.sql(query).queryExecution.logical val streamingPlan = batchPlan transform { // Insert watermark operator between Aggregate and its child case Aggregate(grouping, agg, child) => val timeCol = grouping.collect { case UnresolvedFunction(identifier, args, _, _, _) - if identifier.mkString(".") == TumbleFunction.identifier.funcName => + if identifier.mkString(".").equalsIgnoreCase(TumbleFunction.identifier.funcName) => args.head } diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala index b645f2113..2fc5df96f 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala @@ -5,16 +5,22 @@ package org.opensearch.flint.spark +import java.sql.Timestamp + import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson import org.opensearch.flint.core.FlintVersion.current +import org.opensearch.flint.spark.FlintSpark.RefreshMode.{FULL, INCREMENTAL} import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.getFlintIndexName import org.scalatest.matchers.must.Matchers.defined import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper +import org.apache.spark.sql.Row + class FlintSparkMaterializedViewITSuite extends FlintSparkSuite { - private val testTable = "spark_catalog.default.ci_test" - private val testMvName = "spark_catalog.default.mv_test" + /** Test table, MV, index name and query */ + private val testTable = "spark_catalog.default.mv_test" + private val testMvName = "spark_catalog.default.mv_test_metrics" private val testFlintIndex = getFlintIndexName(testMvName) private val testQuery = s""" @@ -27,18 +33,15 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite { override def beforeAll(): Unit = { super.beforeAll() - createTimeSeriesTable(testTable) } override def afterEach(): Unit = { super.afterEach() - - // Delete all test indices flint.deleteIndex(testFlintIndex) } - test("create materialized view") { + test("create materialized view with metadata successfully") { flint .materializedView() .name(testMvName) @@ -51,7 +54,7 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite { | { | "_meta": { | "version": "${current()}", - | "name": "spark_catalog.default.mv_test", + | "name": "spark_catalog.default.mv_test_metrics", | "kind": "mv", | "source": "$testQuery", | "indexedColumns": [ @@ -76,6 +79,57 @@ class FlintSparkMaterializedViewITSuite extends FlintSparkSuite { | } | } |""".stripMargin) + } + ignore("full refresh materialized view") { + flint + .materializedView() + .name(testMvName) + .query(testQuery) + .create() + + flint.refreshIndex(testFlintIndex, FULL) + + val indexData = flint.queryIndex(testFlintIndex) + checkAnswer( + indexData, + Seq( + Row(timestamp("2023-10-01 00:00:00"), 1), + Row(timestamp("2023-10-01 00:10:00"), 2), + Row(timestamp("2023-10-01 01:00:00"), 1), + Row(timestamp("2023-10-01 02:00:00"), 1))) + } + + test("incremental refresh materialized view") { + withTempDir { checkpointDir => + val checkpointOption = + FlintSparkIndexOptions(Map("checkpoint_location" -> checkpointDir.getAbsolutePath)) + flint + .materializedView() + .name(testMvName) + .query(testQuery) + .options(checkpointOption) + .create() + + flint + .refreshIndex(testFlintIndex, INCREMENTAL) + .map(awaitStreamingComplete) + .orElse(throw new RuntimeException) + + val indexData = flint.queryIndex(testFlintIndex).select("startTime", "count") + checkAnswer( + indexData, + Seq( + Row(timestamp("2023-10-01 00:00:00"), 1), + Row(timestamp("2023-10-01 00:10:00"), 2), + Row(timestamp("2023-10-01 01:00:00"), 1) + /* + * The last row is pending to fire upon watermark + * Row(timestamp("2023-10-01 02:00:00"), 1) + */ + )) + } } + + private def timestamp(ts: String): Timestamp = Timestamp.valueOf(ts) } diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index 47230374d..2b93ca12a 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -28,6 +28,13 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit setFlintSparkConf(REFRESH_POLICY, "true") } + protected def awaitStreamingComplete(jobId: String): Unit = { + val job = spark.streams.get(jobId) + failAfter(streamingTimeout) { + job.processAllAvailable() + } + } + protected def createPartitionedTable(testTable: String): Unit = { sql(s""" | CREATE TABLE $testTable