From fd64e6c772c2f179ce21e9371cf40103ef116907 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Thu, 12 Oct 2023 16:50:21 -0700 Subject: [PATCH] Refactor build API with optional StreamingRefresh interface Signed-off-by: Chen Dai --- .../opensearch/flint/spark/FlintSpark.scala | 51 ++++++++++++++----- .../flint/spark/FlintSparkIndex.scala | 18 +++++-- .../covering/FlintSparkCoveringIndex.scala | 24 ++------- .../spark/mv/FlintSparkMaterializedView.scala | 20 ++++---- .../skipping/FlintSparkSkippingIndex.scala | 30 ++--------- .../FlintSparkSkippingIndexSuite.scala | 2 +- 6 files changed, 71 insertions(+), 74 deletions(-) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala index a5c99752f..844b9f3b1 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala @@ -12,7 +12,7 @@ import org.json4s.native.Serialization import org.opensearch.flint.core.{FlintClient, FlintClientBuilder} import org.opensearch.flint.core.metadata.FlintMetadata import org.opensearch.flint.spark.FlintSpark.RefreshMode.{FULL, INCREMENTAL, RefreshMode} -import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN +import org.opensearch.flint.spark.FlintSparkIndex.{ID_COLUMN, StreamingRefresh} import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.COVERING_INDEX_TYPE import org.opensearch.flint.spark.mv.FlintSparkMaterializedView @@ -117,6 +117,7 @@ class FlintSpark(val spark: SparkSession) { val tableName = index.metadata().source // Write Flint index data to Flint data source (shared by both refresh modes for now) + /* def writeFlintIndex(df: DataFrame): Unit = { index .build(df) @@ -126,22 +127,33 @@ class FlintSpark(val spark: SparkSession) { .mode(Overwrite) .save(indexName) } + */ + + def batchRefresh(df: Option[DataFrame] = None): Unit = { + index + .build(spark, df) + .write + .format(FLINT_DATASOURCE) + .options(flintSparkConf.properties) + .mode(Overwrite) + .save(indexName) + } mode match { case FULL if isIncrementalRefreshing(indexName) => throw new IllegalStateException( s"Index $indexName is incremental refreshing and cannot be manual refreshed") + case FULL => - writeFlintIndex( - spark.read - .table(tableName)) + batchRefresh() None - case INCREMENTAL => - // TODO: Use Foreach sink for now. Need to move this to FlintSparkSkippingIndex - // once finalized. Otherwise, covering index/MV may have different logic. + case INCREMENTAL if index.isInstanceOf[StreamingRefresh] => val job = - index.buildStream(spark) + index + .asInstanceOf[StreamingRefresh] + .build(spark) + .writeStream .queryName(indexName) .outputMode(Append()) .format(FLINT_DATASOURCE) @@ -154,17 +166,30 @@ class FlintSpark(val spark: SparkSession) { .refreshInterval() .foreach(interval => job.trigger(Trigger.ProcessingTime(interval))) - /* + val jobId = job.start(indexName).id + Some(jobId.toString) + + case INCREMENTAL => + val job = spark.readStream + .table(tableName) + .writeStream + .queryName(indexName) + .outputMode(Append()) + + index.options + .checkpointLocation() + .foreach(location => job.option("checkpointLocation", location)) + index.options + .refreshInterval() + .foreach(interval => job.trigger(Trigger.ProcessingTime(interval))) + val jobId = job .foreachBatch { (batchDF: DataFrame, _: Long) => - writeFlintIndex(batchDF) + batchRefresh(Some(batchDF)) } .start() .id - */ - - val jobId = job.start(indexName).id Some(jobId.toString) } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala index 4fa2bcb25..eaf25021b 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala @@ -8,16 +8,16 @@ package org.opensearch.flint.spark import scala.collection.JavaConverters.mapAsJavaMapConverter import org.opensearch.flint.core.metadata.FlintMetadata +import org.opensearch.flint.spark.FlintSparkIndex.BatchRefresh -import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession} +import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.flint.datatype.FlintDataType -import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types.StructType /** * Flint index interface in Spark. */ -trait FlintSparkIndex { +trait FlintSparkIndex extends BatchRefresh { /** * Index type @@ -50,15 +50,27 @@ trait FlintSparkIndex { * @return * index building data frame */ + /* def build(df: DataFrame): DataFrame def buildBatch(spark: SparkSession): DataFrameWriter[Row] def buildStream(spark: SparkSession): DataStreamWriter[Row] + */ } object FlintSparkIndex { + trait BatchRefresh { + + def build(spark: SparkSession, df: Option[DataFrame]): DataFrame + } + + trait StreamingRefresh { + + def build(spark: SparkSession): DataFrame + } + /** * ID column name. */ diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala index 541e6f82c..e9c2b5be5 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala @@ -14,8 +14,6 @@ import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.{getFlintIndexName, COVERING_INDEX_TYPE} import org.apache.spark.sql._ -import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE -import org.apache.spark.sql.streaming.DataStreamWriter /** * Flint covering index in Spark. @@ -56,26 +54,10 @@ case class FlintSparkCoveringIndex( .build() } - override def build(df: DataFrame): DataFrame = { + override def build(spark: SparkSession, df: Option[DataFrame]): DataFrame = { val colNames = indexedColumns.keys.toSeq - df.select(colNames.head, colNames.tail: _*) - } - - override def buildBatch(spark: SparkSession): DataFrameWriter[Row] = { - build(spark.read.table(tableName)).write - } - - override def buildStream(spark: SparkSession): DataStreamWriter[Row] = { - spark.readStream - .table(tableName) - .writeStream - .foreachBatch { (batch: DataFrame, _: Long) => - build(batch).write - .format(FLINT_DATASOURCE) - // .options(flint.flintSparkConf.properties) - .mode(SaveMode.Overwrite) - .save(name()) - } + df.getOrElse(spark.read.table(tableName)) + .select(colNames.head, colNames.tail: _*) } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala index e8265daa6..a057e814d 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala @@ -9,18 +9,17 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter import org.opensearch.flint.core.metadata.FlintMetadata import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndex, FlintSparkIndexBuilder, FlintSparkIndexOptions} -import org.opensearch.flint.spark.FlintSparkIndex.{generateSchemaJSON, metadataBuilder} +import org.opensearch.flint.spark.FlintSparkIndex.{generateSchemaJSON, metadataBuilder, StreamingRefresh} import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.opensearch.flint.spark.function.TumbleFunction import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{getFlintIndexName, MV_INDEX_TYPE} -import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession} +import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, EventTimeWatermark} import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.flint.logicalPlanToDataFrame -import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.unsafe.types.UTF8String case class FlintSparkMaterializedView( @@ -28,7 +27,8 @@ case class FlintSparkMaterializedView( query: String, outputSchema: Map[String, String], override val options: FlintSparkIndexOptions = empty) - extends FlintSparkIndex { + extends FlintSparkIndex + with StreamingRefresh { /** TODO: add it to index option */ private val watermarkDelay = UTF8String.fromString("0 Minute") @@ -52,15 +52,13 @@ case class FlintSparkMaterializedView( .build() } - override def build(df: DataFrame): DataFrame = { - null - } + override def build(spark: SparkSession, df: Option[DataFrame]): DataFrame = { + require(df.isEmpty, "materialized view doesn't support reading from other table") - override def buildBatch(spark: SparkSession): DataFrameWriter[Row] = { - spark.sql(query).write + spark.sql(query) } - override def buildStream(spark: SparkSession): DataStreamWriter[Row] = { + override def build(spark: SparkSession): DataFrame = { val batchPlan = spark.sql(query).queryExecution.logical val streamingPlan = batchPlan transform { @@ -89,7 +87,7 @@ case class FlintSparkMaterializedView( UnresolvedRelation(multipartIdentifier, options, isStreaming = true) } - logicalPlanToDataFrame(spark, streamingPlan).writeStream + logicalPlanToDataFrame(spark, streamingPlan) } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala index 0c45e2155..4fe4dd1f6 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala @@ -9,7 +9,7 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter import org.opensearch.flint.core.metadata.FlintMetadata import org.opensearch.flint.spark._ -import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateSchemaJSON, metadataBuilder, ID_COLUMN} +import org.opensearch.flint.spark.FlintSparkIndex._ import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, FILE_PATH_COLUMN, SKIPPING_INDEX_TYPE} import org.opensearch.flint.spark.skipping.minmax.MinMaxSkippingStrategy @@ -18,9 +18,7 @@ import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.dsl.expressions.DslExpression -import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE import org.apache.spark.sql.functions.{col, input_file_name, sha1} -import org.apache.spark.sql.streaming.DataStreamWriter /** * Flint skipping index in Spark. @@ -69,7 +67,7 @@ class FlintSparkSkippingIndex( .build() } - override def build(df: DataFrame): DataFrame = { + override def build(spark: SparkSession, df: Option[DataFrame]): DataFrame = { val outputNames = indexedColumns.flatMap(_.outputSchema().keys) val aggFuncs = indexedColumns.flatMap(_.getAggregators) @@ -79,28 +77,10 @@ class FlintSparkSkippingIndex( new Column(aggFunc.toAggregateExpression().as(name)) } - df.groupBy(input_file_name().as(FILE_PATH_COLUMN)) + df.getOrElse(spark.read.table(tableName)) + .groupBy(input_file_name().as(FILE_PATH_COLUMN)) .agg(namedAggFuncs.head, namedAggFuncs.tail: _*) - .withColumn(ID_COLUMN, sha1(col(FILE_PATH_COLUMN))) - } - - override def buildBatch(spark: SparkSession): DataFrameWriter[Row] = { - build(spark.read.table(tableName)).write - } - - override def buildStream(spark: SparkSession): DataStreamWriter[Row] = { - spark.readStream - .table(tableName) - .writeStream - .foreachBatch { (batch: DataFrame, _: Long) => - build(batch) - .withColumn(ID_COLUMN, sha1(col(FILE_PATH_COLUMN))) - .write - .format(FLINT_DATASOURCE) - // .options(flint.flintSparkConf.properties) - .mode(SaveMode.Overwrite) - .save(name()) - } + .withColumn(ID_COLUMN, sha1(col(FILE_PATH_COLUMN))) // TODO: no impact to just add it? } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala index a3961bb51..d52c43842 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala @@ -56,7 +56,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) val df = spark.createDataFrame(Seq(("hello", 20))).toDF("name", "age") - val indexDf = index.build(df) + val indexDf = index.build(spark, Some(df)) indexDf.schema.fieldNames should contain only ("name", FILE_PATH_COLUMN, ID_COLUMN) }