Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exactly-once guarantee for covering index and MV incremental refresh #143

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ User can provide the following options in `WITH` clause of create statement:
+ `watermark_delay`: a string as time expression for how late data can come and still be processed, e.g. 1 minute, 10 seconds. This is required by incremental refresh on materialized view if it has aggregation in the query.
+ `output_mode`: a mode string that describes how data will be written to streaming sink. If unspecified, default append mode will be applied.
+ `index_settings`: a JSON string as index settings for OpenSearch index that will be created. Please follow the format in OpenSearch documentation. If unspecified, default OpenSearch index settings will be applied.
+ `id_expression`: an expression string that generates an ID column to avoid duplicate data when incremental refresh job restart. This is mandatory for covering index or materialized view without aggregation if auto refresh enabled and checkpoint location provided.
+ `extra_options`: a JSON string as extra options that can be passed to Spark streaming source and sink API directly. Use qualified source table name (because there could be multiple) and "sink", e.g. '{"sink": "{key: val}", "table1": {key: val}}'

Note that the index option name is case-sensitive. Here is an example:
Expand All @@ -246,6 +247,7 @@ WITH (
watermark_delay = '1 Second',
output_mode = 'complete',
index_settings = '{"number_of_shards": 2, "number_of_replicas": 3}',
id_expression = 'uuid()',
extra_options = '{"spark_catalog.default.alb_logs": {"maxFilesPerTrigger": "1"}}'
)
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter

import org.opensearch.flint.core.metadata.FlintMetadata

import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Column, DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.flint.datatype.FlintDataType
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StructType

/**
* Flint index interface in Spark.
*/
trait FlintSparkIndex {
trait FlintSparkIndex extends Logging {

/**
* Index type
Expand Down Expand Up @@ -55,7 +58,7 @@ trait FlintSparkIndex {
def build(spark: SparkSession, df: Option[DataFrame]): DataFrame
}

object FlintSparkIndex {
object FlintSparkIndex extends Logging {

/**
* Interface indicates a Flint index has custom streaming refresh capability other than foreach
Expand All @@ -79,6 +82,39 @@ object FlintSparkIndex {
*/
val ID_COLUMN: String = "__id__"

/**
* Generate an ID column in the precedence below: 1) Use ID expression provided in the index
* option; 2) SHA-1 based on all columns if aggregated; 3) Throw exception if auto refresh and
* checkpoint location provided 4) Otherwise, no ID column generated.
*
* @param df
* data frame to generate ID column for
* @param options
* Flint index options
* @return
* optional ID column expression
*/
def generateIdColumn(df: DataFrame, options: FlintSparkIndexOptions): Option[Column] = {
def isAggregated: Boolean = {
df.queryExecution.logical.exists(_.isInstanceOf[Aggregate])
}

val idColumn =
if (options.idExpression().isDefined) {
Some(expr(options.idExpression().get))
} else if (isAggregated) {
Some(sha1(concat_ws("\0", df.columns.map(col): _*)))
} else if (options.autoRefresh() && options.checkpointLocation().isDefined) {
throw new IllegalStateException(
"ID expression is required to avoid duplicate data when index refresh job restart")
} else {
None
}

logInfo(s"Generated ID column based on expression $idColumn")
idColumn
}

/**
* Common prefix of Flint index name which is "flint_database_table_"
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ package org.opensearch.flint.spark
import org.json4s.{Formats, NoTypeHints}
import org.json4s.native.JsonMethods._
import org.json4s.native.Serialization
import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName.{AUTO_REFRESH, CHECKPOINT_LOCATION, EXTRA_OPTIONS, INDEX_SETTINGS, OptionName, OUTPUT_MODE, REFRESH_INTERVAL, WATERMARK_DELAY}
import org.opensearch.flint.spark.FlintSparkIndexOptions.OptionName.{AUTO_REFRESH, CHECKPOINT_LOCATION, EXTRA_OPTIONS, ID_EXPRESSION, INDEX_SETTINGS, OptionName, OUTPUT_MODE, REFRESH_INTERVAL, WATERMARK_DELAY}
import org.opensearch.flint.spark.FlintSparkIndexOptions.validateOptionNames

/**
Expand Down Expand Up @@ -70,6 +70,14 @@ case class FlintSparkIndexOptions(options: Map[String, String]) {
*/
def indexSettings(): Option[String] = getOptionValue(INDEX_SETTINGS)

/**
* An expression that generates unique value as source data row ID.
*
* @return
* ID expression
*/
def idExpression(): Option[String] = getOptionValue(ID_EXPRESSION)

/**
* Extra streaming source options that can be simply passed to DataStreamReader or
* Relation.options
Expand Down Expand Up @@ -136,6 +144,7 @@ object FlintSparkIndexOptions {
val OUTPUT_MODE: OptionName.Value = Value("output_mode")
val INDEX_SETTINGS: OptionName.Value = Value("index_settings")
val EXTRA_OPTIONS: OptionName.Value = Value("extra_options")
val ID_EXPRESSION: OptionName.Value = Value("id_expression")
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter

import org.opensearch.flint.core.metadata.FlintMetadata
import org.opensearch.flint.spark._
import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateSchemaJSON, metadataBuilder}
import org.opensearch.flint.spark.FlintSparkIndex._
import org.opensearch.flint.spark.FlintSparkIndexOptions.empty
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.{getFlintIndexName, COVERING_INDEX_TYPE}

import org.apache.spark.sql._
import org.apache.spark.sql.execution.SimpleMode

/**
* Flint covering index in Spark.
Expand Down Expand Up @@ -59,14 +60,26 @@ case class FlintSparkCoveringIndex(
}

override def build(spark: SparkSession, df: Option[DataFrame]): DataFrame = {
val colNames = indexedColumns.keys.toSeq
val job = df.getOrElse(spark.read.table(tableName))
var colNames = indexedColumns.keys.toSeq
var job = df.getOrElse(spark.read.table(tableName))

// Add ID column
val idColumn = generateIdColumn(job, options)
if (idColumn.isDefined) {
colNames = colNames :+ ID_COLUMN
job = job.withColumn(ID_COLUMN, idColumn.get)
}

// Add optional filtering condition
filterCondition
.map(job.where)
.getOrElse(job)
.select(colNames.head, colNames.tail: _*)
if (filterCondition.isDefined) {
job = job.where(filterCondition.get)
}

// Add indexed columns
job = job.select(colNames.head, colNames.tail: _*)

logInfo(s"Building covering index by " + job.queryExecution.explainString(SimpleMode))
job
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import scala.collection.convert.ImplicitConversions.`map AsScala`

import org.opensearch.flint.core.metadata.FlintMetadata
import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndex, FlintSparkIndexBuilder, FlintSparkIndexOptions}
import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateSchemaJSON, metadataBuilder, StreamingRefresh}
import org.opensearch.flint.spark.FlintSparkIndex._
import org.opensearch.flint.spark.FlintSparkIndexOptions.empty
import org.opensearch.flint.spark.function.TumbleFunction
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.{getFlintIndexName, MV_INDEX_TYPE}
Expand Down Expand Up @@ -68,7 +68,7 @@ case class FlintSparkMaterializedView(
override def build(spark: SparkSession, df: Option[DataFrame]): DataFrame = {
require(df.isEmpty, "materialized view doesn't support reading from other data frame")

spark.sql(query)
addIdColumn(spark.sql(query))
}

override def buildStream(spark: SparkSession): DataFrame = {
Expand All @@ -86,7 +86,18 @@ case class FlintSparkMaterializedView(
case relation: UnresolvedRelation if !relation.isStreaming =>
relation.copy(isStreaming = true, options = optionsWithExtra(spark, relation))
}
logicalPlanToDataFrame(spark, streamingPlan)
val streamDf = logicalPlanToDataFrame(spark, streamingPlan)

addIdColumn(streamDf)
}

private def addIdColumn(df: DataFrame): DataFrame = {
val idColumn = generateIdColumn(df, options)
if (idColumn.isDefined) {
df.withColumn(ID_COLUMN, idColumn.get)
} else {
df
}
}

private def watermark(timeCol: Attribute, child: LogicalPlan) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers {
WATERMARK_DELAY.toString shouldBe "watermark_delay"
OUTPUT_MODE.toString shouldBe "output_mode"
INDEX_SETTINGS.toString shouldBe "index_settings"
ID_EXPRESSION.toString shouldBe "id_expression"
EXTRA_OPTIONS.toString shouldBe "extra_options"
}

Expand All @@ -31,6 +32,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers {
"watermark_delay" -> "30 Seconds",
"output_mode" -> "complete",
"index_settings" -> """{"number_of_shards": 3}""",
"id_expression" -> """sha1(col("timestamp"))""",
"extra_options" ->
""" {
| "alb_logs": {
Expand All @@ -48,6 +50,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers {
options.watermarkDelay() shouldBe Some("30 Seconds")
options.outputMode() shouldBe Some("complete")
options.indexSettings() shouldBe Some("""{"number_of_shards": 3}""")
options.idExpression() shouldBe Some("""sha1(col("timestamp"))""")
options.extraSourceOptions("alb_logs") shouldBe Map("opt1" -> "val1")
options.extraSinkOptions() shouldBe Map("opt2" -> "val2", "opt3" -> "val3")
}
Expand Down Expand Up @@ -75,6 +78,7 @@ class FlintSparkIndexOptionsSuite extends FlintSuite with Matchers {
options.watermarkDelay() shouldBe empty
options.outputMode() shouldBe empty
options.indexSettings() shouldBe empty
options.idExpression() shouldBe empty
options.extraSourceOptions("alb_logs") shouldBe empty
options.extraSinkOptions() shouldBe empty
options.optionsWithDefault should contain("auto_refresh" -> "false")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,23 @@

package org.opensearch.flint.spark.covering

import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper
import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN
import org.opensearch.flint.spark.FlintSparkIndexOptions
import org.scalatest.matchers.should.Matchers

import org.apache.spark.FlintSuite
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.flint.config.FlintSparkConf.OPTIMIZER_RULE_ENABLED
import org.apache.spark.sql.functions._

class FlintSparkCoveringIndexSuite extends FlintSuite {
class FlintSparkCoveringIndexSuite extends FlintSuite with Matchers {

/** Test table name */
val testTable = "spark_catalog.default.ci_test"

test("get covering index name") {
val index =
new FlintSparkCoveringIndex("ci", "spark_catalog.default.test", Map("name" -> "string"))
FlintSparkCoveringIndex("ci", "spark_catalog.default.test", Map("name" -> "string"))
index.name() shouldBe "flint_spark_catalog_default_test_ci_index"
}

Expand All @@ -32,7 +40,120 @@ class FlintSparkCoveringIndexSuite extends FlintSuite {

test("should fail if no indexed column given") {
assertThrows[IllegalArgumentException] {
new FlintSparkCoveringIndex("ci", "default.test", Map.empty)
FlintSparkCoveringIndex("ci", "default.test", Map.empty)
}
}

test("build batch with ID expression given in index options") {
withTable(testTable) {
sql(s"CREATE TABLE $testTable (timestamp TIMESTAMP, name STRING) USING JSON")
val index =
FlintSparkCoveringIndex(
"name_idx",
testTable,
Map("name" -> "string"),
options = FlintSparkIndexOptions(Map("id_expression" -> "timestamp")))

assertDataFrameEquals(
index.build(spark, None),
spark
.table(testTable)
.withColumn(ID_COLUMN, expr("timestamp"))
.select(col("name"), col(ID_COLUMN)))
}
}

test("build batch without ID column") {
withTable(testTable) {
sql(s"CREATE TABLE $testTable (name STRING, age INTEGER) USING JSON")
val index = FlintSparkCoveringIndex("name_idx", testTable, Map("name" -> "string"))

assertDataFrameEquals(
index.build(spark, None),
spark
.table(testTable)
.select(col("name")))
}
}

test("build stream with ID expression given in index options") {
withTable(testTable) {
sql(s"CREATE TABLE $testTable (name STRING, age INTEGER) USING JSON")
val index = FlintSparkCoveringIndex(
"name_idx",
testTable,
Map("name" -> "string"),
options =
FlintSparkIndexOptions(Map("auto_refresh" -> "true", "id_expression" -> "name")))

assertDataFrameEquals(
index.build(spark, Some(spark.table(testTable))),
spark
.table(testTable)
.withColumn(ID_COLUMN, col("name"))
.select("name", ID_COLUMN))
}
}

test("build stream without ID column if no checkpoint location") {
withTable(testTable) {
sql(s"CREATE TABLE $testTable (name STRING, age INTEGER) USING JSON")
val index = FlintSparkCoveringIndex(
"name_idx",
testTable,
Map("name" -> "string"),
options = FlintSparkIndexOptions(Map("auto_refresh" -> "true")))

assertDataFrameEquals(
index.build(spark, Some(spark.table(testTable))),
spark
.table(testTable)
.select(col("name")))
}
}

test("build stream fail if checkpoint location provided but no ID expression") {
withTable(testTable) {
sql(s"CREATE TABLE $testTable (name STRING, age INTEGER) USING JSON")
val index = FlintSparkCoveringIndex(
"name_idx",
testTable,
Map("name" -> "string"),
options = FlintSparkIndexOptions(
Map("auto_refresh" -> "true", "checkpoint_location" -> "s3://test/")))

assertThrows[IllegalStateException] {
index.build(spark, Some(spark.table(testTable)))
}
}
}

test("build with filtering condition") {
withTable(testTable) {
sql(s"CREATE TABLE $testTable (timestamp TIMESTAMP, name STRING) USING JSON")
val index = FlintSparkCoveringIndex(
"name_idx",
testTable,
Map("name" -> "string"),
Some("name = 'test'"))

// Avoid optimizer rule to check Flint index exists
spark.conf.set(OPTIMIZER_RULE_ENABLED.key, "false")
try {
assertDataFrameEquals(
index.build(spark, None),
spark
.table(testTable)
.where("name = 'test'")
.select(col("name")))
} finally {
spark.conf.set(OPTIMIZER_RULE_ENABLED.key, "true")
}
}
}

/* Assert unresolved logical plan in DataFrame equals without semantic analysis */
private def assertDataFrameEquals(df1: DataFrame, df2: DataFrame): Unit = {
comparePlans(df1.queryExecution.logical, df2.queryExecution.logical, checkAnalysis = false)
}
}
Loading
Loading