diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexUtils.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexUtils.scala new file mode 100644 index 000000000..6b19aa092 --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexUtils.scala @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import org.apache.spark.sql.catalyst.expressions.{Expression, Or} +import org.apache.spark.sql.functions.expr + +/** + * Flint Spark index utility methods. + */ +object FlintSparkIndexUtils { + + /** + * Is the given Spark predicate string a conjunction + * + * @param condition + * predicate condition string + * @return + * true if yes, otherwise false + */ + def isConjunction(condition: String): Boolean = { + isConjunction(expr(condition).expr) + } + + /** + * Is the given Spark predicate a conjunction + * + * @param condition + * predicate condition + * @return + * true if yes, otherwise false + */ + def isConjunction(condition: Expression): Boolean = { + condition.collectFirst { case Or(_, _) => + true + }.isEmpty + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala index 91272309f..845af17f2 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndex.scala @@ -11,6 +11,7 @@ import org.opensearch.flint.core.metadata.FlintMetadata import org.opensearch.flint.spark._ import org.opensearch.flint.spark.FlintSparkIndex.{flintIndexNamePrefix, generateSchemaJSON, metadataBuilder} import org.opensearch.flint.spark.FlintSparkIndexOptions.empty +import org.opensearch.flint.spark.FlintSparkIndexUtils.isConjunction import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex.{getFlintIndexName, COVERING_INDEX_TYPE} import org.apache.spark.sql._ @@ -34,6 +35,7 @@ case class FlintSparkCoveringIndex( extends FlintSparkIndex { require(indexedColumns.nonEmpty, "indexed columns must not be empty") + require(filterCondition.forall(isConjunction), "filtering condition must be conjunction") override val kind: String = COVERING_INDEX_TYPE diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala index 11f8ad304..b29307ed7 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndex.scala @@ -6,10 +6,11 @@ package org.opensearch.flint.spark.skipping import org.opensearch.flint.spark.FlintSpark +import org.opensearch.flint.spark.FlintSparkIndexUtils.isConjunction import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, SKIPPING_INDEX_TYPE} import org.apache.spark.sql.catalyst.catalog.CatalogTable -import org.apache.spark.sql.catalyst.expressions.{And, Expression, Or, Predicate} +import org.apache.spark.sql.catalyst.expressions.{And, Expression, Predicate} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} @@ -33,7 +34,7 @@ class ApplyFlintSparkSkippingIndex(flint: FlintSpark) extends Rule[LogicalPlan] _, Some(table), false)) - if hasNoDisjunction(condition) && !location.isInstanceOf[FlintSparkSkippingFileIndex] => + if isConjunction(condition) && !location.isInstanceOf[FlintSparkSkippingFileIndex] => val index = flint.describeIndex(getIndexName(table)) if (index.exists(_.kind == SKIPPING_INDEX_TYPE)) { val skippingIndex = index.get.asInstanceOf[FlintSparkSkippingIndex] @@ -67,12 +68,6 @@ class ApplyFlintSparkSkippingIndex(flint: FlintSpark) extends Rule[LogicalPlan] getSkippingIndexName(qualifiedTableName) } - private def hasNoDisjunction(condition: Expression): Boolean = { - condition.collectFirst { case Or(_, _) => - true - }.isEmpty - } - private def rewriteToIndexFilter( index: FlintSparkSkippingIndex, condition: Expression): Option[Expression] = { diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala index 3a555db01..0956f9273 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala @@ -11,6 +11,7 @@ import org.opensearch.flint.core.metadata.FlintMetadata import org.opensearch.flint.spark._ import org.opensearch.flint.spark.FlintSparkIndex._ import org.opensearch.flint.spark.FlintSparkIndexOptions.empty +import org.opensearch.flint.spark.FlintSparkIndexUtils.isConjunction import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, FILE_PATH_COLUMN, SKIPPING_INDEX_TYPE} import org.opensearch.flint.spark.skipping.minmax.MinMaxSkippingStrategy import org.opensearch.flint.spark.skipping.partition.PartitionSkippingStrategy @@ -18,7 +19,7 @@ import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.dsl.expressions.DslExpression -import org.apache.spark.sql.functions.{col, input_file_name, sha1} +import org.apache.spark.sql.functions.{col, expr, input_file_name, sha1} /** * Flint skipping index in Spark. @@ -36,6 +37,7 @@ case class FlintSparkSkippingIndex( extends FlintSparkIndex { require(indexedColumns.nonEmpty, "indexed columns must not be empty") + require(filterCondition.forall(isConjunction), "filtering condition must be conjunction") /** Skipping index type */ override val kind: String = SKIPPING_INDEX_TYPE @@ -85,7 +87,11 @@ case class FlintSparkSkippingIndex( // Add optional filtering condition if (filterCondition.isDefined) { - job = job.where(filterCondition.get) + if (isConjunction(expr(filterCondition.get).expr)) { // TODO: do the same for covering and add UT/IT + job = job.where(filterCondition.get) + } else { + throw new IllegalStateException("Filtering condition is not conjunction") + } } job diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala index 8c144b46b..819052b0c 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/covering/FlintSparkCoveringIndexSuite.scala @@ -24,6 +24,24 @@ class FlintSparkCoveringIndexSuite extends FlintSuite { } } + test("should succeed if filtering condition is conjunction") { + new FlintSparkCoveringIndex( + "ci", + "test", + Map("name" -> "string"), + Some("test_field1 = 1 AND test_field2 = 2")) + } + + test("should fail if filtering condition is not conjunction") { + assertThrows[IllegalArgumentException] { + new FlintSparkCoveringIndex( + "ci", + "test", + Map("name" -> "string"), + Some("test_field1 = 1 OR test_field2 = 2")) + } + } + test("should fail if no indexed column given") { assertThrows[IllegalArgumentException] { new FlintSparkCoveringIndex("ci", "default.test", Map.empty) diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala index 37e9e4395..c87d91999 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndexSuite.scala @@ -49,6 +49,22 @@ class FlintSparkSkippingIndexSuite extends FlintSuite { "columnType" -> "integer").asJava) } + test("should succeed if filtering condition is conjunction") { + new FlintSparkSkippingIndex( + testTable, + Seq(mock[FlintSparkSkippingStrategy]), + Some("test_field1 = 1 AND test_field2 = 2")) + } + + test("should fail if filtering condition is not conjunction") { + assertThrows[IllegalArgumentException] { + new FlintSparkSkippingIndex( + testTable, + Seq(mock[FlintSparkSkippingStrategy]), + Some("test_field1 = 1 OR test_field2 = 2")) + } + } + test("can build index building job with unique ID column") { val indexCol = mock[FlintSparkSkippingStrategy] when(indexCol.outputSchema()).thenReturn(Map("name" -> "string"))