Skip to content

Commit

Permalink
Refactor UT with real Spark table
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Apr 23, 2024
1 parent 39549e1 commit 7823b0b
Showing 1 changed file with 50 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,9 @@ import org.scalatest.matchers.should.Matchers
import org.scalatestplus.mockito.MockitoSugar.mock

import org.apache.spark.FlintSuite
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parseExpression
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType

class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers {

Expand All @@ -39,86 +34,86 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers {

override protected def beforeAll(): Unit = {
super.beforeAll()
sql(s"CREATE TABLE $testTable (name STRING, age INT) USING JSON")

clientBuilder
.when(() => FlintClientBuilder.build(any(classOf[FlintOptions])))
.thenReturn(client)
when(flint.spark).thenReturn(spark)
}

override protected def afterAll(): Unit = {
sql(s"DROP TABLE $testTable")
clientBuilder.close()
super.afterAll()
}

test("should not apply if no index present") {
assertFlintQueryRewriter
.withTable(StructType.fromDDL("name STRING, age INT"))
.withQuery(s"SELECT name, age FROM $testTable")
.assertIndexNotUsed()
}

test("should not apply if some columns in project are not covered") {
assertFlintQueryRewriter
.withTable(StructType.fromDDL("name STRING, age INT, city STRING"))
.withProject("name", "age", "city") // city is not covered
.withIndex(
new FlintSparkCoveringIndex(
indexName = "partial",
tableName = testTable,
indexedColumns = Map("name" -> "string", "age" -> "int")))
.assertIndexNotUsed()
// Covering index doesn't column age
Seq(
s"SELECT name, age FROM $testTable",
s"SELECT name FROM $testTable WHERE age = 30",
s"SELECT COUNT(*) FROM $testTable GROUP BY age").foreach { query =>
test(s"should not apply if columns is not covered in $query") {
assertFlintQueryRewriter
.withQuery(query)
.withIndex(
new FlintSparkCoveringIndex(
indexName = "partial",
tableName = testTable,
indexedColumns = Map("name" -> "string")))
.assertIndexNotUsed()
}
}

test("should not apply if some columns in filter are not covered") {
// Covering index covers all columns
Seq(
s"SELECT * FROM $testTable",
s"SELECT name, age FROM $testTable",
s"SELECT name FROM $testTable WHERE age = 30",
s"SELECT COUNT(*) FROM $testTable GROUP BY age",
s"SELECT name, COUNT(*) FROM $testTable WHERE age > 30 GROUP BY name").foreach { query =>
test(s"should apply when all columns in $query") {
assertFlintQueryRewriter
.withQuery(query)
.withIndex(
new FlintSparkCoveringIndex(
indexName = "all",
tableName = testTable,
indexedColumns = Map("name" -> "string", "age" -> "int")))
.assertIndexUsed(getFlintIndexName("all", testTable))
}
}

test("should apply if all columns are covered by one of the covering indexes") {
assertFlintQueryRewriter
.withTable(StructType.fromDDL("name STRING, age INT, city STRING"))
.withFilter("city = 'Seattle'")
.withProject("name", "age")
.withQuery(s"SELECT name FROM $testTable")
.withIndex(
new FlintSparkCoveringIndex(
indexName = "all",
indexName = "age",
tableName = testTable,
indexedColumns = Map("name" -> "string", "age" -> "int")))
.assertIndexNotUsed()
}

test("should apply when all columns are covered") {
assertFlintQueryRewriter
.withTable(StructType.fromDDL("name STRING, age INT"))
.withProject("name", "age")
indexedColumns = Map("age" -> "int")))
.withIndex(
new FlintSparkCoveringIndex(
indexName = "all",
indexName = "name",
tableName = testTable,
indexedColumns = Map("name" -> "string", "age" -> "int")))
.assertIndexUsed(getFlintIndexName("all", testTable))
indexedColumns = Map("name" -> "string")))
.assertIndexUsed(getFlintIndexName("name", testTable))
}

private def assertFlintQueryRewriter: AssertionHelper = new AssertionHelper

class AssertionHelper {
private var schema: StructType = _
private var plan: LogicalPlan = _
private var indexes: Seq[FlintSparkCoveringIndex] = Seq()

def withTable(schema: StructType): AssertionHelper = {
val baseRelation = mock[BaseRelation]
val table = mock[CatalogTable]
when(baseRelation.schema).thenReturn(schema)
when(table.qualifiedName).thenReturn(testTable)

this.schema = schema
this.plan = LogicalRelation(baseRelation, table)
this
}

def withProject(colNames: String*): AssertionHelper = {
val output = colNames.map(name => AttributeReference(name, schema(name).dataType)())
this.plan = Project(output, plan)
this
}

def withFilter(predicate: String): AssertionHelper = {
this.plan = Filter(parseExpression(predicate), plan)
def withQuery(query: String): AssertionHelper = {
this.plan = sql(query).queryExecution.analyzed
this
}

Expand Down Expand Up @@ -149,7 +144,8 @@ class ApplyFlintSparkCoveringIndexSuite extends FlintSuite with Matchers {
Matcher { (plan: LogicalPlan) =>
val result = plan.exists {
case LogicalRelation(_, _, Some(table), _) =>
table.qualifiedName == testTable
// Table name in logical relation doesn't have catalog name
table.qualifiedName == testTable.split('.').drop(1).mkString(".")
case _ => false
}

Expand Down

0 comments on commit 7823b0b

Please sign in to comment.