diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala index 488d7bc08..2e947b9df 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala @@ -142,11 +142,13 @@ object FlintSparkSkippingStrategy { * @param indexExpr * index expression in a skipping indexed column */ - case class IndexColumnExtractor(indexExpr: Expression) { + case class IndexColumnExtractor(indexExprStr: String, indexExpr: Expression) { def unapply(expr: Expression): Option[Column] = { if (expr.semanticEquals(indexExpr)) { - Some(new Column(expr.canonicalized)) + val sessionState = SparkSession.active.sessionState + val unresolvedExpr = sessionState.sqlParser.parseExpression(indexExprStr) + Some(new Column(unresolvedExpr)) } else { None } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala index a8b899f9c..740312d59 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategy.scala @@ -48,7 +48,7 @@ case class ValueSetSkippingStrategy( override def doRewritePredicate( predicate: Expression, indexExpr: Expression): Option[Expression] = { - val extractor = IndexColumnExtractor(indexExpr) + val extractor = IndexColumnExtractor(columnName, indexExpr) /* * This is supposed to be rewritten to ARRAY_CONTAINS(columName, value). diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndexSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndexSuite.scala index b67a2b707..4eb173405 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndexSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/ApplyFlintSparkSkippingIndexSuite.scala @@ -15,7 +15,7 @@ import org.scalatest.matchers.{Matcher, MatchResult} import org.scalatest.matchers.should.Matchers import org.scalatestplus.mockito.MockitoSugar.mock -import org.apache.spark.SparkFunSuite +import org.apache.spark.FlintSuite import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} @@ -27,7 +27,7 @@ import org.apache.spark.sql.functions.col import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -class ApplyFlintSparkSkippingIndexSuite extends SparkFunSuite with Matchers { +class ApplyFlintSparkSkippingIndexSuite extends FlintSuite with Matchers { /** Test table and index */ private val testTable = "spark_catalog.default.apply_skipping_index_test" diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategySuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategySuite.scala index 08a4c55d0..a1becbdae 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategySuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategySuite.scala @@ -5,14 +5,18 @@ package org.opensearch.flint.spark.skipping +import org.apache.spark.FlintSuite import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions.Expression -trait FlintSparkSkippingStrategySuite { +trait FlintSparkSkippingStrategySuite extends FlintSuite { /** Subclass initializes strategy class to test */ val strategy: FlintSparkSkippingStrategy + /** Resolved index expression */ + val indexExpr: Expression + /* * Add a assertion helpful that provides more readable assertion by * infix function: expr shouldRewriteTo col, expr shouldNotRewrite () @@ -21,7 +25,6 @@ trait FlintSparkSkippingStrategySuite { def shouldRewriteTo(right: Column): Unit = { val queryExpr = left - val indexExpr = left.children.head // Ensure left side matches val actual = strategy.doRewritePredicate(queryExpr, indexExpr) assert(actual.isDefined, s"Expected: ${right.expr}. Actual is None") assert(actual.get == right.expr, s"Expected: ${right.expr}. Actual: ${actual.get}") @@ -29,9 +32,8 @@ trait FlintSparkSkippingStrategySuite { def shouldNotRewrite(): Unit = { val queryExpr = left - val indexExpr = left.children.head val actual = strategy.doRewritePredicate(queryExpr, indexExpr) - assert(actual.isEmpty, s"Expected is None. Actual is ${actual.get}") + assert(actual.isEmpty) } } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategySuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategySuite.scala index 7d2b3a92a..4ade56dd7 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategySuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/minmax/MinMaxSkippingStrategySuite.scala @@ -8,50 +8,46 @@ package org.opensearch.flint.spark.skipping.minmax import org.opensearch.flint.spark.skipping.{FlintSparkSkippingStrategy, FlintSparkSkippingStrategySuite} import org.scalatest.matchers.should.Matchers -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{Abs, AttributeReference, EqualTo, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal} +import org.apache.spark.sql.catalyst.expressions.{Abs, AttributeReference, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal} import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.IntegerType -class MinMaxSkippingStrategySuite - extends SparkFunSuite - with FlintSparkSkippingStrategySuite - with Matchers { +class MinMaxSkippingStrategySuite extends FlintSparkSkippingStrategySuite with Matchers { override val strategy: FlintSparkSkippingStrategy = MinMaxSkippingStrategy(columnName = "age", columnType = "integer") - private val age = AttributeReference("age", IntegerType, nullable = false)() + override val indexExpr: Expression = AttributeReference("age", IntegerType, nullable = false)() private val minAge = col("MinMax_age_0") private val maxAge = col("MinMax_age_1") test("should rewrite EqualTo(, )") { - EqualTo(age, Literal(30)) shouldRewriteTo (minAge <= 30 && maxAge >= 30) + EqualTo(indexExpr, Literal(30)) shouldRewriteTo (minAge <= 30 && maxAge >= 30) } test("should rewrite LessThan(, )") { - LessThan(age, Literal(30)) shouldRewriteTo (minAge < 30) + LessThan(indexExpr, Literal(30)) shouldRewriteTo (minAge < 30) } test("should rewrite LessThanOrEqual(, )") { - LessThanOrEqual(age, Literal(30)) shouldRewriteTo (minAge <= 30) + LessThanOrEqual(indexExpr, Literal(30)) shouldRewriteTo (minAge <= 30) } test("should rewrite GreaterThan(, )") { - GreaterThan(age, Literal(30)) shouldRewriteTo (maxAge > 30) + GreaterThan(indexExpr, Literal(30)) shouldRewriteTo (maxAge > 30) } test("should rewrite GreaterThanOrEqual(, )") { - GreaterThanOrEqual(age, Literal(30)) shouldRewriteTo (maxAge >= 30) + GreaterThanOrEqual(indexExpr, Literal(30)) shouldRewriteTo (maxAge >= 30) } test("should rewrite In(, ") { - val predicate = In(age, Seq(Literal(23), Literal(30), Literal(27))) + val predicate = In(indexExpr, Seq(Literal(23), Literal(30), Literal(27))) predicate shouldRewriteTo (maxAge >= 23 && minAge <= 30) } test("should not rewrite inapplicable predicate") { - EqualTo(age, Abs(Literal(30))) shouldNotRewrite () + EqualTo(indexExpr, Abs(Literal(30))) shouldNotRewrite () } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategySuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategySuite.scala index a866c591f..26dd25f19 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategySuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/partition/PartitionSkippingStrategySuite.scala @@ -8,23 +8,19 @@ package org.opensearch.flint.spark.skipping.partition import org.opensearch.flint.spark.skipping.{FlintSparkSkippingStrategy, FlintSparkSkippingStrategySuite} import org.scalatest.matchers.should.Matchers -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{Abs, AttributeReference, EqualTo, Literal} +import org.apache.spark.sql.catalyst.expressions.{Abs, AttributeReference, EqualTo, Expression, Literal} import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.IntegerType -class PartitionSkippingStrategySuite - extends SparkFunSuite - with FlintSparkSkippingStrategySuite - with Matchers { +class PartitionSkippingStrategySuite extends FlintSparkSkippingStrategySuite with Matchers { override val strategy: FlintSparkSkippingStrategy = PartitionSkippingStrategy(columnName = "year", columnType = "int") - private val year = AttributeReference("year", IntegerType, nullable = false)() + override val indexExpr: Expression = AttributeReference("year", IntegerType, nullable = false)() test("should rewrite EqualTo(, )") { - EqualTo(year, Literal(2023)) shouldRewriteTo (col("year") === 2023) + EqualTo(indexExpr, Literal(2023)) shouldRewriteTo (col("year") === 2023) } test("should not rewrite predicate with other column)") { @@ -35,6 +31,6 @@ class PartitionSkippingStrategySuite } test("should not rewrite inapplicable predicate") { - EqualTo(year, Abs(Literal(2023))) shouldNotRewrite () + EqualTo(indexExpr, Abs(Literal(2023))) shouldNotRewrite () } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategySuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategySuite.scala index 11213d011..4c632bf34 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategySuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/valueset/ValueSetSkippingStrategySuite.scala @@ -10,16 +10,16 @@ import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy.{DE import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{Abs, AttributeReference, EqualTo, Literal} +import org.apache.spark.sql.catalyst.expressions.{Abs, AttributeReference, EqualTo, Expression, Literal} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.StringType -class ValueSetSkippingStrategySuite extends SparkFunSuite with FlintSparkSkippingStrategySuite { +class ValueSetSkippingStrategySuite extends FlintSparkSkippingStrategySuite { override val strategy: FlintSparkSkippingStrategy = ValueSetSkippingStrategy(columnName = "name", columnType = "string") - private val name = AttributeReference("name", StringType, nullable = false)() + override val indexExpr: Expression = AttributeReference("name", StringType, nullable = false)() test("should return parameters with default value") { strategy.parameters shouldBe Map( @@ -48,7 +48,7 @@ class ValueSetSkippingStrategySuite extends SparkFunSuite with FlintSparkSkippin } test("should rewrite EqualTo(, )") { - EqualTo(name, Literal("hello")) shouldRewriteTo + EqualTo(indexExpr, Literal("hello")) shouldRewriteTo (isnull(col("name")) || col("name") === "hello") } @@ -60,6 +60,6 @@ class ValueSetSkippingStrategySuite extends SparkFunSuite with FlintSparkSkippin } test("should not rewrite inapplicable predicate") { - EqualTo(name, Abs(Literal("hello"))) shouldNotRewrite () + EqualTo(indexExpr, Abs(Literal("hello"))) shouldNotRewrite () } }