Skip to content

Commit

Permalink
fix UT
Browse files Browse the repository at this point in the history
Signed-off-by: Peng Huo <[email protected]>
  • Loading branch information
penghuo committed May 11, 2024
1 parent 331e995 commit 03d5057
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 107 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers {
Matcher { (plan: LogicalPlan) =>
val result = plan.exists {
case LogicalRelation(_, _, Some(table), _) =>
// Table name in logical relation doesn't have catalog name
table.qualifiedName == expectedTableName.split('.').drop(1).mkString(".")
// Since Spark 3.4, Table name in logical relation have catalog name
table.qualifiedName == expectedTableName
case _ => false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,112 +7,107 @@ package org.opensearch.flint.spark.skipping

import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.opensearch.flint.core.{FlintClient, FlintClientBuilder, FlintOptions}
import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry
import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry.IndexState.{DELETED, IndexState, REFRESHING}
import org.opensearch.flint.spark.FlintSpark
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, SKIPPING_INDEX_TYPE}
import org.opensearch.flint.spark.{FlintSpark, FlintSparkIndexOptions}
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.SkippingKind
import org.scalatest.matchers.{Matcher, MatchResult}
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.mockito.MockitoSugar.mock

import org.apache.spark.SparkFunSuite
import org.apache.spark.FlintSuite
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType}
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression, ExprId, Literal, Or}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project, SubqueryAlias}
import org.apache.spark.sql.execution.datasources.{FileIndex, HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class ApplyFlintSparkSkippingIndexSuite extends SparkFunSuite with Matchers {
class ApplyFlintSparkSkippingIndexSuite extends FlintSuite with Matchers {

/** Test table and index */
private val testTable = "spark_catalog.default.apply_skipping_index_test"
private val testIndex = getSkippingIndexName(testTable)
private val testSchema = StructType(
Seq(
StructField("name", StringType, nullable = false),
StructField("age", IntegerType, nullable = false),
StructField("address", StringType, nullable = false)))

/** Resolved column reference used in filtering condition */
private val nameCol =
AttributeReference("name", StringType, nullable = false)(exprId = ExprId(1))
private val ageCol =
AttributeReference("age", IntegerType, nullable = false)(exprId = ExprId(2))
private val addressCol =
AttributeReference("address", StringType, nullable = false)(exprId = ExprId(3))

// Mock FlintClient to avoid looking for real OpenSearch cluster
private val clientBuilder = mockStatic(classOf[FlintClientBuilder])
private val client = mock[FlintClient](RETURNS_DEEP_STUBS)

/** Mock FlintSpark which is required by the rule */
private val flint = mock[FlintSpark]

override protected def beforeAll(): Unit = {
super.beforeAll()
sql(s"CREATE TABLE $testTable (name STRING, age INT, address STRING) USING JSON")

// Mock static create method in FlintClientBuilder used by Flint data source
clientBuilder
.when(() => FlintClientBuilder.build(any(classOf[FlintOptions])))
.thenReturn(client)
when(flint.spark).thenReturn(spark)
}

override protected def afterAll(): Unit = {
sql(s"DROP TABLE $testTable")
clientBuilder.close()
super.afterAll()
}

test("should not rewrite query if no skipping index") {
assertFlintQueryRewriter()
.withSourceTable(testTable, testSchema)
.withFilter(EqualTo(nameCol, Literal("hello")))
.withQuery(s"SELECT * FROM $testTable WHERE name = 'hello'")
.withNoSkippingIndex()
.shouldNotRewrite()
}

test("should not rewrite query if filter condition is disjunction") {
assertFlintQueryRewriter()
.withSourceTable(testTable, testSchema)
.withFilter(Or(EqualTo(nameCol, Literal("hello")), EqualTo(ageCol, Literal(30))))
.withSkippingIndex(testIndex, REFRESHING, "name", "age")
.withQuery(s"SELECT * FROM $testTable WHERE name = 'hello' or age = 30")
.withSkippingIndex(REFRESHING, "name", "age")
.shouldNotRewrite()
}

test("should not rewrite query if filter condition contains disjunction") {
assertFlintQueryRewriter()
.withSourceTable(testTable, testSchema)
.withFilter(
And(
Or(EqualTo(nameCol, Literal("hello")), EqualTo(ageCol, Literal(30))),
EqualTo(ageCol, Literal(30))))
.withSkippingIndex(testIndex, REFRESHING, "name", "age")
.withQuery(
s"SELECT * FROM $testTable WHERE (name = 'hello' or age = 30) and address = 'Seattle'")
.withSkippingIndex(REFRESHING, "name", "age")
.shouldNotRewrite()
}

test("should rewrite query with skipping index") {
assertFlintQueryRewriter()
.withSourceTable(testTable, testSchema)
.withFilter(EqualTo(nameCol, Literal("hello")))
.withSkippingIndex(testIndex, REFRESHING, "name")
.withQuery(s"SELECT * FROM $testTable WHERE name = 'hello'")
.withSkippingIndex(REFRESHING, "name")
.shouldPushDownAfterRewrite(col("name") === "hello")
}

test("should not rewrite query with deleted skipping index") {
assertFlintQueryRewriter()
.withSourceTable(testTable, testSchema)
.withFilter(EqualTo(nameCol, Literal("hello")))
.withSkippingIndex(testIndex, DELETED, "name")
.withQuery(s"SELECT * FROM $testTable WHERE name = 'hello'")
.withSkippingIndex(DELETED, "name")
.shouldNotRewrite()
}

test("should only push down filter condition with indexed column") {
assertFlintQueryRewriter()
.withSourceTable(testTable, testSchema)
.withFilter(And(EqualTo(nameCol, Literal("hello")), EqualTo(ageCol, Literal(30))))
.withSkippingIndex(testIndex, REFRESHING, "name")
.withQuery(s"SELECT * FROM $testTable WHERE name = 'hello' and age = 30")
.withSkippingIndex(REFRESHING, "name")
.shouldPushDownAfterRewrite(col("name") === "hello")
}

test("should push down all filter conditions with indexed column") {
assertFlintQueryRewriter()
.withSourceTable(testTable, testSchema)
.withFilter(And(EqualTo(nameCol, Literal("hello")), EqualTo(ageCol, Literal(30))))
.withSkippingIndex(testIndex, REFRESHING, "name", "age")
.withQuery(s"SELECT * FROM $testTable WHERE name = 'hello' and age = 30")
.withSkippingIndex(REFRESHING, "name", "age")
.shouldPushDownAfterRewrite(col("name") === "hello" && col("age") === 30)

assertFlintQueryRewriter()
.withSourceTable(testTable, testSchema)
.withFilter(
And(
EqualTo(nameCol, Literal("hello")),
And(EqualTo(ageCol, Literal(30)), EqualTo(addressCol, Literal("Seattle")))))
.withSkippingIndex(testIndex, REFRESHING, "name", "age", "address")
.withQuery(
s"SELECT * FROM $testTable WHERE name = 'hello' and (age = 30 and address = 'Seattle')")
.withSkippingIndex(REFRESHING, "name", "age", "address")
.shouldPushDownAfterRewrite(
col("name") === "hello" && col("age") === 30 && col("address") === "Seattle")
}
Expand All @@ -122,46 +117,21 @@ class ApplyFlintSparkSkippingIndexSuite extends SparkFunSuite with Matchers {
}

private class AssertionHelper {
private val flint = {
val mockFlint = mock[FlintSpark](RETURNS_DEEP_STUBS)
when(mockFlint.spark.sessionState.catalogManager.currentCatalog.name())
.thenReturn("spark_catalog")
mockFlint
}
private val rule = new ApplyFlintSparkSkippingIndex(flint)
private var relation: LogicalRelation = _
private var plan: LogicalPlan = _

def withSourceTable(fullname: String, schema: StructType): AssertionHelper = {
val table = CatalogTable(
identifier = TableIdentifier(fullname.split('.')(1), Some(fullname.split('.')(0))),
tableType = CatalogTableType.EXTERNAL,
storage = CatalogStorageFormat.empty,
schema = null)
relation = LogicalRelation(mockBaseRelation(schema), table)
this
}

def withFilter(condition: Expression): AssertionHelper = {
val filter = Filter(condition, relation)
val project = Project(Seq(), filter)
plan = SubqueryAlias("alb_logs", project)
def withQuery(query: String): AssertionHelper = {
this.plan = sql(query).queryExecution.optimizedPlan
this
}

def withSkippingIndex(
indexName: String,
indexState: IndexState,
indexCols: String*): AssertionHelper = {
val skippingIndex = mock[FlintSparkSkippingIndex]
when(skippingIndex.kind).thenReturn(SKIPPING_INDEX_TYPE)
when(skippingIndex.name()).thenReturn(indexName)
when(skippingIndex.indexedColumns).thenReturn(indexCols.map(FakeSkippingStrategy))

// Mock index log entry with the given state
val logEntry = mock[FlintMetadataLogEntry]
when(logEntry.state).thenReturn(indexState)
when(skippingIndex.latestLogEntry).thenReturn(Some(logEntry))
def withSkippingIndex(indexState: IndexState, indexCols: String*): AssertionHelper = {
val skippingIndex = new FlintSparkSkippingIndex(
tableName = testTable,
indexedColumns = indexCols.map(FakeSkippingStrategy),
options = FlintSparkIndexOptions.empty,
latestLogEntry = Some(
new FlintMetadataLogEntry("id", 0, 0, 0, indexState, "spark_catalog", "")))

when(flint.describeIndex(any())).thenReturn(Some(skippingIndex))
this
Expand All @@ -181,23 +151,6 @@ class ApplyFlintSparkSkippingIndexSuite extends SparkFunSuite with Matchers {
}
}

private def mockBaseRelation(schema: StructType): BaseRelation = {
val fileIndex = mock[FileIndex]
val baseRelation: HadoopFsRelation = mock[HadoopFsRelation]
when(baseRelation.location).thenReturn(fileIndex)
when(baseRelation.schema).thenReturn(schema)

// Mock baseRelation.copy(location = FlintFileIndex)
doAnswer((invocation: InvocationOnMock) => {
val location = invocation.getArgument[FileIndex](0)
val relationCopy: HadoopFsRelation = mock[HadoopFsRelation]
when(relationCopy.location).thenReturn(location)
relationCopy
}).when(baseRelation).copy(any(), any(), any(), any(), any(), any())(any())

baseRelation
}

private def pushDownFilterToIndexScan(expect: Column): Matcher[LogicalPlan] = {
Matcher { (plan: LogicalPlan) =>
val useFlintSparkSkippingFileIndex = plan.exists {
Expand Down

0 comments on commit 03d5057

Please sign in to comment.