Skip to content

Commit

Permalink
Add materialized view in Flint Spark API (opensearch-project#71)
Browse files Browse the repository at this point in the history
* Implement MV metadata on refactored Flint metadata

Signed-off-by: Chen Dai <[email protected]>

* Split build API and add IT for MV

Signed-off-by: Chen Dai <[email protected]>

* Add IT for incremental refresh

Signed-off-by: Chen Dai <[email protected]>

* Refactor build API with optional StreamingRefresh interface

Signed-off-by: Chen Dai <[email protected]>

* Add javadoc and remove useless BatchRefresh interface

Signed-off-by: Chen Dai <[email protected]>

* Fluent data frame API chain

Signed-off-by: Chen Dai <[email protected]>

* Add more javadoc

Signed-off-by: Chen Dai <[email protected]>

* Add UT for build function

Signed-off-by: Chen Dai <[email protected]>

* Add UT for build stream function

Signed-off-by: Chen Dai <[email protected]>

* More readability by implicit class

Signed-off-by: Chen Dai <[email protected]>

* Add more IT

Signed-off-by: Chen Dai <[email protected]>

* Refactor MV build stream

Signed-off-by: Chen Dai <[email protected]>

* Add more javadoc and comment

Signed-off-by: Chen Dai <[email protected]>

* Move remaining deserialize logic to new Factory class

Signed-off-by: Chen Dai <[email protected]>

* Add implicit class for options

Signed-off-by: Chen Dai <[email protected]>

* Qualify MV name

Signed-off-by: Chen Dai <[email protected]>

* Fix qualified mv name check

Signed-off-by: Chen Dai <[email protected]>

---------

Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen authored Oct 18, 2023
1 parent 14b4033 commit 3fcf926
Show file tree
Hide file tree
Showing 11 changed files with 841 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,28 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.connector.catalog._

/**
* Flint utility methods that rely on access to private code in Spark SQL package.
*/
package object flint {

/**
* Convert the given logical plan to Spark data frame.
*
* @param spark
* Spark session
* @param logicalPlan
* logical plan
* @return
* data frame
*/
def logicalPlanToDataFrame(spark: SparkSession, logicalPlan: LogicalPlan): DataFrame = {
Dataset.ofRows(spark, logicalPlan)
}

/**
* Qualify a given table name.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,19 @@ import scala.collection.JavaConverters._
import org.json4s.{Formats, NoTypeHints}
import org.json4s.native.Serialization
import org.opensearch.flint.core.{FlintClient, FlintClientBuilder}
import org.opensearch.flint.core.metadata.FlintMetadata
import org.opensearch.flint.spark.FlintSpark.RefreshMode.{FULL, INCREMENTAL, RefreshMode}
import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN
import org.opensearch.flint.spark.FlintSparkIndex.{ID_COLUMN, StreamingRefresh}
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.COVERING_INDEX_TYPE
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.SKIPPING_INDEX_TYPE
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.{SkippingKind, SkippingKindSerializer}
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, PARTITION, VALUE_SET}
import org.opensearch.flint.spark.skipping.minmax.MinMaxSkippingStrategy
import org.opensearch.flint.spark.skipping.partition.PartitionSkippingStrategy
import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKindSerializer

import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.SaveMode._
import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.flint.config.FlintSparkConf.{DOC_ID_COLUMN_NAME, IGNORE_DOC_ID_COLUMN}
import org.apache.spark.sql.streaming.OutputMode.Append
import org.apache.spark.sql.streaming.Trigger
import org.apache.spark.sql.streaming.{DataStreamWriter, Trigger}

/**
* Flint Spark integration API entrypoint.
Expand Down Expand Up @@ -69,6 +62,16 @@ class FlintSpark(val spark: SparkSession) {
new FlintSparkCoveringIndex.Builder(this)
}

/**
* Create materialized view builder for creating mv with fluent API.
*
* @return
* mv builder
*/
def materializedView(): FlintSparkMaterializedView.Builder = {
new FlintSparkMaterializedView.Builder(this)
}

/**
* Create the given index with metadata.
*
Expand Down Expand Up @@ -102,12 +105,13 @@ class FlintSpark(val spark: SparkSession) {
def refreshIndex(indexName: String, mode: RefreshMode): Option[String] = {
val index = describeIndex(indexName)
.getOrElse(throw new IllegalStateException(s"Index $indexName doesn't exist"))
val options = index.options
val tableName = index.metadata().source

// Write Flint index data to Flint data source (shared by both refresh modes for now)
def writeFlintIndex(df: DataFrame): Unit = {
// Batch refresh Flint index from the given source data frame
def batchRefresh(df: Option[DataFrame] = None): Unit = {
index
.build(df)
.build(spark, df)
.write
.format(FLINT_DATASOURCE)
.options(flintSparkConf.properties)
Expand All @@ -119,36 +123,37 @@ class FlintSpark(val spark: SparkSession) {
case FULL if isIncrementalRefreshing(indexName) =>
throw new IllegalStateException(
s"Index $indexName is incremental refreshing and cannot be manual refreshed")

case FULL =>
writeFlintIndex(
spark.read
.table(tableName))
batchRefresh()
None

// Flint index has specialized logic and capability for incremental refresh
case INCREMENTAL if index.isInstanceOf[StreamingRefresh] =>
val job =
index
.asInstanceOf[StreamingRefresh]
.buildStream(spark)
.writeStream
.queryName(indexName)
.format(FLINT_DATASOURCE)
.options(flintSparkConf.properties)
.addIndexOptions(options)
.start(indexName)
Some(job.id.toString)

// Otherwise, fall back to foreachBatch + batch refresh
case INCREMENTAL =>
// TODO: Use Foreach sink for now. Need to move this to FlintSparkSkippingIndex
// once finalized. Otherwise, covering index/MV may have different logic.
val job = spark.readStream
.table(tableName)
.writeStream
.queryName(indexName)
.outputMode(Append())

index.options
.checkpointLocation()
.foreach(location => job.option("checkpointLocation", location))
index.options
.refreshInterval()
.foreach(interval => job.trigger(Trigger.ProcessingTime(interval)))

val jobId =
job
.foreachBatch { (batchDF: DataFrame, _: Long) =>
writeFlintIndex(batchDF)
}
.start()
.id
Some(jobId.toString)
.addIndexOptions(options)
.foreachBatch { (batchDF: DataFrame, _: Long) =>
batchRefresh(Some(batchDF))
}
.start()
Some(job.id.toString)
}
}

Expand All @@ -161,7 +166,10 @@ class FlintSpark(val spark: SparkSession) {
* Flint index list
*/
def describeIndexes(indexNamePattern: String): Seq[FlintSparkIndex] = {
flintClient.getAllIndexMetadata(indexNamePattern).asScala.map(deserialize)
flintClient
.getAllIndexMetadata(indexNamePattern)
.asScala
.map(FlintSparkIndexFactory.create)
}

/**
Expand All @@ -175,7 +183,8 @@ class FlintSpark(val spark: SparkSession) {
def describeIndex(indexName: String): Option[FlintSparkIndex] = {
if (flintClient.exists(indexName)) {
val metadata = flintClient.getIndexMetadata(indexName)
Some(deserialize(metadata))
val index = FlintSparkIndexFactory.create(metadata)
Some(index)
} else {
Option.empty
}
Expand Down Expand Up @@ -221,42 +230,30 @@ class FlintSpark(val spark: SparkSession) {
}
}

private def deserialize(metadata: FlintMetadata): FlintSparkIndex = {
val indexOptions = FlintSparkIndexOptions(
metadata.options.asScala.mapValues(_.asInstanceOf[String]).toMap)
// Using Scala implicit class to avoid breaking method chaining of Spark data frame fluent API
private implicit class FlintDataStreamWriter(val dataStream: DataStreamWriter[Row]) {

metadata.kind match {
case SKIPPING_INDEX_TYPE =>
val strategies = metadata.indexedColumns.map { colInfo =>
val skippingKind = SkippingKind.withName(getString(colInfo, "kind"))
val columnName = getString(colInfo, "columnName")
val columnType = getString(colInfo, "columnType")
def addIndexOptions(options: FlintSparkIndexOptions): DataStreamWriter[Row] = {
dataStream
.addCheckpointLocation(options.checkpointLocation())
.addRefreshInterval(options.refreshInterval())
}

skippingKind match {
case PARTITION =>
PartitionSkippingStrategy(columnName = columnName, columnType = columnType)
case VALUE_SET =>
ValueSetSkippingStrategy(columnName = columnName, columnType = columnType)
case MIN_MAX =>
MinMaxSkippingStrategy(columnName = columnName, columnType = columnType)
case other =>
throw new IllegalStateException(s"Unknown skipping strategy: $other")
}
}
new FlintSparkSkippingIndex(metadata.source, strategies, indexOptions)
case COVERING_INDEX_TYPE =>
new FlintSparkCoveringIndex(
metadata.name,
metadata.source,
metadata.indexedColumns.map { colInfo =>
getString(colInfo, "columnName") -> getString(colInfo, "columnType")
}.toMap,
indexOptions)
def addCheckpointLocation(checkpointLocation: Option[String]): DataStreamWriter[Row] = {
if (checkpointLocation.isDefined) {
dataStream.option("checkpointLocation", checkpointLocation.get)
} else {
dataStream
}
}
}

private def getString(map: java.util.Map[String, AnyRef], key: String): String = {
map.get(key).asInstanceOf[String]
def addRefreshInterval(refreshInterval: Option[String]): DataStreamWriter[Row] = {
if (refreshInterval.isDefined) {
dataStream.trigger(Trigger.ProcessingTime(refreshInterval.get))
} else {
dataStream
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter

import org.opensearch.flint.core.metadata.FlintMetadata

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.flint.datatype.FlintDataType
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -44,16 +44,36 @@ trait FlintSparkIndex {
* Build a data frame to represent index data computation logic. Upper level code decides how to
* use this, ex. batch or streaming, fully or incremental refresh.
*
* @param spark
* Spark session for implementation class to use as needed
* @param df
* data frame to append building logic
* data frame to append building logic. If none, implementation class create source data frame
* on its own
* @return
* index building data frame
*/
def build(df: DataFrame): DataFrame
def build(spark: SparkSession, df: Option[DataFrame]): DataFrame
}

object FlintSparkIndex {

/**
* Interface indicates a Flint index has custom streaming refresh capability other than foreach
* batch streaming.
*/
trait StreamingRefresh {

/**
* Build streaming refresh data frame.
*
* @param spark
* Spark session
* @return
* data frame represents streaming logic
*/
def buildStream(spark: SparkSession): DataFrame
}

/**
* ID column name.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark

import scala.collection.JavaConverters.mapAsScalaMapConverter

import org.opensearch.flint.core.metadata.FlintMetadata
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.COVERING_INDEX_TYPE
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.SKIPPING_INDEX_TYPE
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, PARTITION, VALUE_SET}
import org.opensearch.flint.spark.skipping.minmax.MinMaxSkippingStrategy
import org.opensearch.flint.spark.skipping.partition.PartitionSkippingStrategy
import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy

/**
* Flint Spark index factory that encapsulates specific Flint index instance creation. This is for
* internal code use instead of user facing API.
*/
object FlintSparkIndexFactory {

/**
* Creates Flint index from generic Flint metadata.
*
* @param metadata
* Flint metadata
* @return
* Flint index
*/
def create(metadata: FlintMetadata): FlintSparkIndex = {
val indexOptions = FlintSparkIndexOptions(
metadata.options.asScala.mapValues(_.asInstanceOf[String]).toMap)

// Convert generic Map[String,AnyRef] in metadata to specific data structure in Flint index
metadata.kind match {
case SKIPPING_INDEX_TYPE =>
val strategies = metadata.indexedColumns.map { colInfo =>
val skippingKind = SkippingKind.withName(getString(colInfo, "kind"))
val columnName = getString(colInfo, "columnName")
val columnType = getString(colInfo, "columnType")

skippingKind match {
case PARTITION =>
PartitionSkippingStrategy(columnName = columnName, columnType = columnType)
case VALUE_SET =>
ValueSetSkippingStrategy(columnName = columnName, columnType = columnType)
case MIN_MAX =>
MinMaxSkippingStrategy(columnName = columnName, columnType = columnType)
case other =>
throw new IllegalStateException(s"Unknown skipping strategy: $other")
}
}
FlintSparkSkippingIndex(metadata.source, strategies, indexOptions)
case COVERING_INDEX_TYPE =>
FlintSparkCoveringIndex(
metadata.name,
metadata.source,
metadata.indexedColumns.map { colInfo =>
getString(colInfo, "columnName") -> getString(colInfo, "columnType")
}.toMap,
indexOptions)
case MV_INDEX_TYPE =>
FlintSparkMaterializedView(
metadata.name,
metadata.source,
metadata.indexedColumns.map { colInfo =>
getString(colInfo, "columnName") -> getString(colInfo, "columnType")
}.toMap,
indexOptions)
}
}

private def getString(map: java.util.Map[String, AnyRef], key: String): String = {
map.get(key).asInstanceOf[String]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generat
import org.opensearch.flint.spark.FlintSparkIndexOptions.empty
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.{getFlintIndexName, COVERING_INDEX_TYPE}

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql._

/**
* Flint covering index in Spark.
Expand Down Expand Up @@ -54,9 +54,10 @@ case class FlintSparkCoveringIndex(
.build()
}

override def build(df: DataFrame): DataFrame = {
override def build(spark: SparkSession, df: Option[DataFrame]): DataFrame = {
val colNames = indexedColumns.keys.toSeq
df.select(colNames.head, colNames.tail: _*)
df.getOrElse(spark.read.table(tableName))
.select(colNames.head, colNames.tail: _*)
}
}

Expand Down
Loading

0 comments on commit 3fcf926

Please sign in to comment.