Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add where clause support for covering index #85

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ createSkippingIndexStatement
: CREATE SKIPPING INDEX (IF NOT EXISTS)?
ON tableName
LEFT_PAREN indexColTypeList RIGHT_PAREN
whereClause?
(WITH LEFT_PAREN propertyList RIGHT_PAREN)?
;

Expand Down Expand Up @@ -58,6 +59,7 @@ createCoveringIndexStatement
: CREATE INDEX (IF NOT EXISTS)? indexName
ON tableName
LEFT_PAREN indexColumns=multipartIdentifierPropertyList RIGHT_PAREN
whereClause?
(WITH LEFT_PAREN propertyList RIGHT_PAREN)?
;

Expand Down Expand Up @@ -115,6 +117,14 @@ materializedViewQuery
: .+?
;

whereClause
: WHERE filterCondition
;

filterCondition
: .+?
;

indexColTypeList
: indexColType (COMMA indexColType)*
;
Expand Down
1 change: 1 addition & 0 deletions flint-spark-integration/src/main/antlr4/SparkSqlBase.g4
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ SHOW: 'SHOW';
TRUE: 'TRUE';
VIEW: 'VIEW';
VIEWS: 'VIEWS';
WHERE: 'WHERE';
WITH: 'WITH';


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ object FlintSparkIndexFactory {
metadata.indexedColumns.map { colInfo =>
getString(colInfo, "columnName") -> getString(colInfo, "columnType")
}.toMap,
getOptString(metadata.properties, "filterCondition"),
indexOptions)
case MV_INDEX_TYPE =>
FlintSparkMaterializedView(
Expand All @@ -80,4 +81,13 @@ object FlintSparkIndexFactory {
private def getString(map: java.util.Map[String, AnyRef], key: String): String = {
map.get(key).asInstanceOf[String]
}

private def getOptString(map: java.util.Map[String, AnyRef], key: String): Option[String] = {
val value = map.get(key)
if (value == null) {
None
} else {
Some(value.asInstanceOf[String])
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ case class FlintSparkCoveringIndex(
indexName: String,
tableName: String,
indexedColumns: Map[String, String],
filterCondition: Option[String] = None,
override val options: FlintSparkIndexOptions = empty)
extends FlintSparkIndex {

Expand All @@ -46,17 +47,25 @@ case class FlintSparkCoveringIndex(
}
val schemaJson = generateSchemaJSON(indexedColumns)

metadataBuilder(this)
val builder = metadataBuilder(this)
.name(indexName)
.source(tableName)
.indexedColumns(indexColumnMaps)
.schema(schemaJson)
.build()

// Add optional index properties
filterCondition.map(builder.addProperty("filterCondition", _))
builder.build()
}

override def build(spark: SparkSession, df: Option[DataFrame]): DataFrame = {
val colNames = indexedColumns.keys.toSeq
df.getOrElse(spark.read.table(tableName))
val job = df.getOrElse(spark.read.table(tableName))

// Add optional filtering condition
filterCondition
.map(job.where)
.getOrElse(job)
.select(colNames.head, colNames.tail: _*)
}
}
Expand Down Expand Up @@ -95,6 +104,7 @@ object FlintSparkCoveringIndex {
class Builder(flint: FlintSpark) extends FlintSparkIndexBuilder(flint) {
private var indexName: String = ""
private var indexedColumns: Map[String, String] = Map()
private var filterCondition: Option[String] = None

/**
* Set covering index name.
Expand Down Expand Up @@ -137,7 +147,25 @@ object FlintSparkCoveringIndex {
this
}

/**
* Add filtering condition.
*
* @param condition
* filter condition
* @return
* index builder
*/
def filterBy(condition: String): Builder = {
filterCondition = Some(condition)
this
}

override protected def buildIndex(): FlintSparkIndex =
new FlintSparkCoveringIndex(indexName, tableName, indexedColumns, indexOptions)
new FlintSparkCoveringIndex(
indexName,
tableName,
indexedColumns,
filterCondition,
indexOptions)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@

package org.opensearch.flint.spark.sql

import org.antlr.v4.runtime.ParserRuleContext
import org.antlr.v4.runtime.tree.{ParseTree, RuleNode}
import org.opensearch.flint.spark.FlintSpark
import org.opensearch.flint.spark.sql.covering.FlintSparkCoveringIndexAstBuilder
import org.opensearch.flint.spark.sql.mv.FlintSparkMaterializedViewAstBuilder
import org.opensearch.flint.spark.sql.skipping.FlintSparkSkippingIndexAstBuilder

import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.flint.qualifyTableName

/**
Expand Down Expand Up @@ -49,4 +51,20 @@ object FlintSparkSqlAstBuilder {
def getFullTableName(flint: FlintSpark, tableNameCtx: RuleNode): String = {
qualifyTableName(flint.spark, tableNameCtx.getText)
}

/**
* Get original SQL text from the origin.
*
* @param ctx
* rule context to get SQL text associated with
* @return
* SQL text
*/
def getSqlText(ctx: ParserRuleContext): String = {
// Origin must be preserved at the beginning of parsing
val sqlText = CurrentOrigin.get.sqlText.get
val startIndex = ctx.getStart.getStartIndex
val stopIndex = ctx.getStop.getStopIndex
sqlText.substring(startIndex, stopIndex + 1)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import org.opensearch.flint.spark.FlintSpark
import org.opensearch.flint.spark.FlintSpark.RefreshMode
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex
import org.opensearch.flint.spark.sql.{FlintSparkSqlCommand, FlintSparkSqlExtensionsVisitor, SparkSqlAstBuilder}
import org.opensearch.flint.spark.sql.FlintSparkSqlAstBuilder.getFullTableName
import org.opensearch.flint.spark.sql.FlintSparkSqlAstBuilder.{getFullTableName, getSqlText}
import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser._

import org.apache.spark.sql.Row
Expand Down Expand Up @@ -40,6 +40,10 @@ trait FlintSparkCoveringIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A
indexBuilder.addIndexColumns(colName)
}

if (ctx.whereClause() != null) {
indexBuilder.filterBy(getSqlText(ctx.whereClause().filterCondition()))
}

val ignoreIfExists = ctx.EXISTS() != null
val indexOptions = visitPropertyList(ctx.propertyList())
indexBuilder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@ import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndex}
import org.opensearch.flint.spark.FlintSpark.RefreshMode
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView
import org.opensearch.flint.spark.sql.{FlintSparkSqlCommand, FlintSparkSqlExtensionsVisitor, SparkSqlAstBuilder}
import org.opensearch.flint.spark.sql.FlintSparkSqlAstBuilder.getFullTableName
import org.opensearch.flint.spark.sql.FlintSparkSqlAstBuilder.{getFullTableName, getSqlText}
import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser._

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.Command
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.types.StringType

/**
Expand All @@ -29,7 +28,7 @@ trait FlintSparkMaterializedViewAstBuilder extends FlintSparkSqlExtensionsVisito
ctx: CreateMaterializedViewStatementContext): Command = {
FlintSparkSqlCommand() { flint =>
val mvName = getFullTableName(flint, ctx.mvName)
val query = getMvQuery(ctx.query)
val query = getSqlText(ctx.query)

val mvBuilder = flint
.materializedView()
Expand Down Expand Up @@ -103,14 +102,6 @@ trait FlintSparkMaterializedViewAstBuilder extends FlintSparkSqlExtensionsVisito
}
}

private def getMvQuery(ctx: MaterializedViewQueryContext): String = {
// Assume origin must be preserved at the beginning of parsing
val sqlText = CurrentOrigin.get.sqlText.get
val startIndex = ctx.getStart.getStartIndex
val stopIndex = ctx.getStop.getStopIndex
sqlText.substring(startIndex, stopIndex + 1)
}

private def getFlintIndexName(flint: FlintSpark, mvNameCtx: RuleNode): String = {
val fullMvName = getFullTableName(flint, mvNameCtx)
FlintSparkMaterializedView.getFlintIndexName(fullMvName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, PARTITION, VALUE_SET}
import org.opensearch.flint.spark.sql.{FlintSparkSqlCommand, FlintSparkSqlExtensionsVisitor, SparkSqlAstBuilder}
import org.opensearch.flint.spark.sql.FlintSparkSqlAstBuilder.getFullTableName
import org.opensearch.flint.spark.sql.FlintSparkSqlAstBuilder.{getFullTableName, getSqlText}
import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser._

import org.apache.spark.sql.Row
Expand All @@ -29,6 +29,12 @@ trait FlintSparkSkippingIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A
override def visitCreateSkippingIndexStatement(
ctx: CreateSkippingIndexStatementContext): Command =
FlintSparkSqlCommand() { flint =>
// TODO: support filtering condition
if (ctx.whereClause() != null) {
throw new UnsupportedOperationException(
s"Filtering condition is not supported: ${getSqlText(ctx.whereClause())}")
}

// Create skipping index
val indexBuilder = flint
.skippingIndex()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.sql

import org.scalatest.matchers.should.Matchers

import org.apache.spark.FlintSuite

class FlintSparkSqlParserSuite extends FlintSuite with Matchers {

test("create skipping index with filtering condition") {
the[UnsupportedOperationException] thrownBy {
sql("""
| CREATE SKIPPING INDEX ON alb_logs
| (client_ip VALUE_SET)
| WHERE status != 200
| WITH (auto_refresh = true)
|""".stripMargin)
} should have message "Filtering condition is not supported: WHERE status != 200"
}

ignore("create covering index with filtering condition") {
the[UnsupportedOperationException] thrownBy {
sql("""
| CREATE INDEX test ON alb_logs
| (elb, client_ip)
| WHERE status != 404
| WITH (auto_refresh = true)
|""".stripMargin)
} should have message "Filtering condition is not supported: WHERE status != 404"
dai-chen marked this conversation as resolved.
Show resolved Hide resolved
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class FlintSparkCoveringIndexITSuite extends FlintSparkSuite {
.name(testIndex)
.onTable(testTable)
.addIndexColumns("name", "age")
.filterBy("age > 30")
.create()

val index = flint.describeIndex(testFlintIndex)
Expand All @@ -60,7 +61,9 @@ class FlintSparkCoveringIndexITSuite extends FlintSparkSuite {
| }],
| "source": "spark_catalog.default.ci_test",
| "options": { "auto_refresh": "false" },
| "properties": {}
| "properties": {
| "filterCondition": "age > 30"
| }
| },
| "properties": {
| "name": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite {
indexData.count() shouldBe 2
}

test("create covering index with filtering condition") {
sql(s"""
| CREATE INDEX $testIndex ON $testTable
| (name, age)
| WHERE address = 'Portland'
| WITH (auto_refresh = true)
|""".stripMargin)

// Wait for streaming job complete current micro batch
val job = spark.streams.active.find(_.name == testFlintIndex)
awaitStreamingComplete(job.get.id.toString)

val indexData = flint.queryIndex(testFlintIndex)
indexData.count() shouldBe 1
}

test("create covering index with streaming job options") {
withTempDir { checkpointDir =>
sql(s"""
Expand Down
Loading