Skip to content

Commit

Permalink
Refactor UT assertions
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Nov 26, 2024
1 parent 7fb194c commit 796b45c
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import org.opensearch.flint.core.metadata.FlintJsonHelper._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.catalyst.util.quoteIfNeeded
import org.apache.spark.sql.flint.datatype.FlintDataType
import org.apache.spark.sql.functions.{col, concat_ws, expr, sha1, to_json}
import org.apache.spark.sql.types.{MapType, StructType}
Expand Down Expand Up @@ -150,9 +151,9 @@ object FlintSparkIndex extends Logging {
val allOutputCols = df.schema.fields.map { field =>
field.dataType match {
case _: StructType | _: MapType =>
to_json(col(field.name))
to_json(col(quoteIfNeeded(field.name)))
case _ =>
col(field.name)
col(quoteIfNeeded(field.name))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
package org.apache.spark

import org.opensearch.flint.spark.FlintSparkExtensions
import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN

import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.expressions.{Alias, CodegenObjectFactoryMode, Expression}
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.flint.config.{FlintConfigEntry, FlintSparkConf}
import org.apache.spark.sql.flint.config.FlintSparkConf.{EXTERNAL_SCHEDULER_ENABLED, HYBRID_SCAN_ENABLED, METADATA_CACHE_WRITE}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -68,4 +71,15 @@ trait FlintSuite extends SharedSparkSession {
setFlintSparkConf(METADATA_CACHE_WRITE, "false")
}
}

protected implicit class DataFrameExtensions(val df: DataFrame) {

def idColumn(): Option[Expression] = {
df.queryExecution.logical.collectFirst { case Project(projectList, _) =>
projectList.collectFirst { case Alias(child, ID_COLUMN) =>
child
}
}.flatten
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ import org.scalatest.matchers.should.Matchers

import org.apache.spark.FlintSuite
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Add, ConcatWs, Literal, Sha1, StructsToJson}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String

class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers {

Expand All @@ -19,6 +23,7 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers {
val options = new FlintSparkIndexOptions(Map("id_expression" -> "id + 10"))

val resultDf = addIdColumn(df, options)
resultDf.idColumn() shouldBe Some(Add(UnresolvedAttribute("id"), Literal(10)))
checkAnswer(resultDf.select(ID_COLUMN), Seq(Row(11), Row(12)))
}

Expand Down Expand Up @@ -47,7 +52,37 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers {
val options = FlintSparkIndexOptions.empty

val resultDf = addIdColumn(df, options)
resultDf.columns should contain(ID_COLUMN)
resultDf.idColumn() shouldBe Some(
Sha1(
ConcatWs(
Seq(
Literal(UTF8String.fromString("\0"), StringType),
UnresolvedAttribute(Seq("name")),
UnresolvedAttribute(Seq("count"))))))
resultDf.select(ID_COLUMN).distinct().count() shouldBe 2
}

test("should add ID column for aggregated query with quoted alias") {
val df = spark
.createDataFrame(
sparkContext.parallelize(
Seq(
Row(1, "Alice", Row("WA", "Seattle")),
Row(2, "Bob", Row("OR", "Portland")),
Row(3, "Alice", Row("WA", "Seattle")))),
StructType.fromDDL("id INT, name STRING, address STRUCT<state: STRING, city: String>"))
.toDF("id", "name", "address")
.groupBy(col("name").as("test.name"), col("address").as("test.address"))
.count()
val options = FlintSparkIndexOptions.empty

val resultDf = addIdColumn(df, options)
resultDf.idColumn() shouldBe Some(
Sha1(ConcatWs(Seq(
Literal(UTF8String.fromString("\0"), StringType),
UnresolvedAttribute(Seq("test.name")),
new StructsToJson(UnresolvedAttribute(Seq("test.address"))),
UnresolvedAttribute(Seq("count"))))))
resultDf.select(ID_COLUMN).distinct().count() shouldBe 2
}

Expand Down Expand Up @@ -92,7 +127,20 @@ class FlintSparkIndexSuite extends QueryTest with FlintSuite with Matchers {
val options = FlintSparkIndexOptions.empty

val resultDf = addIdColumn(aggregatedDf, options)
resultDf.columns should contain(ID_COLUMN)
resultDf.idColumn() shouldBe Some(
Sha1(ConcatWs(Seq(
Literal(UTF8String.fromString("\0"), StringType),
UnresolvedAttribute(Seq("boolean_col")),
UnresolvedAttribute(Seq("string_col")),
UnresolvedAttribute(Seq("long_col")),
UnresolvedAttribute(Seq("int_col")),
UnresolvedAttribute(Seq("double_col")),
UnresolvedAttribute(Seq("float_col")),
UnresolvedAttribute(Seq("timestamp_col")),
UnresolvedAttribute(Seq("date_col")),
new StructsToJson(UnresolvedAttribute(Seq("struct_col"))),
UnresolvedAttribute(Seq("subfield2")),
UnresolvedAttribute(Seq("count"))))))
resultDf.select(ID_COLUMN).distinct().count() shouldBe 1
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@ package org.opensearch.flint.spark.mv

import scala.collection.JavaConverters.{mapAsJavaMapConverter, mapAsScalaMapConverter}

import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN
import org.opensearch.flint.spark.FlintSparkIndexOptions
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE
import org.opensearch.flint.spark.mv.FlintSparkMaterializedViewSuite.{streamingRelation, DataFrameIdColumnExtractor, StreamingDslLogicalPlan}
import org.opensearch.flint.spark.mv.FlintSparkMaterializedViewSuite.{streamingRelation, StreamingDslLogicalPlan}
import org.scalatest.matchers.should.Matchers._
import org.scalatestplus.mockito.MockitoSugar.mock

Expand All @@ -19,8 +18,8 @@ import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation}
import org.apache.spark.sql.catalyst.dsl.expressions.{intToLiteral, stringToLiteral, DslAttr, DslExpression, StringToAttributeConversionHelper}
import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, ConcatWs, Expression, Literal, Sha1}
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.expressions.{Attribute, ConcatWs, Literal, Sha1}
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan}
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
Expand Down Expand Up @@ -385,15 +384,4 @@ object FlintSparkMaterializedViewSuite {
logicalPlan)
}
}

implicit class DataFrameIdColumnExtractor(val df: DataFrame) {

def idColumn(): Option[Expression] = {
df.queryExecution.logical.collectFirst { case Project(projectList, _) =>
projectList.collectFirst { case Alias(child, ID_COLUMN) =>
child
}
}.flatten
}
}
}

0 comments on commit 796b45c

Please sign in to comment.