Skip to content

Commit

Permalink
Fix broken UT
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Jan 23, 2024
1 parent be83704 commit ea7bb13
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,13 @@ object FlintSparkSkippingStrategy {
* @param indexExpr
* index expression in a skipping indexed column
*/
case class IndexColumnExtractor(indexExpr: Expression) {
case class IndexColumnExtractor(indexExprStr: String, indexExpr: Expression) {

def unapply(expr: Expression): Option[Column] = {
if (expr.semanticEquals(indexExpr)) {
Some(new Column(expr.canonicalized))
val sessionState = SparkSession.active.sessionState
val unresolvedExpr = sessionState.sqlParser.parseExpression(indexExprStr)
Some(new Column(unresolvedExpr))
} else {
None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ case class ValueSetSkippingStrategy(
override def doRewritePredicate(
predicate: Expression,
indexExpr: Expression): Option[Expression] = {
val extractor = IndexColumnExtractor(indexExpr)
val extractor = IndexColumnExtractor(columnName, indexExpr)

/*
* This is supposed to be rewritten to ARRAY_CONTAINS(columName, value).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import org.scalatest.matchers.{Matcher, MatchResult}
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.mockito.MockitoSugar.mock

import org.apache.spark.SparkFunSuite
import org.apache.spark.FlintSuite
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType}
Expand All @@ -27,7 +27,7 @@ import org.apache.spark.sql.functions.col
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class ApplyFlintSparkSkippingIndexSuite extends SparkFunSuite with Matchers {
class ApplyFlintSparkSkippingIndexSuite extends FlintSuite with Matchers {

/** Test table and index */
private val testTable = "spark_catalog.default.apply_skipping_index_test"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@

package org.opensearch.flint.spark.skipping

import org.apache.spark.FlintSuite
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.Expression

trait FlintSparkSkippingStrategySuite {
trait FlintSparkSkippingStrategySuite extends FlintSuite {

/** Subclass initializes strategy class to test */
val strategy: FlintSparkSkippingStrategy

/** Resolved index expression */
val indexExpr: Expression

/*
* Add a assertion helpful that provides more readable assertion by
* infix function: expr shouldRewriteTo col, expr shouldNotRewrite ()
Expand All @@ -21,17 +25,15 @@ trait FlintSparkSkippingStrategySuite {

def shouldRewriteTo(right: Column): Unit = {
val queryExpr = left
val indexExpr = left.children.head // Ensure left side matches
val actual = strategy.doRewritePredicate(queryExpr, indexExpr)
assert(actual.isDefined, s"Expected: ${right.expr}. Actual is None")
assert(actual.get == right.expr, s"Expected: ${right.expr}. Actual: ${actual.get}")
}

def shouldNotRewrite(): Unit = {
val queryExpr = left
val indexExpr = left.children.head
val actual = strategy.doRewritePredicate(queryExpr, indexExpr)
assert(actual.isEmpty, s"Expected is None. Actual is ${actual.get}")
assert(actual.isEmpty)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,50 +8,46 @@ package org.opensearch.flint.spark.skipping.minmax
import org.opensearch.flint.spark.skipping.{FlintSparkSkippingStrategy, FlintSparkSkippingStrategySuite}
import org.scalatest.matchers.should.Matchers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Abs, AttributeReference, EqualTo, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal}
import org.apache.spark.sql.catalyst.expressions.{Abs, AttributeReference, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Literal}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.IntegerType

class MinMaxSkippingStrategySuite
extends SparkFunSuite
with FlintSparkSkippingStrategySuite
with Matchers {
class MinMaxSkippingStrategySuite extends FlintSparkSkippingStrategySuite with Matchers {

override val strategy: FlintSparkSkippingStrategy =
MinMaxSkippingStrategy(columnName = "age", columnType = "integer")

private val age = AttributeReference("age", IntegerType, nullable = false)()
override val indexExpr: Expression = AttributeReference("age", IntegerType, nullable = false)()
private val minAge = col("MinMax_age_0")
private val maxAge = col("MinMax_age_1")

test("should rewrite EqualTo(<indexCol>, <value>)") {
EqualTo(age, Literal(30)) shouldRewriteTo (minAge <= 30 && maxAge >= 30)
EqualTo(indexExpr, Literal(30)) shouldRewriteTo (minAge <= 30 && maxAge >= 30)
}

test("should rewrite LessThan(<indexCol>, <value>)") {
LessThan(age, Literal(30)) shouldRewriteTo (minAge < 30)
LessThan(indexExpr, Literal(30)) shouldRewriteTo (minAge < 30)
}

test("should rewrite LessThanOrEqual(<indexCol>, <value>)") {
LessThanOrEqual(age, Literal(30)) shouldRewriteTo (minAge <= 30)
LessThanOrEqual(indexExpr, Literal(30)) shouldRewriteTo (minAge <= 30)
}

test("should rewrite GreaterThan(<indexCol>, <value>)") {
GreaterThan(age, Literal(30)) shouldRewriteTo (maxAge > 30)
GreaterThan(indexExpr, Literal(30)) shouldRewriteTo (maxAge > 30)
}

test("should rewrite GreaterThanOrEqual(<indexCol>, <value>)") {
GreaterThanOrEqual(age, Literal(30)) shouldRewriteTo (maxAge >= 30)
GreaterThanOrEqual(indexExpr, Literal(30)) shouldRewriteTo (maxAge >= 30)
}

test("should rewrite In(<indexCol>, <value1, value2 ...>") {
val predicate = In(age, Seq(Literal(23), Literal(30), Literal(27)))
val predicate = In(indexExpr, Seq(Literal(23), Literal(30), Literal(27)))

predicate shouldRewriteTo (maxAge >= 23 && minAge <= 30)
}

test("should not rewrite inapplicable predicate") {
EqualTo(age, Abs(Literal(30))) shouldNotRewrite ()
EqualTo(indexExpr, Abs(Literal(30))) shouldNotRewrite ()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,19 @@ package org.opensearch.flint.spark.skipping.partition
import org.opensearch.flint.spark.skipping.{FlintSparkSkippingStrategy, FlintSparkSkippingStrategySuite}
import org.scalatest.matchers.should.Matchers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Abs, AttributeReference, EqualTo, Literal}
import org.apache.spark.sql.catalyst.expressions.{Abs, AttributeReference, EqualTo, Expression, Literal}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.IntegerType

class PartitionSkippingStrategySuite
extends SparkFunSuite
with FlintSparkSkippingStrategySuite
with Matchers {
class PartitionSkippingStrategySuite extends FlintSparkSkippingStrategySuite with Matchers {

override val strategy: FlintSparkSkippingStrategy =
PartitionSkippingStrategy(columnName = "year", columnType = "int")

private val year = AttributeReference("year", IntegerType, nullable = false)()
override val indexExpr: Expression = AttributeReference("year", IntegerType, nullable = false)()

test("should rewrite EqualTo(<indexCol>, <value>)") {
EqualTo(year, Literal(2023)) shouldRewriteTo (col("year") === 2023)
EqualTo(indexExpr, Literal(2023)) shouldRewriteTo (col("year") === 2023)
}

test("should not rewrite predicate with other column)") {
Expand All @@ -35,6 +31,6 @@ class PartitionSkippingStrategySuite
}

test("should not rewrite inapplicable predicate") {
EqualTo(year, Abs(Literal(2023))) shouldNotRewrite ()
EqualTo(indexExpr, Abs(Literal(2023))) shouldNotRewrite ()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy.{DE
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Abs, AttributeReference, EqualTo, Literal}
import org.apache.spark.sql.catalyst.expressions.{Abs, AttributeReference, EqualTo, Expression, Literal}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StringType

class ValueSetSkippingStrategySuite extends SparkFunSuite with FlintSparkSkippingStrategySuite {
class ValueSetSkippingStrategySuite extends FlintSparkSkippingStrategySuite {

override val strategy: FlintSparkSkippingStrategy =
ValueSetSkippingStrategy(columnName = "name", columnType = "string")

private val name = AttributeReference("name", StringType, nullable = false)()
override val indexExpr: Expression = AttributeReference("name", StringType, nullable = false)()

test("should return parameters with default value") {
strategy.parameters shouldBe Map(
Expand Down Expand Up @@ -48,7 +48,7 @@ class ValueSetSkippingStrategySuite extends SparkFunSuite with FlintSparkSkippin
}

test("should rewrite EqualTo(<indexCol>, <value>)") {
EqualTo(name, Literal("hello")) shouldRewriteTo
EqualTo(indexExpr, Literal("hello")) shouldRewriteTo
(isnull(col("name")) || col("name") === "hello")
}

Expand All @@ -60,6 +60,6 @@ class ValueSetSkippingStrategySuite extends SparkFunSuite with FlintSparkSkippin
}

test("should not rewrite inapplicable predicate") {
EqualTo(name, Abs(Literal("hello"))) shouldNotRewrite ()
EqualTo(indexExpr, Abs(Literal("hello"))) shouldNotRewrite ()
}
}

0 comments on commit ea7bb13

Please sign in to comment.