Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add covering index in Flint Spark API #22

Merged
merged 12 commits into from
Sep 18, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ import org.json4s.native.JsonMethods.parse
import org.json4s.native.Serialization
import org.opensearch.flint.core.{FlintClient, FlintClientBuilder}
import org.opensearch.flint.core.metadata.FlintMetadata
import org.opensearch.flint.spark.FlintSpark._
import org.opensearch.flint.spark.FlintSpark.RefreshMode.{FULL, INCREMENTAL, RefreshMode}
import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN
import org.opensearch.flint.spark.skipping.{FlintSparkSkippingIndex, FlintSparkSkippingStrategy}
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.COVERING_INDEX_TYPE
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.SKIPPING_INDEX_TYPE
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.{SkippingKind, SkippingKindSerializer}
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, PARTITION, VALUE_SET}
Expand All @@ -25,12 +26,10 @@ import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy

import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.SaveMode._
import org.apache.spark.sql.catalog.Column
import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.flint.config.FlintSparkConf.{DOC_ID_COLUMN_NAME, IGNORE_DOC_ID_COLUMN}
import org.apache.spark.sql.streaming.OutputMode.Append
import org.apache.spark.sql.streaming.StreamingQuery

/**
* Flint Spark integration API entrypoint.
Expand All @@ -42,8 +41,7 @@ class FlintSpark(val spark: SparkSession) {
FlintSparkConf(
Map(
DOC_ID_COLUMN_NAME.optionKey -> ID_COLUMN,
IGNORE_DOC_ID_COLUMN.optionKey -> "true"
).asJava)
IGNORE_DOC_ID_COLUMN.optionKey -> "true").asJava)

/** Flint client for low-level index operation */
private val flintClient: FlintClient = FlintClientBuilder.build(flintSparkConf.flintOptions())
Expand All @@ -57,8 +55,18 @@ class FlintSpark(val spark: SparkSession) {
* @return
* index builder
*/
def skippingIndex(): IndexBuilder = {
new IndexBuilder(this)
def skippingIndex(): FlintSparkSkippingIndex.Builder = {
new FlintSparkSkippingIndex.Builder(this)
}

/**
* Create index builder for creating index with fluent API.
*
* @return
* index builder
*/
def coveringIndex(): FlintSparkCoveringIndex.Builder = {
new FlintSparkCoveringIndex.Builder(this)
}

/**
Expand Down Expand Up @@ -199,6 +207,7 @@ class FlintSpark(val spark: SparkSession) {
*/
private def deserialize(metadata: FlintMetadata): FlintSparkIndex = {
val meta = parse(metadata.getContent) \ "_meta"
val indexName = (meta \ "name").extract[String]
val tableName = (meta \ "source").extract[String]
val indexType = (meta \ "kind").extract[String]
val indexedColumns = (meta \ "indexedColumns").asInstanceOf[JArray]
Expand All @@ -222,6 +231,13 @@ class FlintSpark(val spark: SparkSession) {
}
}
new FlintSparkSkippingIndex(tableName, strategies)
case COVERING_INDEX_TYPE =>
new FlintSparkCoveringIndex(
indexName,
tableName,
indexedColumns.arr.map { obj =>
((obj \ "columnName").extract[String], (obj \ "columnType").extract[String])
}.toMap)
}
}
}
Expand All @@ -236,102 +252,4 @@ object FlintSpark {
type RefreshMode = Value
val FULL, INCREMENTAL = Value
}

/**
* Helper class for index class construct. For now only skipping index supported.
*/
class IndexBuilder(flint: FlintSpark) {
var tableName: String = ""
var indexedColumns: Seq[FlintSparkSkippingStrategy] = Seq()

lazy val allColumns: Map[String, Column] = {
flint.spark.catalog
.listColumns(tableName)
.collect()
.map(col => (col.name, col))
.toMap
}

/**
* Configure which source table the index is based on.
*
* @param tableName
* full table name
* @return
* index builder
*/
def onTable(tableName: String): IndexBuilder = {
this.tableName = tableName
this
}

/**
* Add partition skipping indexed columns.
*
* @param colNames
* indexed column names
* @return
* index builder
*/
def addPartitions(colNames: String*): IndexBuilder = {
require(tableName.nonEmpty, "table name cannot be empty")

colNames
.map(findColumn)
.map(col => PartitionSkippingStrategy(columnName = col.name, columnType = col.dataType))
.foreach(addIndexedColumn)
this
}

/**
* Add value set skipping indexed column.
*
* @param colName
* indexed column name
* @return
* index builder
*/
def addValueSet(colName: String): IndexBuilder = {
require(tableName.nonEmpty, "table name cannot be empty")

val col = findColumn(colName)
addIndexedColumn(ValueSetSkippingStrategy(columnName = col.name, columnType = col.dataType))
this
}

/**
* Add min max skipping indexed column.
*
* @param colName
* indexed column name
* @return
* index builder
*/
def addMinMax(colName: String): IndexBuilder = {
val col = findColumn(colName)
indexedColumns =
indexedColumns :+ MinMaxSkippingStrategy(columnName = col.name, columnType = col.dataType)
this
}

/**
* Create index.
*/
def create(): Unit = {
flint.createIndex(new FlintSparkSkippingIndex(tableName, indexedColumns))
}

private def findColumn(colName: String): Column =
allColumns.getOrElse(
colName,
throw new IllegalArgumentException(s"Column $colName does not exist"))

private def addIndexedColumn(indexedCol: FlintSparkSkippingStrategy): Unit = {
require(
indexedColumns.forall(_.columnName != indexedCol.columnName),
s"${indexedCol.columnName} is already indexed")

indexedColumns = indexedColumns :+ indexedCol
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,15 @@ object FlintSparkIndex {
* ID column name.
*/
val ID_COLUMN: String = "__id__"

/**
* Common prefix of Flint index name which is "flint_database_table_"
*
* @param fullTableName
* source full table name
* @return
* Flint index name
*/
def flintIndexNamePrefix(fullTableName: String): String =
s"flint_${fullTableName.replace(".", "_")}_"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark

import org.apache.spark.sql.catalog.Column

/**
* Flint Spark index builder base class.
*
* @param flint
* Flint Spark API entrypoint
*/
abstract class FlintSparkIndexBuilder(flint: FlintSpark) {

/** Source table name */
protected var tableName: String = ""

/** All columns of the given source table */
lazy protected val allColumns: Map[String, Column] = {
require(tableName.nonEmpty, "Source table name is not provided")

flint.spark.catalog
.listColumns(tableName)
.collect()
.map(col => (col.name, col))
.toMap
}

/**
* Create Flint index.
*/
def create(): Unit = flint.createIndex(buildIndex())

/**
* Build method for concrete builder class to implement
*/
protected def buildIndex(): FlintSparkIndex

protected def findColumn(colName: String): Column =
allColumns.getOrElse(
colName,
throw new IllegalArgumentException(s"Column $colName does not exist"))
}
Loading