From 1ae9fdc6378e071c3437a5bb52a2ef6a248215ff Mon Sep 17 00:00:00 2001 From: YANGDB Date: Tue, 26 Sep 2023 13:17:23 -0700 Subject: [PATCH] update ppl tests & IT tests Signed-off-by: YANGDB --- ...ntSparkPPLAggregationWithSpanITSuite.scala | 291 ++++++ .../FlintSparkPPLAggregationsITSuite.scala | 423 +++++++++ .../spark/FlintSparkPPLFiltersITSuite.scala | 459 ++++++++++ .../flint/spark/FlintSparkPPLITSuite.scala | 849 +----------------- 4 files changed, 1185 insertions(+), 837 deletions(-) create mode 100644 integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLAggregationWithSpanITSuite.scala create mode 100644 integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLAggregationsITSuite.scala create mode 100644 integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLFiltersITSuite.scala diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLAggregationWithSpanITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLAggregationWithSpanITSuite.scala new file mode 100644 index 000000000..d0851c4f0 --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLAggregationWithSpanITSuite.scala @@ -0,0 +1,291 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, Divide, Floor, Literal, Multiply, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.{QueryTest, Row} + +class FlintSparkPPLAggregationWithSpanITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + // Update table creation + sql(s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | age INT, + | state STRING, + | country STRING + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\t' + | ) + | PARTITIONED BY ( + | year INT, + | month INT + | ) + |""".stripMargin) + + // Update data insertion + sql(s""" + | INSERT INTO $testTable + | PARTITION (year=2023, month=4) + | VALUES ('Jake', 70, 'California', 'USA'), + | ('Hello', 30, 'New York', 'USA'), + | ('John', 25, 'Ontario', 'Canada'), + | ('Jane', 20, 'Quebec', 'Canada') + | """.stripMargin) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + /** + * | age_span | count_age | + * |:---------|----------:| + * | 20 | 2 | + * | 30 | 1 | + * | 70 | 1 | + */ + test("create ppl simple count age by span of interval of 10 years query test ") { + val frame = sql(s""" + | source = $testTable| stats count(age) by span(age, 10) as age_span + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(1, 70L), Row(1, 30L), Row(2, 20L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + /** + * | age_span | average_age | + * |:---------|------------:| + * | 20 | 22.5 | + * | 30 | 30 | + * | 70 | 70 | + */ + test("create ppl simple avg age by span of interval of 10 years query test ") { + val frame = sql(s""" + | source = $testTable| stats avg(age) by span(age, 10) as age_span + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(70d, 70L), Row(30d, 30L), Row(22.5d, 20L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple avg age by span of interval of 10 years with head (limit) query test ") { + val frame = sql(s""" + | source = $testTable| stats avg(age) by span(age, 10) as age_span | head 2 + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 2) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) + val projectPlan = Project(star, aggregatePlan) + val expectedPlan = Limit(Literal(2), projectPlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + /** + * | age_span | country | average_age | + * |:---------|:--------|:------------| + * | 20 | Canada | 22.5 | + * | 30 | USA | 30 | + * | 70 | USA | 70 | + */ + test("create ppl average age by span of interval of 10 years group by country query test ") { + val frame = sql(s""" + | source = $testTable| stats avg(age) by span(age, 10) as age_span, country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = + Array(Row(70.0d, "USA", 70L), Row(30.0d, "USA", 30L), Row(22.5d, "Canada", 20L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val countryField = UnresolvedAttribute("country") + val countryAlias = Alias(countryField, "country")() + + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = + Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl average age by span of interval of 10 years group by country head (limit) 2 query test ") { + val frame = sql(s""" + | source = $testTable| stats avg(age) by span(age, 10) as age_span, country | head 3 + | """.stripMargin) + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = + Array(Row(70.0d, "USA", 70L), Row(30.0d, "USA", 30L), Row(22.5d, "Canada", 20L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val countryField = UnresolvedAttribute("country") + val countryAlias = Alias(countryField, "country")() + + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = + Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), table) + val projectPlan = Project(star, aggregatePlan) + val expectedPlan = Limit(Literal(3), projectPlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl average age by span of interval of 10 years group by country head (limit) 2 query and sort by test ") { + val frame = sql(s""" + | source = $testTable| stats avg(age) by span(age, 10) as age_span, country | sort - age_span | head 2 + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 2) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val countryField = UnresolvedAttribute("country") + val countryAlias = Alias(countryField, "country")() + + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = + Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), table) + val projectPlan = Project(star, aggregatePlan) + val expectedPlan = Limit(Literal(2), projectPlan) + val sortedPlan: LogicalPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("age_span"), Descending)), + global = true, + expectedPlan) + // Compare the two plans + assert(compareByString(sortedPlan) === compareByString(logicalPlan)) + } +} + diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLAggregationsITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLAggregationsITSuite.scala new file mode 100644 index 000000000..181fc8ee1 --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLAggregationsITSuite.scala @@ -0,0 +1,423 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, Divide, EqualTo, Floor, GreaterThan, LessThan, LessThanOrEqual, Literal, Multiply, Not, Or, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.{QueryTest, Row} + +class FlintSparkPPLAggregationsITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + // Update table creation + sql( + s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | age INT, + | state STRING, + | country STRING + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\t' + | ) + | PARTITIONED BY ( + | year INT, + | month INT + | ) + |""".stripMargin) + + // Update data insertion + sql( + s""" + | INSERT INTO $testTable + | PARTITION (year=2023, month=4) + | VALUES ('Jake', 70, 'California', 'USA'), + | ('Hello', 30, 'New York', 'USA'), + | ('John', 25, 'Ontario', 'Canada'), + | ('Jane', 20, 'Quebec', 'Canada') + | """.stripMargin) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("create ppl simple age avg query test") { + val frame = sql( + s""" + | source = $testTable| stats avg(age) + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(36.25)) + + // Compare the results + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val aggregateExpressions = + Seq(Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")()) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age avg query with filter test") { + val frame = sql( + s""" + | source = $testTable| where age < 50 | stats avg(age) + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(25)) + + // Compare the results + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val filterExpr = LessThan(ageField, Literal(50)) + val filterPlan = Filter(filterExpr, table) + val aggregateExpressions = + Seq(Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")()) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age avg group by country query test ") { + val frame = sql( + s""" + | source = $testTable| stats avg(age) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(22.5, "Canada"), Row(50.0, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val countryAlias = Alias(countryField, "country")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, countryAlias), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age avg group by country head (limit) query test ") { + val frame = sql( + s""" + | source = $testTable| stats avg(age) by country | head 1 + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 1) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val projectPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan) + val expectedPlan = Limit(Literal(1), projectPlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age max group by country query test ") { + val frame = sql( + s""" + | source = $testTable| stats max(age) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(70, "USA"), Row(25, "Canada")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("MAX"), Seq(ageField), isDistinct = false), "max(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age min group by country query test ") { + val frame = sql( + s""" + | source = $testTable| stats min(age) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(30, "USA"), Row(20, "Canada")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("MIN"), Seq(ageField), isDistinct = false), "min(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age sum group by country query test ") { + val frame = sql( + s""" + | source = $testTable| stats sum(age) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(100L, "USA"), Row(45L, "Canada")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("SUM"), Seq(ageField), isDistinct = false), "sum(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age sum group by country order by age query test with sort ") { + val frame = sql( + s""" + | source = $testTable| stats sum(age) by country | sort country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(45L, "Canada"), Row(100L, "USA")) + + // Compare the results + assert(results === expectedResults) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("SUM"), Seq(ageField), isDistinct = false), "sum(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val expectedPlan = Project(star, aggregatePlan) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("country"), Ascending)), global = true, expectedPlan) + // Compare the two plans + assert(compareByString(sortedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age count group by country query test ") { + val frame = sql( + s""" + | source = $testTable| stats count(age) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(2L, "Canada"), Row(2L, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert( + results.sorted.sameElements(expectedResults.sorted), + s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert( + compareByString(expectedPlan) === compareByString(logicalPlan), + s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}") + } + + test("create ppl simple age avg group by country with state filter query test ") { + val frame = sql( + s""" + | source = $testTable| where state != 'Quebec' | stats avg(age) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(25.0, "Canada"), Row(50.0, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val stateField = UnresolvedAttribute("state") + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val productAlias = Alias(countryField, "country")() + val filterExpr = Not(EqualTo(stateField, Literal("Quebec"))) + val filterPlan = Filter(filterExpr, table) + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } +} + diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLFiltersITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLFiltersITSuite.scala new file mode 100644 index 000000000..37aef3e71 --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLFiltersITSuite.scala @@ -0,0 +1,459 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, Divide, EqualTo, Floor, GreaterThan, LessThan, LessThanOrEqual, Literal, Multiply, Not, Or, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.{QueryTest, Row} + +class FlintSparkPPLFiltersITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + // Update table creation + sql( + s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | age INT, + | state STRING, + | country STRING + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\t' + | ) + | PARTITIONED BY ( + | year INT, + | month INT + | ) + |""".stripMargin) + + // Update data insertion + sql( + s""" + | INSERT INTO $testTable + | PARTITION (year=2023, month=4) + | VALUES ('Jake', 70, 'California', 'USA'), + | ('Hello', 30, 'New York', 'USA'), + | ('John', 25, 'Ontario', 'Canada'), + | ('Jane', 20, 'Quebec', 'Canada') + | """.stripMargin) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("create ppl simple age literal equal filter query with two fields result test") { + val frame = sql( + s""" + | source = $testTable age=25 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("John", 25)) + // Compare the results + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val filterExpr = EqualTo(UnresolvedAttribute("age"), Literal(25)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age literal greater than filter AND country not equal filter query with two fields result test") { + val frame = sql( + s""" + | source = $testTable age>10 and country != 'USA' | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("John", 25), Row("Jane", 20)) + // Compare the results + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val filterExpr = And( + GreaterThan(UnresolvedAttribute("age"), Literal(10)), Not(EqualTo(UnresolvedAttribute("country"), Literal("USA")))) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age literal greater than filter AND country not equal filter query with two fields sorted result test") { + val frame = sql( + s""" + | source = $testTable age>10 and country != 'USA' | sort - age | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("John", 25), Row("Jane", 20)) + // Compare the results + assert(results === expectedResults) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val filterExpr = And( + GreaterThan(UnresolvedAttribute("age"), Literal(10)), + Not(EqualTo(UnresolvedAttribute("country"), Literal("USA")))) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Descending)), global = true, expectedPlan) + // Compare the two plans + assert(compareByString(sortedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age literal equal than filter OR country not equal filter query with two fields result test") { + val frame = sql( + s""" + | source = $testTable age<=20 OR country = 'USA' | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Jane", 20), Row("Jake", 70), Row("Hello", 30)) + // Compare the results + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val filterExpr = Or( + LessThanOrEqual(UnresolvedAttribute("age"), Literal(20)), + EqualTo(UnresolvedAttribute("country"), Literal("USA"))) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age literal equal than filter OR country not equal filter query with two fields result and head (limit) test") { + val frame = sql( + s""" + | source = $testTable age<=20 OR country = 'USA' | fields name, age | head 1 + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 1) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val filterExpr = Or( + LessThanOrEqual(UnresolvedAttribute("age"), Literal(20)), + EqualTo(UnresolvedAttribute("country"), Literal("USA"))) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val projectPlan = Project(Seq(UnresolvedStar(None)), Project(projectList, filterPlan)) + val expectedPlan = Limit(Literal(1), projectPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age literal greater than filter query with two fields result test") { + val frame = sql( + s""" + | source = $testTable age>25 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Jake", 70), Row("Hello", 30)) + // Compare the results + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val filterExpr = GreaterThan(UnresolvedAttribute("age"), Literal(25)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age literal smaller than equals filter query with two fields result test") { + val frame = sql( + s""" + | source = $testTable age<=65 | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30), Row("John", 25), Row("Jane", 20)) + // Compare the results + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val filterExpr = LessThanOrEqual(UnresolvedAttribute("age"), Literal(65)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple age literal smaller than equals filter query with two fields result with sort test") { + val frame = sql( + s""" + | source = $testTable age<=65 | sort name | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30), Row("Jane", 20), Row("John", 25)) + // Compare the results + assert(results === expectedResults) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val filterExpr = LessThanOrEqual(UnresolvedAttribute("age"), Literal(65)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("name"), Ascending)), global = true, expectedPlan) + // Compare the two plans + assert(compareByString(sortedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple name literal equal filter query with two fields result test") { + val frame = sql( + s""" + | source = $testTable name='Jake' | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Jake", 70)) + // Compare the results + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val filterExpr = EqualTo(UnresolvedAttribute("name"), Literal("Jake")) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple name literal not equal filter query with two fields result test") { + val frame = sql( + s""" + | source = $testTable name!='Jake' | fields name, age + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row("Hello", 30), Row("John", 25), Row("Jane", 20)) + + // Compare the results + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val filterExpr = Not(EqualTo(UnresolvedAttribute("name"), Literal("Jake"))) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val expectedPlan = Project(projectList, filterPlan) + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl simple avg age by span of interval of 10 years query test ") { + val frame = sql( + s""" + | source = $testTable| stats avg(age) by span(age, 10) as age_span + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(70d, 70L), Row(30d, 30L), Row(22.5d, 20L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test( + "create ppl simple avg age by span of interval of 10 years with head (limit) query test ") { + val frame = sql( + s""" + | source = $testTable| stats avg(age) by span(age, 10) as age_span | head 2 + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 2) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) + val projectPlan = Project(star, aggregatePlan) + val expectedPlan = Limit(Literal(2), projectPlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + /** + * | age_span | country | average_age | + * |:---------|:--------|:------------| + * | 20 | Canada | 22.5 | + * | 30 | USA | 30 | + * | 70 | USA | 70 | + */ + test("create ppl average age by span of interval of 10 years group by country query test ") { + val dataFrame = spark.sql( + "SELECT FLOOR(age / 10) * 10 AS age_span, country, AVG(age) AS average_age FROM default.flint_ppl_test GROUP BY FLOOR(age / 10) * 10, country ") + dataFrame.collect(); + dataFrame.show() + + val frame = sql( + s""" + | source = $testTable| stats avg(age) by span(age, 10) as age_span, country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = + Array(Row(70.0d, "USA", 70L), Row(30.0d, "USA", 30L), Row(22.5d, "Canada", 20L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) + val countryField = UnresolvedAttribute("country") + val countryAlias = Alias(countryField, "country")() + + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = + Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), table) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + +} diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala index d2448a3f4..f9c9752a7 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala @@ -5,11 +5,11 @@ package org.opensearch.flint.spark -import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, Divide, EqualTo, Floor, GreaterThan, LessThan, LessThanOrEqual, Literal, Multiply, Not, Or, SortOrder} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.{QueryTest, Row} class FlintSparkPPLITSuite extends QueryTest @@ -200,125 +200,10 @@ class FlintSparkPPLITSuite // Compare the two plans assert(expectedPlan === logicalPlan) } - - test("create ppl simple age literal equal filter query with two fields result test") { - val frame = sql(s""" - | source = $testTable age=25 | fields name, age - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = Array(Row("John", 25)) - // Compare the results - // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) - assert(results.sorted.sameElements(expectedResults.sorted)) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val filterExpr = EqualTo(UnresolvedAttribute("age"), Literal(25)) - val filterPlan = Filter(filterExpr, table) - val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) - val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans - assert(expectedPlan === logicalPlan) - } - - test( - "create ppl simple age literal greater than filter AND country not equal filter query with two fields result test") { - val frame = sql(s""" - | source = $testTable age>10 and country != 'USA' | fields name, age - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = Array(Row("John", 25), Row("Jane", 20)) - // Compare the results - // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) - assert(results.sorted.sameElements(expectedResults.sorted)) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val filterExpr = And( - Not(EqualTo(UnresolvedAttribute("country"), Literal("USA"))), - GreaterThan(UnresolvedAttribute("age"), Literal(10))) - val filterPlan = Filter(filterExpr, table) - val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) - val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans - assert(expectedPlan === logicalPlan) - } - - test( - "create ppl simple age literal greater than filter AND country not equal filter query with two fields sorted result test") { - val frame = sql(s""" - | source = $testTable age>10 and country != 'USA' | sort - age | fields name, age - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = Array(Row("John", 25), Row("Jane", 20)) - // Compare the results - assert(results === expectedResults) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val filterExpr = And( - Not(EqualTo(UnresolvedAttribute("country"), Literal("USA"))), - GreaterThan(UnresolvedAttribute("age"), Literal(10))) - val filterPlan = Filter(filterExpr, table) - val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) - val expectedPlan = Project(projectList, filterPlan) - - val sortedPlan: LogicalPlan = - Sort(Seq(SortOrder(UnresolvedAttribute("age"), Descending)), global = true, expectedPlan) - // Compare the two plans - assert(sortedPlan === logicalPlan) - } - - test( - "create ppl simple age literal equal than filter OR country not equal filter query with two fields result test") { - val frame = sql(s""" - | source = $testTable age<=20 OR country = 'USA' | fields name, age - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = Array(Row("Jane", 20), Row("Jake", 70), Row("Hello", 30)) - // Compare the results - // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) - assert(results.sorted.sameElements(expectedResults.sorted)) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val filterExpr = Or( - EqualTo(UnresolvedAttribute("country"), Literal("USA")), - LessThanOrEqual(UnresolvedAttribute("age"), Literal(20))) - val filterPlan = Filter(filterExpr, table) - val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) - val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans - assert(expectedPlan === logicalPlan) - } - - test( - "create ppl simple age literal equal than filter OR country not equal filter query with two fields result and head (limit) test") { + + test("create ppl simple query two with fields and head (limit) with sorting test") { val frame = sql(s""" - | source = $testTable age<=20 OR country = 'USA' | fields name, age | head 1 + | source = $testTable| fields name, age | head 1 | sort age | """.stripMargin) // Retrieve the results @@ -327,725 +212,15 @@ class FlintSparkPPLITSuite // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical + val project = Project( + Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), + UnresolvedRelation(Seq("default", "flint_ppl_test"))) // Define the expected logical plan - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val filterExpr = Or( - EqualTo(UnresolvedAttribute("country"), Literal("USA")), - LessThanOrEqual(UnresolvedAttribute("age"), Literal(20))) - val filterPlan = Filter(filterExpr, table) - val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) - val projectPlan = Project(Seq(UnresolvedStar(None)), Project(projectList, filterPlan)) - val expectedPlan = Limit(Literal(1), projectPlan) - // Compare the two plans - assert(expectedPlan === logicalPlan) - } - - test("create ppl simple age literal greater than filter query with two fields result test") { - val frame = sql(s""" - | source = $testTable age>25 | fields name, age - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = Array(Row("Jake", 70), Row("Hello", 30)) - // Compare the results - // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) - assert(results.sorted.sameElements(expectedResults.sorted)) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val filterExpr = GreaterThan(UnresolvedAttribute("age"), Literal(25)) - val filterPlan = Filter(filterExpr, table) - val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) - val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans - assert(expectedPlan === logicalPlan) - } - - test( - "create ppl simple age literal smaller than equals filter query with two fields result test") { - val frame = sql(s""" - | source = $testTable age<=65 | fields name, age - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = Array(Row("Hello", 30), Row("John", 25), Row("Jane", 20)) - // Compare the results - // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) - assert(results.sorted.sameElements(expectedResults.sorted)) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val filterExpr = LessThanOrEqual(UnresolvedAttribute("age"), Literal(65)) - val filterPlan = Filter(filterExpr, table) - val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) - val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans - assert(expectedPlan === logicalPlan) - } - - test( - "create ppl simple age literal smaller than equals filter query with two fields result with sort test") { - val frame = sql(s""" - | source = $testTable age<=65 | sort name | fields name, age - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = Array(Row("Hello", 30), Row("Jane", 20), Row("John", 25)) - // Compare the results - assert(results === expectedResults) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val filterExpr = LessThanOrEqual(UnresolvedAttribute("age"), Literal(65)) - val filterPlan = Filter(filterExpr, table) - val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) - val expectedPlan = Project(projectList, filterPlan) + val expectedPlan: LogicalPlan = Limit(Literal(1), Project(Seq(UnresolvedStar(None)), project)) val sortedPlan: LogicalPlan = - Sort(Seq(SortOrder(UnresolvedAttribute("name"), Ascending)), global = true, expectedPlan) + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, expectedPlan) // Compare the two plans assert(sortedPlan === logicalPlan) } - - test("create ppl simple name literal equal filter query with two fields result test") { - val frame = sql(s""" - | source = $testTable name='Jake' | fields name, age - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = Array(Row("Jake", 70)) - // Compare the results - // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) - assert(results.sorted.sameElements(expectedResults.sorted)) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val filterExpr = EqualTo(UnresolvedAttribute("name"), Literal("Jake")) - val filterPlan = Filter(filterExpr, table) - val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) - val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans - assert(expectedPlan === logicalPlan) - } - - test("create ppl simple name literal not equal filter query with two fields result test") { - val frame = sql(s""" - | source = $testTable name!='Jake' | fields name, age - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = Array(Row("Hello", 30), Row("John", 25), Row("Jane", 20)) - - // Compare the results - // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) - assert(results.sorted.sameElements(expectedResults.sorted)) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val filterExpr = Not(EqualTo(UnresolvedAttribute("name"), Literal("Jake"))) - val filterPlan = Filter(filterExpr, table) - val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) - val expectedPlan = Project(projectList, filterPlan) - // Compare the two plans - assert(expectedPlan === logicalPlan) - } - - test("create ppl simple age avg query test") { - val frame = sql(s""" - | source = $testTable| stats avg(age) - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = Array(Row(36.25)) - - // Compare the results - // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) - assert(results.sorted.sameElements(expectedResults.sorted)) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val star = Seq(UnresolvedStar(None)) - val ageField = UnresolvedAttribute("age") - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val aggregateExpressions = - Seq(Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")()) - val aggregatePlan = Aggregate(Seq(), aggregateExpressions, table) - val expectedPlan = Project(star, aggregatePlan) - - // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logicalPlan)) - } - - test("create ppl simple age avg query with filter test") { - val frame = sql(s""" - | source = $testTable| where age < 50 | stats avg(age) - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = Array(Row(25)) - - // Compare the results - // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) - assert(results.sorted.sameElements(expectedResults.sorted)) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val star = Seq(UnresolvedStar(None)) - val ageField = UnresolvedAttribute("age") - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val filterExpr = LessThan(ageField, Literal(50)) - val filterPlan = Filter(filterExpr, table) - val aggregateExpressions = - Seq(Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")()) - val aggregatePlan = Aggregate(Seq(), aggregateExpressions, filterPlan) - val expectedPlan = Project(star, aggregatePlan) - - // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logicalPlan)) - } - - test("create ppl simple age avg group by country query test ") { - val frame = sql(s""" - | source = $testTable| stats avg(age) by country - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = Array(Row(22.5, "Canada"), Row(50.0, "USA")) - - // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) - assert(results.sorted.sameElements(expectedResults.sorted)) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val star = Seq(UnresolvedStar(None)) - val countryField = UnresolvedAttribute("country") - val ageField = UnresolvedAttribute("age") - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - - val groupByAttributes = Seq(Alias(countryField, "country")()) - val aggregateExpressions = - Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() - val countryAlias = Alias(countryField, "country")() - - val aggregatePlan = - Aggregate(groupByAttributes, Seq(aggregateExpressions, countryAlias), table) - val expectedPlan = Project(star, aggregatePlan) - - // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logicalPlan)) - } - - test("create ppl simple age avg group by country head (limit) query test ") { - val frame = sql(s""" - | source = $testTable| stats avg(age) by country | head 1 - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - assert(results.length == 1) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val countryField = UnresolvedAttribute("country") - val ageField = UnresolvedAttribute("age") - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - - val groupByAttributes = Seq(Alias(countryField, "country")()) - val aggregateExpressions = - Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() - val productAlias = Alias(countryField, "country")() - - val aggregatePlan = - Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) - val projectPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan) - val expectedPlan = Limit(Literal(1), projectPlan) - - // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logicalPlan)) - } - - test("create ppl simple age max group by country query test ") { - val frame = sql(s""" - | source = $testTable| stats max(age) by country - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = Array(Row(70, "USA"), Row(25, "Canada")) - - // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) - assert(results.sorted.sameElements(expectedResults.sorted)) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val star = Seq(UnresolvedStar(None)) - val countryField = UnresolvedAttribute("country") - val ageField = UnresolvedAttribute("age") - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - - val groupByAttributes = Seq(Alias(countryField, "country")()) - val aggregateExpressions = - Alias(UnresolvedFunction(Seq("MAX"), Seq(ageField), isDistinct = false), "max(age)")() - val productAlias = Alias(countryField, "country")() - - val aggregatePlan = - Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) - val expectedPlan = Project(star, aggregatePlan) - - // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logicalPlan)) - } - - test("create ppl simple age min group by country query test ") { - val frame = sql(s""" - | source = $testTable| stats min(age) by country - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = Array(Row(30, "USA"), Row(20, "Canada")) - - // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) - assert(results.sorted.sameElements(expectedResults.sorted)) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val star = Seq(UnresolvedStar(None)) - val countryField = UnresolvedAttribute("country") - val ageField = UnresolvedAttribute("age") - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - - val groupByAttributes = Seq(Alias(countryField, "country")()) - val aggregateExpressions = - Alias(UnresolvedFunction(Seq("MIN"), Seq(ageField), isDistinct = false), "min(age)")() - val productAlias = Alias(countryField, "country")() - - val aggregatePlan = - Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) - val expectedPlan = Project(star, aggregatePlan) - - // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logicalPlan)) - } - - test("create ppl simple age sum group by country query test ") { - val frame = sql(s""" - | source = $testTable| stats sum(age) by country - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = Array(Row(100L, "USA"), Row(45L, "Canada")) - - // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](0)) - assert(results.sorted.sameElements(expectedResults.sorted)) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val star = Seq(UnresolvedStar(None)) - val countryField = UnresolvedAttribute("country") - val ageField = UnresolvedAttribute("age") - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - - val groupByAttributes = Seq(Alias(countryField, "country")()) - val aggregateExpressions = - Alias(UnresolvedFunction(Seq("SUM"), Seq(ageField), isDistinct = false), "sum(age)")() - val productAlias = Alias(countryField, "country")() - - val aggregatePlan = - Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) - val expectedPlan = Project(star, aggregatePlan) - - // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logicalPlan)) - } - - test("create ppl simple age sum group by country order by age query test with sort ") { - val frame = sql(s""" - | source = $testTable| stats sum(age) by country | sort country - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = Array(Row(45L, "Canada"), Row(100L, "USA")) - - // Compare the results - assert(results === expectedResults) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val star = Seq(UnresolvedStar(None)) - val countryField = UnresolvedAttribute("country") - val ageField = UnresolvedAttribute("age") - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - - val groupByAttributes = Seq(Alias(countryField, "country")()) - val aggregateExpressions = - Alias(UnresolvedFunction(Seq("SUM"), Seq(ageField), isDistinct = false), "sum(age)")() - val productAlias = Alias(countryField, "country")() - - val aggregatePlan = - Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) - val expectedPlan = Project(star, aggregatePlan) - val sortedPlan: LogicalPlan = - Sort(Seq(SortOrder(UnresolvedAttribute("country"), Ascending)), global = true, expectedPlan) - // Compare the two plans - assert(compareByString(sortedPlan) === compareByString(logicalPlan)) - } - - test("create ppl simple age count group by country query test ") { - val frame = sql(s""" - | source = $testTable| stats count(age) by country - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = Array(Row(2L, "Canada"), Row(2L, "USA")) - - // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) - assert( - results.sorted.sameElements(expectedResults.sorted), - s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}") - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val star = Seq(UnresolvedStar(None)) - val countryField = UnresolvedAttribute("country") - val ageField = UnresolvedAttribute("age") - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - - val groupByAttributes = Seq(Alias(countryField, "country")()) - val aggregateExpressions = - Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() - val productAlias = Alias(countryField, "country")() - - val aggregatePlan = - Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) - val expectedPlan = Project(star, aggregatePlan) - - // Compare the two plans - assert( - compareByString(expectedPlan) === compareByString(logicalPlan), - s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}") - } - - test("create ppl simple age avg group by country with state filter query test ") { - val frame = sql(s""" - | source = $testTable| where state != 'Quebec' | stats avg(age) by country - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = Array(Row(25.0, "Canada"), Row(50.0, "USA")) - - // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) - assert(results.sorted.sameElements(expectedResults.sorted)) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val star = Seq(UnresolvedStar(None)) - val stateField = UnresolvedAttribute("state") - val countryField = UnresolvedAttribute("country") - val ageField = UnresolvedAttribute("age") - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - - val groupByAttributes = Seq(Alias(countryField, "country")()) - val aggregateExpressions = - Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() - val productAlias = Alias(countryField, "country")() - val filterExpr = Not(EqualTo(stateField, Literal("Quebec"))) - val filterPlan = Filter(filterExpr, table) - - val aggregatePlan = - Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan) - val expectedPlan = Project(star, aggregatePlan) - - // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logicalPlan)) - } - - /** - * | age_span | count_age | - * |:---------|----------:| - * | 20 | 2 | - * | 30 | 1 | - * | 70 | 1 | - */ - test("create ppl simple count age by span of interval of 10 years query test ") { - val frame = sql(s""" - | source = $testTable| stats count(age) by span(age, 10) as age_span - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = Array(Row(1, 70L), Row(1, 30L), Row(2, 20L)) - - // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1)) - assert(results.sorted.sameElements(expectedResults.sorted)) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val star = Seq(UnresolvedStar(None)) - val ageField = UnresolvedAttribute("age") - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - - val aggregateExpressions = - Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() - val span = Alias( - Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), - "age_span")() - val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) - val expectedPlan = Project(star, aggregatePlan) - - // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logicalPlan)) - } - - /** - * | age_span | average_age | - * |:---------|------------:| - * | 20 | 22.5 | - * | 30 | 30 | - * | 70 | 70 | - */ - test("create ppl simple avg age by span of interval of 10 years query test ") { - val frame = sql(s""" - | source = $testTable| stats avg(age) by span(age, 10) as age_span - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = Array(Row(70d, 70L), Row(30d, 30L), Row(22.5d, 20L)) - - // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) - assert(results.sorted.sameElements(expectedResults.sorted)) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val star = Seq(UnresolvedStar(None)) - val ageField = UnresolvedAttribute("age") - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - - val aggregateExpressions = - Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() - val span = Alias( - Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), - "age_span")() - val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) - val expectedPlan = Project(star, aggregatePlan) - - // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logicalPlan)) - } - - test( - "create ppl simple avg age by span of interval of 10 years with head (limit) query test ") { - val frame = sql(s""" - | source = $testTable| stats avg(age) by span(age, 10) as age_span | head 2 - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - assert(results.length == 2) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val star = Seq(UnresolvedStar(None)) - val ageField = UnresolvedAttribute("age") - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - - val aggregateExpressions = - Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() - val span = Alias( - Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), - "age_span")() - val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) - val projectPlan = Project(star, aggregatePlan) - val expectedPlan = Limit(Literal(2), projectPlan) - - // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logicalPlan)) - } - - /** - * | age_span | country | average_age | - * |:---------|:--------|:------------| - * | 20 | Canada | 22.5 | - * | 30 | USA | 30 | - * | 70 | USA | 70 | - */ - test("create ppl average age by span of interval of 10 years group by country query test ") { - val dataFrame = spark.sql( - "SELECT FLOOR(age / 10) * 10 AS age_span, country, AVG(age) AS average_age FROM default.flint_ppl_test GROUP BY FLOOR(age / 10) * 10, country ") - dataFrame.collect(); - dataFrame.show() - - val frame = sql(s""" - | source = $testTable| stats avg(age) by span(age, 10) as age_span, country - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = - Array(Row(70.0d, "USA", 70L), Row(30.0d, "USA", 30L), Row(22.5d, "Canada", 20L)) - - // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) - assert(results.sorted.sameElements(expectedResults.sorted)) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val star = Seq(UnresolvedStar(None)) - val ageField = UnresolvedAttribute("age") - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val countryField = UnresolvedAttribute("country") - val countryAlias = Alias(countryField, "country")() - - val aggregateExpressions = - Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() - val span = Alias( - Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), - "age_span")() - val aggregatePlan = - Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), table) - val expectedPlan = Project(star, aggregatePlan) - - // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logicalPlan)) - } - - test( - "create ppl average age by span of interval of 10 years group by country head (limit) 2 query test ") { - val frame = sql(s""" - | source = $testTable| stats avg(age) by span(age, 10) as age_span, country | head 3 - | """.stripMargin) - // Retrieve the results - val results: Array[Row] = frame.collect() - // Define the expected results - val expectedResults: Array[Row] = - Array(Row(70.0d, "USA", 70L), Row(30.0d, "USA", 30L), Row(22.5d, "Canada", 20L)) - - // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) - assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val star = Seq(UnresolvedStar(None)) - val ageField = UnresolvedAttribute("age") - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val countryField = UnresolvedAttribute("country") - val countryAlias = Alias(countryField, "country")() - - val aggregateExpressions = - Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() - val span = Alias( - Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), - "age_span")() - val aggregatePlan = - Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), table) - val projectPlan = Project(star, aggregatePlan) - val expectedPlan = Limit(Literal(3), projectPlan) - - // Compare the two plans - assert(compareByString(expectedPlan) === compareByString(logicalPlan)) - } - - test( - "create ppl average age by span of interval of 10 years group by country head (limit) 2 query and sort by test ") { - val frame = sql(s""" - | source = $testTable| stats avg(age) by span(age, 10) as age_span, country | sort - age_span | head 2 - | """.stripMargin) - - // Retrieve the results - val results: Array[Row] = frame.collect() - assert(results.length == 2) - - // Retrieve the logical plan - val logicalPlan: LogicalPlan = frame.queryExecution.logical - // Define the expected logical plan - val star = Seq(UnresolvedStar(None)) - val ageField = UnresolvedAttribute("age") - val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val countryField = UnresolvedAttribute("country") - val countryAlias = Alias(countryField, "country")() - - val aggregateExpressions = - Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() - val span = Alias( - Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), - "age_span")() - val aggregatePlan = - Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), table) - val projectPlan = Project(star, aggregatePlan) - val expectedPlan = Limit(Literal(2), projectPlan) - val sortedPlan: LogicalPlan = Sort( - Seq(SortOrder(UnresolvedAttribute("age_span"), Descending)), - global = true, - expectedPlan) - // Compare the two plans - assert(compareByString(sortedPlan) === compareByString(logicalPlan)) - } + }