Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add materialized view in Flint Spark API #71

Merged
merged 18 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,28 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.connector.catalog._

/**
* Flint utility methods that rely on access to private code in Spark SQL package.
*/
package object flint {

/**
* Convert the given logical plan to Spark data frame.
*
* @param spark
* Spark session
* @param logicalPlan
* logical plan
* @return
* data frame
*/
def logicalPlanToDataFrame(spark: SparkSession, logicalPlan: LogicalPlan): DataFrame = {
Dataset.ofRows(spark, logicalPlan)
}

/**
* Qualify a given table name.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,19 @@ import scala.collection.JavaConverters._
import org.json4s.{Formats, NoTypeHints}
import org.json4s.native.Serialization
import org.opensearch.flint.core.{FlintClient, FlintClientBuilder}
import org.opensearch.flint.core.metadata.FlintMetadata
import org.opensearch.flint.spark.FlintSpark.RefreshMode.{FULL, INCREMENTAL, RefreshMode}
import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN
import org.opensearch.flint.spark.FlintSparkIndex.{ID_COLUMN, StreamingRefresh}
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.COVERING_INDEX_TYPE
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.SKIPPING_INDEX_TYPE
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.{SkippingKind, SkippingKindSerializer}
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, PARTITION, VALUE_SET}
import org.opensearch.flint.spark.skipping.minmax.MinMaxSkippingStrategy
import org.opensearch.flint.spark.skipping.partition.PartitionSkippingStrategy
import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKindSerializer

import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.SaveMode._
import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.flint.config.FlintSparkConf.{DOC_ID_COLUMN_NAME, IGNORE_DOC_ID_COLUMN}
import org.apache.spark.sql.streaming.OutputMode.Append
import org.apache.spark.sql.streaming.Trigger
import org.apache.spark.sql.streaming.{DataStreamWriter, Trigger}

/**
* Flint Spark integration API entrypoint.
Expand Down Expand Up @@ -69,6 +62,16 @@ class FlintSpark(val spark: SparkSession) {
new FlintSparkCoveringIndex.Builder(this)
}

/**
* Create materialized view builder for creating mv with fluent API.
*
* @return
* mv builder
*/
def materializedView(): FlintSparkMaterializedView.Builder = {
new FlintSparkMaterializedView.Builder(this)
}

/**
* Create the given index with metadata.
*
Expand Down Expand Up @@ -102,12 +105,13 @@ class FlintSpark(val spark: SparkSession) {
def refreshIndex(indexName: String, mode: RefreshMode): Option[String] = {
val index = describeIndex(indexName)
.getOrElse(throw new IllegalStateException(s"Index $indexName doesn't exist"))
val options = index.options
val tableName = index.metadata().source

// Write Flint index data to Flint data source (shared by both refresh modes for now)
def writeFlintIndex(df: DataFrame): Unit = {
// Batch refresh Flint index from the given source data frame
def batchRefresh(df: Option[DataFrame] = None): Unit = {
index
.build(df)
.build(spark, df)
.write
.format(FLINT_DATASOURCE)
.options(flintSparkConf.properties)
Expand All @@ -119,36 +123,37 @@ class FlintSpark(val spark: SparkSession) {
case FULL if isIncrementalRefreshing(indexName) =>
throw new IllegalStateException(
s"Index $indexName is incremental refreshing and cannot be manual refreshed")

case FULL =>
writeFlintIndex(
spark.read
.table(tableName))
batchRefresh()
None

// Flint index has specialized logic and capability for incremental refresh
case INCREMENTAL if index.isInstanceOf[StreamingRefresh] =>
val job =
index
.asInstanceOf[StreamingRefresh]
.buildStream(spark)
.writeStream
.queryName(indexName)
.format(FLINT_DATASOURCE)
.options(flintSparkConf.properties)
.addIndexOptions(options)
.start(indexName)
Some(job.id.toString)

// Otherwise, fall back to foreachBatch + batch refresh
case INCREMENTAL =>
// TODO: Use Foreach sink for now. Need to move this to FlintSparkSkippingIndex
// once finalized. Otherwise, covering index/MV may have different logic.
val job = spark.readStream
.table(tableName)
.writeStream
.queryName(indexName)
.outputMode(Append())

index.options
.checkpointLocation()
.foreach(location => job.option("checkpointLocation", location))
index.options
.refreshInterval()
.foreach(interval => job.trigger(Trigger.ProcessingTime(interval)))

val jobId =
job
.foreachBatch { (batchDF: DataFrame, _: Long) =>
writeFlintIndex(batchDF)
}
.start()
.id
Some(jobId.toString)
.addIndexOptions(options)
.foreachBatch { (batchDF: DataFrame, _: Long) =>
batchRefresh(Some(batchDF))
}
.start()
Some(job.id.toString)
}
}

Expand All @@ -161,7 +166,10 @@ class FlintSpark(val spark: SparkSession) {
* Flint index list
*/
def describeIndexes(indexNamePattern: String): Seq[FlintSparkIndex] = {
flintClient.getAllIndexMetadata(indexNamePattern).asScala.map(deserialize)
flintClient
.getAllIndexMetadata(indexNamePattern)
.asScala
.map(FlintSparkIndexFactory.create)
}

/**
Expand All @@ -175,7 +183,8 @@ class FlintSpark(val spark: SparkSession) {
def describeIndex(indexName: String): Option[FlintSparkIndex] = {
if (flintClient.exists(indexName)) {
val metadata = flintClient.getIndexMetadata(indexName)
Some(deserialize(metadata))
val index = FlintSparkIndexFactory.create(metadata)
Some(index)
} else {
Option.empty
}
Expand Down Expand Up @@ -221,42 +230,30 @@ class FlintSpark(val spark: SparkSession) {
}
}

private def deserialize(metadata: FlintMetadata): FlintSparkIndex = {
val indexOptions = FlintSparkIndexOptions(
metadata.options.asScala.mapValues(_.asInstanceOf[String]).toMap)
// Using Scala implicit class to avoid breaking method chaining of Spark data frame fluent API
private implicit class FlintDataStreamWriter(val dataStream: DataStreamWriter[Row]) {

metadata.kind match {
case SKIPPING_INDEX_TYPE =>
val strategies = metadata.indexedColumns.map { colInfo =>
val skippingKind = SkippingKind.withName(getString(colInfo, "kind"))
val columnName = getString(colInfo, "columnName")
val columnType = getString(colInfo, "columnType")
def addIndexOptions(options: FlintSparkIndexOptions): DataStreamWriter[Row] = {
dataStream
.addCheckpointLocation(options.checkpointLocation())
.addRefreshInterval(options.refreshInterval())
}

skippingKind match {
case PARTITION =>
PartitionSkippingStrategy(columnName = columnName, columnType = columnType)
case VALUE_SET =>
ValueSetSkippingStrategy(columnName = columnName, columnType = columnType)
case MIN_MAX =>
MinMaxSkippingStrategy(columnName = columnName, columnType = columnType)
case other =>
throw new IllegalStateException(s"Unknown skipping strategy: $other")
}
}
new FlintSparkSkippingIndex(metadata.source, strategies, indexOptions)
case COVERING_INDEX_TYPE =>
new FlintSparkCoveringIndex(
metadata.name,
metadata.source,
metadata.indexedColumns.map { colInfo =>
getString(colInfo, "columnName") -> getString(colInfo, "columnType")
}.toMap,
indexOptions)
def addCheckpointLocation(checkpointLocation: Option[String]): DataStreamWriter[Row] = {
if (checkpointLocation.isDefined) {
dataStream.option("checkpointLocation", checkpointLocation.get)
} else {
dataStream
}
}
}

private def getString(map: java.util.Map[String, AnyRef], key: String): String = {
map.get(key).asInstanceOf[String]
def addRefreshInterval(refreshInterval: Option[String]): DataStreamWriter[Row] = {
if (refreshInterval.isDefined) {
dataStream.trigger(Trigger.ProcessingTime(refreshInterval.get))
} else {
dataStream
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter

import org.opensearch.flint.core.metadata.FlintMetadata

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.flint.datatype.FlintDataType
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -44,16 +44,36 @@ trait FlintSparkIndex {
* Build a data frame to represent index data computation logic. Upper level code decides how to
* use this, ex. batch or streaming, fully or incremental refresh.
*
* @param spark
* Spark session for implementation class to use as needed
* @param df
* data frame to append building logic
* data frame to append building logic. If none, implementation class create source data frame
* on its own
* @return
* index building data frame
*/
def build(df: DataFrame): DataFrame
def build(spark: SparkSession, df: Option[DataFrame]): DataFrame
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in which case df is not empty? seems, skipping/covering/mv are all empty?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not empty when incremental refresh in foreach batch style (incremental refresh on skipping/covering index). In this case, the df is data frame for each micro batch.

https://github.com/dai-chen/opensearch-spark/blob/add-mv-api-on-new-metadata/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSpark.scala#L153

}

object FlintSparkIndex {

/**
* Interface indicates a Flint index has custom streaming refresh capability other than foreach
* batch streaming.
*/
trait StreamingRefresh {

/**
* Build streaming refresh data frame.
*
* @param spark
* Spark session
* @return
* data frame represents streaming logic
*/
def buildStream(spark: SparkSession): DataFrame
}

/**
* ID column name.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark

import scala.collection.JavaConverters.mapAsScalaMapConverter

import org.opensearch.flint.core.metadata.FlintMetadata
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.COVERING_INDEX_TYPE
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.SKIPPING_INDEX_TYPE
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, PARTITION, VALUE_SET}
import org.opensearch.flint.spark.skipping.minmax.MinMaxSkippingStrategy
import org.opensearch.flint.spark.skipping.partition.PartitionSkippingStrategy
import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy

/**
* Flint Spark index factory that encapsulates specific Flint index instance creation. This is for
* internal code use instead of user facing API.
*/
object FlintSparkIndexFactory {

/**
* Creates Flint index from generic Flint metadata.
*
* @param metadata
* Flint metadata
* @return
* Flint index
*/
def create(metadata: FlintMetadata): FlintSparkIndex = {
val indexOptions = FlintSparkIndexOptions(
metadata.options.asScala.mapValues(_.asInstanceOf[String]).toMap)

// Convert generic Map[String,AnyRef] in metadata to specific data structure in Flint index
metadata.kind match {
case SKIPPING_INDEX_TYPE =>
val strategies = metadata.indexedColumns.map { colInfo =>
val skippingKind = SkippingKind.withName(getString(colInfo, "kind"))
val columnName = getString(colInfo, "columnName")
val columnType = getString(colInfo, "columnType")

skippingKind match {
case PARTITION =>
PartitionSkippingStrategy(columnName = columnName, columnType = columnType)
case VALUE_SET =>
ValueSetSkippingStrategy(columnName = columnName, columnType = columnType)
case MIN_MAX =>
MinMaxSkippingStrategy(columnName = columnName, columnType = columnType)
case other =>
throw new IllegalStateException(s"Unknown skipping strategy: $other")
}
}
FlintSparkSkippingIndex(metadata.source, strategies, indexOptions)
case COVERING_INDEX_TYPE =>
FlintSparkCoveringIndex(
metadata.name,
metadata.source,
metadata.indexedColumns.map { colInfo =>
getString(colInfo, "columnName") -> getString(colInfo, "columnType")
}.toMap,
indexOptions)
case MV_INDEX_TYPE =>
FlintSparkMaterializedView(
metadata.name,
metadata.source,
metadata.indexedColumns.map { colInfo =>
getString(colInfo, "columnName") -> getString(colInfo, "columnType")
}.toMap,
indexOptions)
}
}

private def getString(map: java.util.Map[String, AnyRef], key: String): String = {
map.get(key).asInstanceOf[String]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generat
import org.opensearch.flint.spark.FlintSparkIndexOptions.empty
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.{getFlintIndexName, COVERING_INDEX_TYPE}

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql._

/**
* Flint covering index in Spark.
Expand Down Expand Up @@ -54,9 +54,10 @@ case class FlintSparkCoveringIndex(
.build()
}

override def build(df: DataFrame): DataFrame = {
override def build(spark: SparkSession, df: Option[DataFrame]): DataFrame = {
val colNames = indexedColumns.keys.toSeq
df.select(colNames.head, colNames.tail: _*)
df.getOrElse(spark.read.table(tableName))
.select(colNames.head, colNames.tail: _*)
}
}

Expand Down
Loading
Loading