diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala index b848f47b4..0bac6ac73 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/package.scala @@ -5,6 +5,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog._ /** @@ -12,6 +13,20 @@ import org.apache.spark.sql.connector.catalog._ */ package object flint { + /** + * Convert the given logical plan to Spark data frame. + * + * @param spark + * Spark session + * @param logicalPlan + * logical plan + * @return + * data frame + */ + def logicalPlanToDataFrame(spark: SparkSession, logicalPlan: LogicalPlan): DataFrame = { + Dataset.ofRows(spark, logicalPlan) + } + /** * Qualify a given table name. * diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala index ee4775b6a..9c78a07f8 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala @@ -10,26 +10,19 @@ import scala.collection.JavaConverters._ import org.json4s.{Formats, NoTypeHints} import org.json4s.native.Serialization import org.opensearch.flint.core.{FlintClient, FlintClientBuilder} -import org.opensearch.flint.core.metadata.FlintMetadata import org.opensearch.flint.spark.FlintSpark.RefreshMode.{FULL, INCREMENTAL, RefreshMode} -import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN +import org.opensearch.flint.spark.FlintSparkIndex.{ID_COLUMN, StreamingRefresh} import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex -import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.COVERING_INDEX_TYPE +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex -import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.SKIPPING_INDEX_TYPE -import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.{SkippingKind, SkippingKindSerializer} -import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, PARTITION, VALUE_SET} -import org.opensearch.flint.spark.skipping.minmax.MinMaxSkippingStrategy -import org.opensearch.flint.spark.skipping.partition.PartitionSkippingStrategy -import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy +import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKindSerializer -import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.SaveMode._ import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE import org.apache.spark.sql.flint.config.FlintSparkConf import org.apache.spark.sql.flint.config.FlintSparkConf.{DOC_ID_COLUMN_NAME, IGNORE_DOC_ID_COLUMN} -import org.apache.spark.sql.streaming.OutputMode.Append -import org.apache.spark.sql.streaming.Trigger +import org.apache.spark.sql.streaming.{DataStreamWriter, Trigger} /** * Flint Spark integration API entrypoint. @@ -69,6 +62,16 @@ class FlintSpark(val spark: SparkSession) { new FlintSparkCoveringIndex.Builder(this) } + /** + * Create materialized view builder for creating mv with fluent API. + * + * @return + * mv builder + */ + def materializedView(): FlintSparkMaterializedView.Builder = { + new FlintSparkMaterializedView.Builder(this) + } + /** * Create the given index with metadata. * @@ -102,12 +105,13 @@ class FlintSpark(val spark: SparkSession) { def refreshIndex(indexName: String, mode: RefreshMode): Option[String] = { val index = describeIndex(indexName) .getOrElse(throw new IllegalStateException(s"Index $indexName doesn't exist")) + val options = index.options val tableName = index.metadata().source - // Write Flint index data to Flint data source (shared by both refresh modes for now) - def writeFlintIndex(df: DataFrame): Unit = { + // Batch refresh Flint index from the given source data frame + def batchRefresh(df: Option[DataFrame] = None): Unit = { index - .build(df) + .build(spark, df) .write .format(FLINT_DATASOURCE) .options(flintSparkConf.properties) @@ -119,36 +123,37 @@ class FlintSpark(val spark: SparkSession) { case FULL if isIncrementalRefreshing(indexName) => throw new IllegalStateException( s"Index $indexName is incremental refreshing and cannot be manual refreshed") + case FULL => - writeFlintIndex( - spark.read - .table(tableName)) + batchRefresh() None + // Flint index has specialized logic and capability for incremental refresh + case INCREMENTAL if index.isInstanceOf[StreamingRefresh] => + val job = + index + .asInstanceOf[StreamingRefresh] + .buildStream(spark) + .writeStream + .queryName(indexName) + .format(FLINT_DATASOURCE) + .options(flintSparkConf.properties) + .addIndexOptions(options) + .start(indexName) + Some(job.id.toString) + + // Otherwise, fall back to foreachBatch + batch refresh case INCREMENTAL => - // TODO: Use Foreach sink for now. Need to move this to FlintSparkSkippingIndex - // once finalized. Otherwise, covering index/MV may have different logic. val job = spark.readStream .table(tableName) .writeStream .queryName(indexName) - .outputMode(Append()) - - index.options - .checkpointLocation() - .foreach(location => job.option("checkpointLocation", location)) - index.options - .refreshInterval() - .foreach(interval => job.trigger(Trigger.ProcessingTime(interval))) - - val jobId = - job - .foreachBatch { (batchDF: DataFrame, _: Long) => - writeFlintIndex(batchDF) - } - .start() - .id - Some(jobId.toString) + .addIndexOptions(options) + .foreachBatch { (batchDF: DataFrame, _: Long) => + batchRefresh(Some(batchDF)) + } + .start() + Some(job.id.toString) } } @@ -161,7 +166,10 @@ class FlintSpark(val spark: SparkSession) { * Flint index list */ def describeIndexes(indexNamePattern: String): Seq[FlintSparkIndex] = { - flintClient.getAllIndexMetadata(indexNamePattern).asScala.map(deserialize) + flintClient + .getAllIndexMetadata(indexNamePattern) + .asScala + .map(FlintSparkIndexFactory.create) } /** @@ -175,7 +183,8 @@ class FlintSpark(val spark: SparkSession) { def describeIndex(indexName: String): Option[FlintSparkIndex] = { if (flintClient.exists(indexName)) { val metadata = flintClient.getIndexMetadata(indexName) - Some(deserialize(metadata)) + val index = FlintSparkIndexFactory.create(metadata) + Some(index) } else { Option.empty } @@ -221,42 +230,30 @@ class FlintSpark(val spark: SparkSession) { } } - private def deserialize(metadata: FlintMetadata): FlintSparkIndex = { - val indexOptions = FlintSparkIndexOptions( - metadata.options.asScala.mapValues(_.asInstanceOf[String]).toMap) + // Using Scala implicit class to avoid breaking method chaining of Spark data frame fluent API + private implicit class FlintDataStreamWriter(val dataStream: DataStreamWriter[Row]) { - metadata.kind match { - case SKIPPING_INDEX_TYPE => - val strategies = metadata.indexedColumns.map { colInfo => - val skippingKind = SkippingKind.withName(getString(colInfo, "kind")) - val columnName = getString(colInfo, "columnName") - val columnType = getString(colInfo, "columnType") + def addIndexOptions(options: FlintSparkIndexOptions): DataStreamWriter[Row] = { + dataStream + .addCheckpointLocation(options.checkpointLocation()) + .addRefreshInterval(options.refreshInterval()) + } - skippingKind match { - case PARTITION => - PartitionSkippingStrategy(columnName = columnName, columnType = columnType) - case VALUE_SET => - ValueSetSkippingStrategy(columnName = columnName, columnType = columnType) - case MIN_MAX => - MinMaxSkippingStrategy(columnName = columnName, columnType = columnType) - case other => - throw new IllegalStateException(s"Unknown skipping strategy: $other") - } - } - new FlintSparkSkippingIndex(metadata.source, strategies, indexOptions) - case COVERING_INDEX_TYPE => - new FlintSparkCoveringIndex( - metadata.name, - metadata.source, - metadata.indexedColumns.map { colInfo => - getString(colInfo, "columnName") -> getString(colInfo, "columnType") - }.toMap, - indexOptions) + def addCheckpointLocation(checkpointLocation: Option[String]): DataStreamWriter[Row] = { + if (checkpointLocation.isDefined) { + dataStream.option("checkpointLocation", checkpointLocation.get) + } else { + dataStream + } } - } - private def getString(map: java.util.Map[String, AnyRef], key: String): String = { - map.get(key).asInstanceOf[String] + def addRefreshInterval(refreshInterval: Option[String]): DataStreamWriter[Row] = { + if (refreshInterval.isDefined) { + dataStream.trigger(Trigger.ProcessingTime(refreshInterval.get)) + } else { + dataStream + } + } } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala index a130821bd..0586bfc49 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndex.scala @@ -9,7 +9,7 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter import org.opensearch.flint.core.metadata.FlintMetadata -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.flint.datatype.FlintDataType import org.apache.spark.sql.types.StructType @@ -44,16 +44,36 @@ trait FlintSparkIndex { * Build a data frame to represent index data computation logic. Upper level code decides how to * use this, ex. batch or streaming, fully or incremental refresh. * + * @param spark + * Spark session for implementation class to use as needed * @param df - * data frame to append building logic + * data frame to append building logic. If none, implementation class create source data frame + * on its own * @return * index building data frame */ - def build(df: DataFrame): DataFrame + def build(spark: SparkSession, df: Option[DataFrame]): DataFrame } object FlintSparkIndex { + /** + * Interface indicates a Flint index has custom streaming refresh capability other than foreach + * batch streaming. + */ + trait StreamingRefresh { + + /** + * Build streaming refresh data frame. + * + * @param spark + * Spark session + * @return + * data frame represents streaming logic + */ + def buildStream(spark: SparkSession): DataFrame + } + /** * ID column name. */ diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala new file mode 100644 index 000000000..cda11405c --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala @@ -0,0 +1,83 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import scala.collection.JavaConverters.mapAsScalaMapConverter + +import org.opensearch.flint.core.metadata.FlintMetadata +import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex +import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.COVERING_INDEX_TYPE +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex +import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.SKIPPING_INDEX_TYPE +import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind +import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, PARTITION, VALUE_SET} +import org.opensearch.flint.spark.skipping.minmax.MinMaxSkippingStrategy +import org.opensearch.flint.spark.skipping.partition.PartitionSkippingStrategy +import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy + +/** + * Flint Spark index factory that encapsulates specific Flint index instance creation. This is for + * internal code use instead of user facing API. + */ +object FlintSparkIndexFactory { + + /** + * Creates Flint index from generic Flint metadata. + * + * @param metadata + * Flint metadata + * @return + * Flint index + */ + def create(metadata: FlintMetadata): FlintSparkIndex = { + val indexOptions = FlintSparkIndexOptions( + metadata.options.asScala.mapValues(_.asInstanceOf[String]).toMap) + + // Convert generic Map[String,AnyRef] in metadata to specific data structure in Flint index + metadata.kind match { + case SKIPPING_INDEX_TYPE => + val strategies = metadata.indexedColumns.map { colInfo => + val skippingKind = SkippingKind.withName(getString(colInfo, "kind")) + val columnName = getString(colInfo, "columnName") + val columnType = getString(colInfo, "columnType") + + skippingKind match { + case PARTITION => + PartitionSkippingStrategy(columnName = columnName, columnType = columnType) + case VALUE_SET => + ValueSetSkippingStrategy(columnName = columnName, columnType = columnType) + case MIN_MAX => + MinMaxSkippingStrategy(columnName = columnName, columnType = columnType) + case other => + throw new IllegalStateException(s"Unknown skipping strategy: $other") + } + } + FlintSparkSkippingIndex(metadata.source, strategies, indexOptions) + case COVERING_INDEX_TYPE => + FlintSparkCoveringIndex( + metadata.name, + metadata.source, + metadata.indexedColumns.map { colInfo => + getString(colInfo, "columnName") -> getString(colInfo, "columnType") + }.toMap, + indexOptions) + case MV_INDEX_TYPE => + FlintSparkMaterializedView( + metadata.name, + metadata.source, + metadata.indexedColumns.map { colInfo => + getString(colInfo, "columnName") -> getString(colInfo, "columnType") + }.toMap, + indexOptions) + } + } + + private def getString(map: java.util.Map[String, AnyRef], key: String): String = { + map.get(key).asInstanceOf[String] + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala index b97c3fea3..e9c2b5be5 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala @@ -13,7 +13,7 @@ import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generat import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.{getFlintIndexName, COVERING_INDEX_TYPE} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql._ /** * Flint covering index in Spark. @@ -54,9 +54,10 @@ case class FlintSparkCoveringIndex( .build() } - override def build(df: DataFrame): DataFrame = { + override def build(spark: SparkSession, df: Option[DataFrame]): DataFrame = { val colNames = indexedColumns.keys.toSeq - df.select(colNames.head, colNames.tail: _*) + df.getOrElse(spark.read.table(tableName)) + .select(colNames.head, colNames.tail: _*) } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala new file mode 100644 index 000000000..ee58ec7f5 --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedView.scala @@ -0,0 +1,194 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.mv + +import java.util.Locale + +import scala.collection.JavaConverters.mapAsJavaMapConverter + +import org.opensearch.flint.core.metadata.FlintMetadata +import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndex, FlintSparkIndexBuilder, FlintSparkIndexOptions} +import org.opensearch.flint.spark.FlintSparkIndex.{generateSchemaJSON, metadataBuilder, StreamingRefresh} +import org.opensearch.flint.spark.FlintSparkIndexOptions.empty +import org.opensearch.flint.spark.function.TumbleFunction +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{getFlintIndexName, MV_INDEX_TYPE} + +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedRelation} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, EventTimeWatermark, LogicalPlan} +import org.apache.spark.sql.catalyst.util.IntervalUtils +import org.apache.spark.sql.flint.{logicalPlanToDataFrame, qualifyTableName} + +/** + * Flint materialized view in Spark. + * + * @param mvName + * MV name + * @param query + * source query that generates MV data + * @param outputSchema + * output schema + * @param options + * index options + */ +case class FlintSparkMaterializedView( + mvName: String, + query: String, + outputSchema: Map[String, String], + override val options: FlintSparkIndexOptions = empty) + extends FlintSparkIndex + with StreamingRefresh { + + /** TODO: add it to index option */ + private val watermarkDelay = "0 Minute" + + override val kind: String = MV_INDEX_TYPE + + override def name(): String = getFlintIndexName(mvName) + + override def metadata(): FlintMetadata = { + val indexColumnMaps = + outputSchema.map { case (colName, colType) => + Map[String, AnyRef]("columnName" -> colName, "columnType" -> colType).asJava + }.toArray + val schemaJson = generateSchemaJSON(outputSchema) + + metadataBuilder(this) + .name(mvName) + .source(query) + .indexedColumns(indexColumnMaps) + .schema(schemaJson) + .build() + } + + override def build(spark: SparkSession, df: Option[DataFrame]): DataFrame = { + require(df.isEmpty, "materialized view doesn't support reading from other data frame") + + spark.sql(query) + } + + override def buildStream(spark: SparkSession): DataFrame = { + val batchPlan = spark.sql(query).queryExecution.logical + + /* + * Convert unresolved batch plan to streaming plan by: + * 1.Insert Watermark operator below Aggregate (required by Spark streaming) + * 2.Set isStreaming flag to true in Relation operator + */ + val streamingPlan = batchPlan transform { + case WindowingAggregate(agg, timeCol) => + agg.copy(child = watermark(timeCol, watermarkDelay, agg.child)) + + case relation: UnresolvedRelation if !relation.isStreaming => + relation.copy(isStreaming = true) + } + logicalPlanToDataFrame(spark, streamingPlan) + } + + private def watermark(timeCol: Attribute, delay: String, child: LogicalPlan) = { + EventTimeWatermark(timeCol, IntervalUtils.fromIntervalString(delay), child) + } + + /** + * Extractor that extract event time column out of Aggregate operator. + */ + private object WindowingAggregate { + + def unapply(agg: Aggregate): Option[(Aggregate, Attribute)] = { + val winFuncs = agg.groupingExpressions.collect { + case func: UnresolvedFunction if isWindowingFunction(func) => + func + } + + if (winFuncs.size != 1) { + throw new IllegalStateException( + "A windowing function is required for streaming aggregation") + } + + // Assume first aggregate item must be time column + val winFunc = winFuncs.head + val timeCol = winFunc.arguments.head.asInstanceOf[Attribute] + Some(agg, timeCol) + } + + private def isWindowingFunction(func: UnresolvedFunction): Boolean = { + val funcName = func.nameParts.mkString(".").toLowerCase(Locale.ROOT) + val funcIdent = FunctionIdentifier(funcName) + + // TODO: support other window functions + funcIdent == TumbleFunction.identifier + } + } +} + +object FlintSparkMaterializedView { + + /** MV index type name */ + val MV_INDEX_TYPE = "mv" + + /** + * Get index name following the convention "flint_" + qualified MV name (replace dot with + * underscore). + * + * @param mvName + * MV name + * @return + * Flint index name + */ + def getFlintIndexName(mvName: String): String = { + require( + mvName.split("\\.").length >= 3, + "Qualified materialized view name catalog.database.mv is required") + + s"flint_${mvName.replace(".", "_")}" + } + + /** Builder class for MV build */ + class Builder(flint: FlintSpark) extends FlintSparkIndexBuilder(flint) { + private var mvName: String = "" + private var query: String = "" + + /** + * Set MV name. + * + * @param mvName + * MV name + * @return + * builder + */ + def name(mvName: String): Builder = { + this.mvName = qualifyTableName(flint.spark, mvName) + this + } + + /** + * Set MV query. + * + * @param query + * MV query + * @return + * builder + */ + def query(query: String): Builder = { + this.query = query + this + } + + override protected def buildIndex(): FlintSparkIndex = { + // TODO: change here and FlintDS class to support complex field type in future + val outputSchema = flint.spark + .sql(query) + .schema + .map { field => + field.name -> field.dataType.typeName + } + .toMap + FlintSparkMaterializedView(mvName, query, outputSchema, indexOptions) + } + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala index ec213a3cd..eb2075b63 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala @@ -9,14 +9,14 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter import org.opensearch.flint.core.metadata.FlintMetadata import org.opensearch.flint.spark._ -import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateSchemaJSON, metadataBuilder, ID_COLUMN} +import org.opensearch.flint.spark.FlintSparkIndex._ import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, FILE_PATH_COLUMN, SKIPPING_INDEX_TYPE} import org.opensearch.flint.spark.skipping.minmax.MinMaxSkippingStrategy import org.opensearch.flint.spark.skipping.partition.PartitionSkippingStrategy import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy -import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.dsl.expressions.DslExpression import org.apache.spark.sql.functions.{col, input_file_name, sha1} @@ -28,9 +28,9 @@ import org.apache.spark.sql.functions.{col, input_file_name, sha1} * @param indexedColumns * indexed column list */ -class FlintSparkSkippingIndex( +case class FlintSparkSkippingIndex( tableName: String, - val indexedColumns: Seq[FlintSparkSkippingStrategy], + indexedColumns: Seq[FlintSparkSkippingStrategy], override val options: FlintSparkIndexOptions = empty) extends FlintSparkIndex { @@ -67,7 +67,7 @@ class FlintSparkSkippingIndex( .build() } - override def build(df: DataFrame): DataFrame = { + override def build(spark: SparkSession, df: Option[DataFrame]): DataFrame = { val outputNames = indexedColumns.flatMap(_.outputSchema().keys) val aggFuncs = indexedColumns.flatMap(_.getAggregators) @@ -77,7 +77,8 @@ class FlintSparkSkippingIndex( new Column(aggFunc.toAggregateExpression().as(name)) } - df.groupBy(input_file_name().as(FILE_PATH_COLUMN)) + df.getOrElse(spark.read.table(tableName)) + .groupBy(input_file_name().as(FILE_PATH_COLUMN)) .agg(namedAggFuncs.head, namedAggFuncs.tail: _*) .withColumn(ID_COLUMN, sha1(col(FILE_PATH_COLUMN))) } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala new file mode 100644 index 000000000..c28495c69 --- /dev/null +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala @@ -0,0 +1,199 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.mv + +import scala.collection.JavaConverters.mapAsJavaMapConverter + +import org.opensearch.flint.spark.FlintSparkIndexOptions +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE +import org.opensearch.flint.spark.mv.FlintSparkMaterializedViewSuite.{streamingRelation, StreamingDslLogicalPlan} +import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, the} +import org.scalatestplus.mockito.MockitoSugar.mock + +import org.apache.spark.FlintSuite +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.dsl.expressions.{count, intToLiteral, stringToLiteral, DslAttr, DslExpression, StringToAttributeConversionHelper} +import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} +import org.apache.spark.sql.catalyst.util.IntervalUtils +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.unsafe.types.UTF8String + +/** + * This UT include test cases for building API which make use of real SparkSession. This is + * because SparkSession.sessionState is private val and hard to mock but it's required in + * logicalPlanToDataFrame() -> DataRows.of(). + */ +class FlintSparkMaterializedViewSuite extends FlintSuite { + + val testMvName = "spark_catalog.default.mv" + val testQuery = "SELECT 1" + + test("get name") { + val mv = FlintSparkMaterializedView(testMvName, testQuery, Map.empty) + mv.name() shouldBe "flint_spark_catalog_default_mv" + } + + test("should fail if get name with unqualified MV name") { + the[IllegalArgumentException] thrownBy + FlintSparkMaterializedView("mv", testQuery, Map.empty).name() + + the[IllegalArgumentException] thrownBy + FlintSparkMaterializedView("default.mv", testQuery, Map.empty).name() + } + + test("get metadata") { + val mv = FlintSparkMaterializedView(testMvName, testQuery, Map("test_col" -> "integer")) + + val metadata = mv.metadata() + metadata.name shouldBe mv.mvName + metadata.kind shouldBe MV_INDEX_TYPE + metadata.source shouldBe "SELECT 1" + metadata.indexedColumns shouldBe Array( + Map("columnName" -> "test_col", "columnType" -> "integer").asJava) + metadata.schema shouldBe Map("test_col" -> Map("type" -> "integer").asJava).asJava + } + + test("get metadata with index options") { + val indexSettings = """{"number_of_shards": 2}""" + val indexOptions = + FlintSparkIndexOptions(Map("auto_refresh" -> "true", "index_settings" -> indexSettings)) + val mv = FlintSparkMaterializedView( + testMvName, + testQuery, + Map("test_col" -> "integer"), + indexOptions) + + mv.metadata().options shouldBe Map( + "auto_refresh" -> "true", + "index_settings" -> indexSettings).asJava + mv.metadata().indexSettings shouldBe Some(indexSettings) + } + + test("build batch data frame") { + val mv = FlintSparkMaterializedView(testMvName, testQuery, Map.empty) + mv.build(spark, None).collect() shouldBe Array(Row(1)) + } + + test("should fail if build given other source data frame") { + val mv = FlintSparkMaterializedView(testMvName, testQuery, Map.empty) + the[IllegalArgumentException] thrownBy mv.build(spark, Some(mock[DataFrame])) + } + + test("build stream should insert watermark operator and replace batch relation") { + val testTable = "mv_build_test" + withTable(testTable) { + sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") + + val testQuery = + s""" + | SELECT + | window.start AS startTime, + | COUNT(*) AS count + | FROM $testTable + | GROUP BY TUMBLE(time, '1 Minute') + |""".stripMargin + + val mv = FlintSparkMaterializedView(testMvName, testQuery, Map.empty) + val actualPlan = mv.buildStream(spark).queryExecution.logical + assert( + actualPlan.sameSemantics( + streamingRelation(testTable) + .watermark($"time", "0 Minute") + .groupBy($"TUMBLE".function($"time", "1 Minute"))( + $"window.start" as "startTime", + count(1) as "count"))) + } + } + + test("build stream with filtering query") { + val testTable = "mv_build_test" + withTable(testTable) { + sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") + + val testQuery = + s""" + | SELECT + | window.start AS startTime, + | COUNT(*) AS count + | FROM $testTable + | WHERE age > 30 + | GROUP BY TUMBLE(time, '1 Minute') + |""".stripMargin + + val mv = FlintSparkMaterializedView(testMvName, testQuery, Map.empty) + val actualPlan = mv.buildStream(spark).queryExecution.logical + assert( + actualPlan.sameSemantics( + streamingRelation(testTable) + .where($"age" > 30) + .watermark($"time", "0 Minute") + .groupBy($"TUMBLE".function($"time", "1 Minute"))( + $"window.start" as "startTime", + count(1) as "count"))) + } + } + + test("build stream with non-aggregate query") { + val testTable = "mv_build_test" + withTable(testTable) { + sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") + + val mv = FlintSparkMaterializedView( + testMvName, + s"SELECT name, age FROM $testTable WHERE age > 30", + Map.empty) + val actualPlan = mv.buildStream(spark).queryExecution.logical + + assert( + actualPlan.sameSemantics( + streamingRelation(testTable) + .where($"age" > 30) + .select($"name", $"age"))) + } + } + + test("build stream should fail if there is aggregation but no windowing function") { + val testTable = "mv_build_test" + withTable(testTable) { + sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") + + val mv = FlintSparkMaterializedView( + testMvName, + s"SELECT name, COUNT(*) AS count FROM $testTable GROUP BY name", + Map.empty) + + the[IllegalStateException] thrownBy + mv.buildStream(spark) + } + } +} + +/** + * Helper method that extends LogicalPlan with more methods by Scala implicit class. + */ +object FlintSparkMaterializedViewSuite { + + def streamingRelation(tableName: String): UnresolvedRelation = { + UnresolvedRelation( + TableIdentifier(tableName), + CaseInsensitiveStringMap.empty(), + isStreaming = true) + } + + implicit class StreamingDslLogicalPlan(val logicalPlan: LogicalPlan) { + + def watermark(colName: Attribute, interval: String): DslLogicalPlan = { + EventTimeWatermark( + colName, + IntervalUtils.stringToInterval(UTF8String.fromString(interval)), + logicalPlan) + } + } +} diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala index a3961bb51..d52c43842 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala @@ -56,7 +56,7 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { val index = new FlintSparkSkippingIndex(testTable, Seq(indexCol)) val df = spark.createDataFrame(Seq(("hello", 20))).toDF("name", "age") - val indexDf = index.build(df) + val indexDf = index.build(spark, Some(df)) indexDf.schema.fieldNames should contain only ("name", FILE_PATH_COLUMN, ID_COLUMN) } diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala new file mode 100644 index 000000000..29ab433c6 --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkMaterializedViewITSuite.scala @@ -0,0 +1,213 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import java.sql.Timestamp + +import com.stephenn.scalatest.jsonassert.JsonMatchers.matchJson +import org.opensearch.flint.core.FlintVersion.current +import org.opensearch.flint.spark.FlintSpark.RefreshMode.{FULL, INCREMENTAL} +import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.getFlintIndexName +import org.scalatest.matchers.must.Matchers.defined +import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper + +import org.apache.spark.sql.{DataFrame, Row} + +class FlintSparkMaterializedViewITSuite extends FlintSparkSuite { + + /** Test table, MV, index name and query */ + private val testTable = "spark_catalog.default.mv_test" + private val testMvName = "spark_catalog.default.mv_test_metrics" + private val testFlintIndex = getFlintIndexName(testMvName) + private val testQuery = + s""" + | SELECT + | window.start AS startTime, + | COUNT(*) AS count + | FROM $testTable + | GROUP BY TUMBLE(time, '10 Minutes') + |""".stripMargin + + override def beforeAll(): Unit = { + super.beforeAll() + createTimeSeriesTable(testTable) + } + + override def afterEach(): Unit = { + super.afterEach() + flint.deleteIndex(testFlintIndex) + } + + test("create materialized view with metadata successfully") { + val indexOptions = + FlintSparkIndexOptions(Map("auto_refresh" -> "true", "checkpoint_location" -> "s3://test/")) + flint + .materializedView() + .name(testMvName) + .query(testQuery) + .options(indexOptions) + .create() + + val index = flint.describeIndex(testFlintIndex) + index shouldBe defined + index.get.metadata().getContent should matchJson(s""" + | { + | "_meta": { + | "version": "${current()}", + | "name": "spark_catalog.default.mv_test_metrics", + | "kind": "mv", + | "source": "$testQuery", + | "indexedColumns": [ + | { + | "columnName": "startTime", + | "columnType": "timestamp" + | },{ + | "columnName": "count", + | "columnType": "long" + | }], + | "options": { + | "auto_refresh": "true", + | "checkpoint_location": "s3://test/" + | }, + | "properties": {} + | }, + | "properties": { + | "startTime": { + | "type": "date", + | "format": "strict_date_optional_time_nanos" + | }, + | "count": { + | "type": "long" + | } + | } + | } + |""".stripMargin) + } + + // TODO: fix this windowing function unable to be used in GROUP BY + ignore("full refresh materialized view") { + flint + .materializedView() + .name(testMvName) + .query(testQuery) + .create() + + flint.refreshIndex(testFlintIndex, FULL) + + val indexData = flint.queryIndex(testFlintIndex) + checkAnswer( + indexData, + Seq( + Row(timestamp("2023-10-01 00:00:00"), 1), + Row(timestamp("2023-10-01 00:10:00"), 2), + Row(timestamp("2023-10-01 01:00:00"), 1), + Row(timestamp("2023-10-01 02:00:00"), 1))) + } + + test("incremental refresh materialized view") { + withIncrementalMaterializedView(testQuery) { indexData => + checkAnswer( + indexData.select("startTime", "count"), + Seq( + Row(timestamp("2023-10-01 00:00:00"), 1), + Row(timestamp("2023-10-01 00:10:00"), 2), + Row(timestamp("2023-10-01 01:00:00"), 1) + /* + * The last row is pending to fire upon watermark + * Row(timestamp("2023-10-01 02:00:00"), 1) + */ + )) + } + } + + test("incremental refresh materialized view with larger window") { + val largeWindowQuery = + s""" + | SELECT + | window.start AS startTime, + | COUNT(*) AS count + | FROM $testTable + | GROUP BY TUMBLE(time, '1 Hour') + |""".stripMargin + + withIncrementalMaterializedView(largeWindowQuery) { indexData => + checkAnswer( + indexData.select("startTime", "count"), + Seq( + Row(timestamp("2023-10-01 00:00:00"), 3), + Row(timestamp("2023-10-01 01:00:00"), 1) + /* + * The last row is pending to fire upon watermark + * Row(timestamp("2023-10-01 02:00:00"), 1) + */ + )) + } + } + + test("incremental refresh materialized view with filtering query") { + val filterQuery = + s""" + | SELECT + | window.start AS startTime, + | COUNT(*) AS count + | FROM $testTable + | WHERE address = 'Seattle' + | GROUP BY TUMBLE(time, '10 Minutes') + |""".stripMargin + + withIncrementalMaterializedView(filterQuery) { indexData => + checkAnswer( + indexData.select("startTime", "count"), + Seq( + Row(timestamp("2023-10-01 00:00:00"), 1) + /* + * The last row is pending to fire upon watermark + * Row(timestamp("2023-10-01 00:10:00"), 1) + */ + )) + } + } + + test("incremental refresh materialized view with non-aggregate query") { + val nonAggQuery = + s""" + | SELECT name, age + | FROM $testTable + | WHERE age <= 30 + |""".stripMargin + + withIncrementalMaterializedView(nonAggQuery) { indexData => + checkAnswer(indexData.select("name", "age"), Seq(Row("A", 30), Row("B", 20), Row("E", 15))) + } + } + + private def timestamp(ts: String): Timestamp = Timestamp.valueOf(ts) + + private def withIncrementalMaterializedView(query: String)( + codeBlock: DataFrame => Unit): Unit = { + withTempDir { checkpointDir => + val indexOptions = FlintSparkIndexOptions( + Map("auto_refresh" -> "true", "checkpoint_location" -> checkpointDir.getAbsolutePath)) + + flint + .materializedView() + .name(testMvName) + .query(query) + .options(indexOptions) + .create() + + flint + .refreshIndex(testFlintIndex, INCREMENTAL) + .map(awaitStreamingComplete) + .orElse(throw new RuntimeException) + + val indexData = flint.queryIndex(testFlintIndex) + + // Execute the code block + codeBlock(indexData) + } + } +} diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index edbf5935a..2b93ca12a 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -15,11 +15,7 @@ import org.apache.spark.sql.streaming.StreamTest /** * Flint Spark suite trait that initializes [[FlintSpark]] API instance. */ -trait FlintSparkSuite - extends QueryTest - with FlintSuite - with OpenSearchSuite - with StreamTest { +trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite with StreamTest { /** Flint Spark high level API being tested */ lazy protected val flint: FlintSpark = new FlintSpark(spark) @@ -32,9 +28,15 @@ trait FlintSparkSuite setFlintSparkConf(REFRESH_POLICY, "true") } + protected def awaitStreamingComplete(jobId: String): Unit = { + val job = spark.streams.get(jobId) + failAfter(streamingTimeout) { + job.processAllAvailable() + } + } + protected def createPartitionedTable(testTable: String): Unit = { - sql( - s""" + sql(s""" | CREATE TABLE $testTable | ( | name STRING, @@ -52,18 +54,39 @@ trait FlintSparkSuite | ) |""".stripMargin) - sql( - s""" + sql(s""" | INSERT INTO $testTable | PARTITION (year=2023, month=4) | VALUES ('Hello', 30, 'Seattle') | """.stripMargin) - sql( - s""" + sql(s""" | INSERT INTO $testTable | PARTITION (year=2023, month=5) | VALUES ('World', 25, 'Portland') | """.stripMargin) } + + protected def createTimeSeriesTable(testTable: String): Unit = { + sql(s""" + | CREATE TABLE $testTable + | ( + | time TIMESTAMP, + | name STRING, + | age INT, + | address STRING + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\t' + | ) + |""".stripMargin) + + sql(s"INSERT INTO $testTable VALUES (TIMESTAMP '2023-10-01 00:01:00', 'A', 30, 'Seattle')") + sql(s"INSERT INTO $testTable VALUES (TIMESTAMP '2023-10-01 00:10:00', 'B', 20, 'Seattle')") + sql(s"INSERT INTO $testTable VALUES (TIMESTAMP '2023-10-01 00:15:00', 'C', 35, 'Portland')") + sql(s"INSERT INTO $testTable VALUES (TIMESTAMP '2023-10-01 01:00:00', 'D', 40, 'Portland')") + sql(s"INSERT INTO $testTable VALUES (TIMESTAMP '2023-10-01 03:00:00', 'E', 15, 'Vancouver')") + } }