diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala index 6e8558c65..ae6b0e0a8 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/mv/FlintSparkMaterializedViewSuite.scala @@ -9,6 +9,7 @@ import scala.collection.JavaConverters.mapAsJavaMapConverter import org.opensearch.flint.spark.FlintSparkIndexOptions import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE +import org.opensearch.flint.spark.mv.FlintSparkMaterializedViewSuite.{streamingRelation, StreamingDslLogicalPlan} import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, the} import org.scalatestplus.mockito.MockitoSugar.mock @@ -16,13 +17,19 @@ import org.apache.spark.FlintSuite import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.dsl.expressions.{count, DslAttr, DslExpression, StringToAttributeConversionHelper} -import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, EventTimeWatermark, Filter, LogicalPlan} +import org.apache.spark.sql.catalyst.dsl.expressions.{count, intToLiteral, stringToLiteral, DslAttr, DslExpression, StringToAttributeConversionHelper} +import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.unsafe.types.UTF8String +/** + * This UT include test cases for building API which make use of real SparkSession. This is + * because SparkSession.sessionState is private val and hard to mock but it's required in + * logicalPlanToDataFrame() -> DataRows.of(). + */ class FlintSparkMaterializedViewSuite extends FlintSuite { val testMvName = "spark_catalog.default.mv" @@ -94,12 +101,13 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { val mv = FlintSparkMaterializedView(testMvName, testQuery, Map.empty) val actualPlan = mv.buildStream(spark).queryExecution.logical - assert( - actualPlan.sameSemantics(Aggregate( - Seq($"TUMBLE".function($"time", Literal("1 Minute"))), - Seq($"window.start" as "startTime", count(Literal(1)) as "count"), - watermark($"time", "0 Minute", streamingRelation(testTable))))) + actualPlan.sameSemantics( + streamingRelation(testTable) + .watermark($"time", "0 Minute") + .groupBy($"TUMBLE".function($"time", "1 Minute"))( + $"window.start" as "startTime", + count(1) as "count"))) } } @@ -120,63 +128,71 @@ class FlintSparkMaterializedViewSuite extends FlintSuite { val mv = FlintSparkMaterializedView(testMvName, testQuery, Map.empty) val actualPlan = mv.buildStream(spark).queryExecution.logical - assert( - actualPlan.sameSemantics(Aggregate( - Seq($"TUMBLE".function($"time", Literal("1 Minute"))), - Seq($"window.start" as "startTime", count(Literal(1)) as "count"), - watermark( - $"time", - "0 Minute", - Filter($"age" > Literal(30), streamingRelation(testTable)))))) + actualPlan.sameSemantics( + streamingRelation(testTable) + .where($"age" > 30) + .watermark($"time", "0 Minute") + .groupBy($"TUMBLE".function($"time", "1 Minute"))( + $"window.start" as "startTime", + count(1) as "count"))) } } - test("build stream should fail if there is aggregation without windowing function") { + test("build stream with non-aggregate query") { val testTable = "mv_build_test" withTable(testTable) { sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") val mv = FlintSparkMaterializedView( testMvName, - s"SELECT name, COUNT(*) AS count FROM $testTable GROUP BY name", + s"SELECT name, age FROM $testTable WHERE age > 30", Map.empty) + val actualPlan = mv.buildStream(spark).queryExecution.logical - the[IllegalStateException] thrownBy - mv.buildStream(spark) + assert( + actualPlan.sameSemantics( + streamingRelation(testTable) + .where($"age" > 30) + .select($"name", $"age"))) } } - // TODO: should we add this restriction? - ignore("build stream should fail if there is no aggregation") { + test("build stream should fail if there is aggregation without windowing function") { val testTable = "mv_build_test" withTable(testTable) { sql(s"CREATE TABLE $testTable (time TIMESTAMP, name STRING, age INT) USING CSV") val mv = FlintSparkMaterializedView( testMvName, - s"SELECT COUNT(*) AS count FROM $testTable", + s"SELECT name, COUNT(*) AS count FROM $testTable GROUP BY name", Map.empty) the[IllegalStateException] thrownBy mv.buildStream(spark) } } +} - private def streamingRelation(tableName: String): UnresolvedRelation = { +/** + * Helper method that extends LogicalPlan with more methods by Scala implicit class. + */ +object FlintSparkMaterializedViewSuite { + + def streamingRelation(tableName: String): UnresolvedRelation = { UnresolvedRelation( TableIdentifier(tableName), CaseInsensitiveStringMap.empty(), isStreaming = true) } - private def watermark( - colName: Attribute, - interval: String, - child: LogicalPlan): EventTimeWatermark = { - EventTimeWatermark( - colName, - IntervalUtils.stringToInterval(UTF8String.fromString(interval)), - child) + implicit class StreamingDslLogicalPlan(val logicalPlan: LogicalPlan) { + + def watermark(colName: Attribute, interval: String): DslLogicalPlan = { + EventTimeWatermark( + colName, + IntervalUtils.stringToInterval(UTF8String.fromString(interval)), + logicalPlan) + } } }