From 1f8dda9b11adb898621ead6e8f403a261bb4a6fb Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Wed, 25 Oct 2023 10:56:09 -0700 Subject: [PATCH] Add where clause support for covering index (#85) * Add where clause for create skipping and covering index Signed-off-by: Chen Dai * Fix optional where clause Signed-off-by: Chen Dai * Add filtering condition support for covering index Signed-off-by: Chen Dai --------- Signed-off-by: Chen Dai --- .../main/antlr4/FlintSparkSqlExtensions.g4 | 10 ++++++ .../src/main/antlr4/SparkSqlBase.g4 | 1 + .../flint/spark/FlintSparkIndexFactory.scala | 10 ++++++ .../covering/FlintSparkCoveringIndex.scala | 36 ++++++++++++++++--- .../spark/sql/FlintSparkSqlAstBuilder.scala | 18 ++++++++++ .../FlintSparkCoveringIndexAstBuilder.scala | 6 +++- ...FlintSparkMaterializedViewAstBuilder.scala | 13 ++----- .../FlintSparkSkippingIndexAstBuilder.scala | 8 ++++- .../spark/sql/FlintSparkSqlParserSuite.scala | 35 ++++++++++++++++++ .../FlintSparkCoveringIndexITSuite.scala | 5 ++- .../FlintSparkCoveringIndexSqlITSuite.scala | 16 +++++++++ 11 files changed, 140 insertions(+), 18 deletions(-) create mode 100644 flint-spark-integration/src/test/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParserSuite.scala diff --git a/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 b/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 index f48c276e4..e44944fcf 100644 --- a/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 +++ b/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 @@ -31,6 +31,7 @@ createSkippingIndexStatement : CREATE SKIPPING INDEX (IF NOT EXISTS)? ON tableName LEFT_PAREN indexColTypeList RIGHT_PAREN + whereClause? (WITH LEFT_PAREN propertyList RIGHT_PAREN)? ; @@ -58,6 +59,7 @@ createCoveringIndexStatement : CREATE INDEX (IF NOT EXISTS)? indexName ON tableName LEFT_PAREN indexColumns=multipartIdentifierPropertyList RIGHT_PAREN + whereClause? (WITH LEFT_PAREN propertyList RIGHT_PAREN)? ; @@ -115,6 +117,14 @@ materializedViewQuery : .+? ; +whereClause + : WHERE filterCondition + ; + +filterCondition + : .+? + ; + indexColTypeList : indexColType (COMMA indexColType)* ; diff --git a/flint-spark-integration/src/main/antlr4/SparkSqlBase.g4 b/flint-spark-integration/src/main/antlr4/SparkSqlBase.g4 index 533d851ba..597a1e585 100644 --- a/flint-spark-integration/src/main/antlr4/SparkSqlBase.g4 +++ b/flint-spark-integration/src/main/antlr4/SparkSqlBase.g4 @@ -174,6 +174,7 @@ SHOW: 'SHOW'; TRUE: 'TRUE'; VIEW: 'VIEW'; VIEWS: 'VIEWS'; +WHERE: 'WHERE'; WITH: 'WITH'; diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala index cda11405c..6d680ae39 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala @@ -65,6 +65,7 @@ object FlintSparkIndexFactory { metadata.indexedColumns.map { colInfo => getString(colInfo, "columnName") -> getString(colInfo, "columnType") }.toMap, + getOptString(metadata.properties, "filterCondition"), indexOptions) case MV_INDEX_TYPE => FlintSparkMaterializedView( @@ -80,4 +81,13 @@ object FlintSparkIndexFactory { private def getString(map: java.util.Map[String, AnyRef], key: String): String = { map.get(key).asInstanceOf[String] } + + private def getOptString(map: java.util.Map[String, AnyRef], key: String): Option[String] = { + val value = map.get(key) + if (value == null) { + None + } else { + Some(value.asInstanceOf[String]) + } + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala index e9c2b5be5..91272309f 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala @@ -29,6 +29,7 @@ case class FlintSparkCoveringIndex( indexName: String, tableName: String, indexedColumns: Map[String, String], + filterCondition: Option[String] = None, override val options: FlintSparkIndexOptions = empty) extends FlintSparkIndex { @@ -46,17 +47,25 @@ case class FlintSparkCoveringIndex( } val schemaJson = generateSchemaJSON(indexedColumns) - metadataBuilder(this) + val builder = metadataBuilder(this) .name(indexName) .source(tableName) .indexedColumns(indexColumnMaps) .schema(schemaJson) - .build() + + // Add optional index properties + filterCondition.map(builder.addProperty("filterCondition", _)) + builder.build() } override def build(spark: SparkSession, df: Option[DataFrame]): DataFrame = { val colNames = indexedColumns.keys.toSeq - df.getOrElse(spark.read.table(tableName)) + val job = df.getOrElse(spark.read.table(tableName)) + + // Add optional filtering condition + filterCondition + .map(job.where) + .getOrElse(job) .select(colNames.head, colNames.tail: _*) } } @@ -95,6 +104,7 @@ object FlintSparkCoveringIndex { class Builder(flint: FlintSpark) extends FlintSparkIndexBuilder(flint) { private var indexName: String = "" private var indexedColumns: Map[String, String] = Map() + private var filterCondition: Option[String] = None /** * Set covering index name. @@ -137,7 +147,25 @@ object FlintSparkCoveringIndex { this } + /** + * Add filtering condition. + * + * @param condition + * filter condition + * @return + * index builder + */ + def filterBy(condition: String): Builder = { + filterCondition = Some(condition) + this + } + override protected def buildIndex(): FlintSparkIndex = - new FlintSparkCoveringIndex(indexName, tableName, indexedColumns, indexOptions) + new FlintSparkCoveringIndex( + indexName, + tableName, + indexedColumns, + filterCondition, + indexOptions) } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlAstBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlAstBuilder.scala index a56d99f14..606cb88eb 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlAstBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlAstBuilder.scala @@ -5,6 +5,7 @@ package org.opensearch.flint.spark.sql +import org.antlr.v4.runtime.ParserRuleContext import org.antlr.v4.runtime.tree.{ParseTree, RuleNode} import org.opensearch.flint.spark.FlintSpark import org.opensearch.flint.spark.sql.covering.FlintSparkCoveringIndexAstBuilder @@ -12,6 +13,7 @@ import org.opensearch.flint.spark.sql.mv.FlintSparkMaterializedViewAstBuilder import org.opensearch.flint.spark.sql.skipping.FlintSparkSkippingIndexAstBuilder import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.flint.qualifyTableName /** @@ -49,4 +51,20 @@ object FlintSparkSqlAstBuilder { def getFullTableName(flint: FlintSpark, tableNameCtx: RuleNode): String = { qualifyTableName(flint.spark, tableNameCtx.getText) } + + /** + * Get original SQL text from the origin. + * + * @param ctx + * rule context to get SQL text associated with + * @return + * SQL text + */ + def getSqlText(ctx: ParserRuleContext): String = { + // Origin must be preserved at the beginning of parsing + val sqlText = CurrentOrigin.get.sqlText.get + val startIndex = ctx.getStart.getStartIndex + val stopIndex = ctx.getStop.getStopIndex + sqlText.substring(startIndex, stopIndex + 1) + } } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/covering/FlintSparkCoveringIndexAstBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/covering/FlintSparkCoveringIndexAstBuilder.scala index c0bb47830..83a816a58 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/covering/FlintSparkCoveringIndexAstBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/covering/FlintSparkCoveringIndexAstBuilder.scala @@ -10,7 +10,7 @@ import org.opensearch.flint.spark.FlintSpark import org.opensearch.flint.spark.FlintSpark.RefreshMode import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex import org.opensearch.flint.spark.sql.{FlintSparkSqlCommand, FlintSparkSqlExtensionsVisitor, SparkSqlAstBuilder} -import org.opensearch.flint.spark.sql.FlintSparkSqlAstBuilder.getFullTableName +import org.opensearch.flint.spark.sql.FlintSparkSqlAstBuilder.{getFullTableName, getSqlText} import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser._ import org.apache.spark.sql.Row @@ -40,6 +40,10 @@ trait FlintSparkCoveringIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A indexBuilder.addIndexColumns(colName) } + if (ctx.whereClause() != null) { + indexBuilder.filterBy(getSqlText(ctx.whereClause().filterCondition())) + } + val ignoreIfExists = ctx.EXISTS() != null val indexOptions = visitPropertyList(ctx.propertyList()) indexBuilder diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/mv/FlintSparkMaterializedViewAstBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/mv/FlintSparkMaterializedViewAstBuilder.scala index 16af7984c..266a10c9f 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/mv/FlintSparkMaterializedViewAstBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/mv/FlintSparkMaterializedViewAstBuilder.scala @@ -10,13 +10,12 @@ import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndex} import org.opensearch.flint.spark.FlintSpark.RefreshMode import org.opensearch.flint.spark.mv.FlintSparkMaterializedView import org.opensearch.flint.spark.sql.{FlintSparkSqlCommand, FlintSparkSqlExtensionsVisitor, SparkSqlAstBuilder} -import org.opensearch.flint.spark.sql.FlintSparkSqlAstBuilder.getFullTableName +import org.opensearch.flint.spark.sql.FlintSparkSqlAstBuilder.{getFullTableName, getSqlText} import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.Command -import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.types.StringType /** @@ -29,7 +28,7 @@ trait FlintSparkMaterializedViewAstBuilder extends FlintSparkSqlExtensionsVisito ctx: CreateMaterializedViewStatementContext): Command = { FlintSparkSqlCommand() { flint => val mvName = getFullTableName(flint, ctx.mvName) - val query = getMvQuery(ctx.query) + val query = getSqlText(ctx.query) val mvBuilder = flint .materializedView() @@ -103,14 +102,6 @@ trait FlintSparkMaterializedViewAstBuilder extends FlintSparkSqlExtensionsVisito } } - private def getMvQuery(ctx: MaterializedViewQueryContext): String = { - // Assume origin must be preserved at the beginning of parsing - val sqlText = CurrentOrigin.get.sqlText.get - val startIndex = ctx.getStart.getStartIndex - val stopIndex = ctx.getStop.getStopIndex - sqlText.substring(startIndex, stopIndex + 1) - } - private def getFlintIndexName(flint: FlintSpark, mvNameCtx: RuleNode): String = { val fullMvName = getFullTableName(flint, mvNameCtx) FlintSparkMaterializedView.getFlintIndexName(fullMvName) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala index dc8132a25..2b0bb6c48 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/skipping/FlintSparkSkippingIndexAstBuilder.scala @@ -12,7 +12,7 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, PARTITION, VALUE_SET} import org.opensearch.flint.spark.sql.{FlintSparkSqlCommand, FlintSparkSqlExtensionsVisitor, SparkSqlAstBuilder} -import org.opensearch.flint.spark.sql.FlintSparkSqlAstBuilder.getFullTableName +import org.opensearch.flint.spark.sql.FlintSparkSqlAstBuilder.{getFullTableName, getSqlText} import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser._ import org.apache.spark.sql.Row @@ -29,6 +29,12 @@ trait FlintSparkSkippingIndexAstBuilder extends FlintSparkSqlExtensionsVisitor[A override def visitCreateSkippingIndexStatement( ctx: CreateSkippingIndexStatementContext): Command = FlintSparkSqlCommand() { flint => + // TODO: support filtering condition + if (ctx.whereClause() != null) { + throw new UnsupportedOperationException( + s"Filtering condition is not supported: ${getSqlText(ctx.whereClause())}") + } + // Create skipping index val indexBuilder = flint .skippingIndex() diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParserSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParserSuite.scala new file mode 100644 index 000000000..87ea34582 --- /dev/null +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParserSuite.scala @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.sql + +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.FlintSuite + +class FlintSparkSqlParserSuite extends FlintSuite with Matchers { + + test("create skipping index with filtering condition") { + the[UnsupportedOperationException] thrownBy { + sql(""" + | CREATE SKIPPING INDEX ON alb_logs + | (client_ip VALUE_SET) + | WHERE status != 200 + | WITH (auto_refresh = true) + |""".stripMargin) + } should have message "Filtering condition is not supported: WHERE status != 200" + } + + ignore("create covering index with filtering condition") { + the[UnsupportedOperationException] thrownBy { + sql(""" + | CREATE INDEX test ON alb_logs + | (elb, client_ip) + | WHERE status != 404 + | WITH (auto_refresh = true) + |""".stripMargin) + } should have message "Filtering condition is not supported: WHERE status != 404" + } +} diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala index a4b0069dd..c79069b9b 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexITSuite.scala @@ -40,6 +40,7 @@ class FlintSparkCoveringIndexITSuite extends FlintSparkSuite { .name(testIndex) .onTable(testTable) .addIndexColumns("name", "age") + .filterBy("age > 30") .create() val index = flint.describeIndex(testFlintIndex) @@ -60,7 +61,9 @@ class FlintSparkCoveringIndexITSuite extends FlintSparkSuite { | }], | "source": "spark_catalog.default.ci_test", | "options": { "auto_refresh": "false" }, - | "properties": {} + | "properties": { + | "filterCondition": "age > 30" + | } | }, | "properties": { | "name": { diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala index 71b768239..b3e2ef063 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkCoveringIndexSqlITSuite.scala @@ -59,6 +59,22 @@ class FlintSparkCoveringIndexSqlITSuite extends FlintSparkSuite { indexData.count() shouldBe 2 } + test("create covering index with filtering condition") { + sql(s""" + | CREATE INDEX $testIndex ON $testTable + | (name, age) + | WHERE address = 'Portland' + | WITH (auto_refresh = true) + |""".stripMargin) + + // Wait for streaming job complete current micro batch + val job = spark.streams.active.find(_.name == testFlintIndex) + awaitStreamingComplete(job.get.id.toString) + + val indexData = flint.queryIndex(testFlintIndex) + indexData.count() shouldBe 1 + } + test("create covering index with streaming job options") { withTempDir { checkpointDir => sql(s"""