diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala index 1121a2e5b..54e6f7339 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala @@ -5,14 +5,14 @@ package org.opensearch.flint.spark +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, Divide, EqualTo, Floor, GreaterThan, LessThan, LessThanOrEqual, Literal, Multiply, Not, Or, SortOrder} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.sql.{QueryTest, Row} class FlintSparkPPLITSuite - extends QueryTest + extends QueryTest with LogicalPlanTestUtils with FlintPPLSuite with StreamTest { @@ -25,8 +25,7 @@ class FlintSparkPPLITSuite // Create test table // Update table creation - sql( - s""" + sql(s""" | CREATE TABLE $testTable | ( | name STRING, @@ -46,8 +45,7 @@ class FlintSparkPPLITSuite |""".stripMargin) // Update data insertion - sql( - s""" + sql(s""" | INSERT INTO $testTable | PARTITION (year=2023, month=4) | VALUES ('Jake', 70, 'California', 'USA'), @@ -67,8 +65,7 @@ class FlintSparkPPLITSuite } test("create ppl simple query test") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable | """.stripMargin) @@ -79,8 +76,7 @@ class FlintSparkPPLITSuite Row("Jake", 70, "California", "USA", 2023, 4), Row("Hello", 30, "New York", "USA", 2023, 4), Row("John", 25, "Ontario", "Canada", 2023, 4), - Row("Jane", 20, "Quebec", "Canada", 2023, 4) - ) + Row("Jane", 20, "Quebec", "Canada", 2023, 4)) // Compare the results // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) @@ -89,14 +85,14 @@ class FlintSparkPPLITSuite // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical // Define the expected logical plan - val expectedPlan: LogicalPlan = Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("default", "flint_ppl_test"))) + val expectedPlan: LogicalPlan = + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("default", "flint_ppl_test"))) // Compare the two plans assert(expectedPlan === logicalPlan) } test("create ppl simple query with head (limit) 3 test") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| head 2 | """.stripMargin) @@ -107,15 +103,16 @@ class FlintSparkPPLITSuite // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical // Define the expected logical plan - val expectedPlan: LogicalPlan = Limit(Literal(2), Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("default", "flint_ppl_test")))) + val expectedPlan: LogicalPlan = Limit( + Literal(2), + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("default", "flint_ppl_test")))) // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } test("create ppl simple query with head (limit) and sorted test") { - val frame = sql( - s""" - | source = $testTable| sort name | head 2 + val frame = sql(s""" + | source = $testTable| sort name | head 2 | """.stripMargin) // Retrieve the results @@ -126,27 +123,25 @@ class FlintSparkPPLITSuite // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical // Define the expected logical plan - val expectedPlan: LogicalPlan = Limit(Literal(2), Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("default", "flint_ppl_test")))) - val sortedPlan: LogicalPlan = Sort(Seq(SortOrder(UnresolvedAttribute("name"), Ascending)), global = true, expectedPlan) + val expectedPlan: LogicalPlan = Limit( + Literal(2), + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("default", "flint_ppl_test")))) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("name"), Ascending)), global = true, expectedPlan) // Compare the two plans assert(compareByString(sortedPlan) === compareByString(logicalPlan)) } test("create ppl simple query two with fields result test") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| fields name, age | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row("Jake", 70), - Row("Hello", 30), - Row("John", 25), - Row("Jane", 20) - ) + val expectedResults: Array[Row] = + Array(Row("Jake", 70), Row("Hello", 30), Row("John", 25), Row("Jane", 20)) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) @@ -154,43 +149,40 @@ class FlintSparkPPLITSuite // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical // Define the expected logical plan - val expectedPlan: LogicalPlan = Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), + val expectedPlan: LogicalPlan = Project( + Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), UnresolvedRelation(Seq("default", "flint_ppl_test"))) // Compare the two plans assert(expectedPlan === logicalPlan) } test("create ppl simple sorted query two with fields result test sorted") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| sort age | fields name, age | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row("Jane", 20), - Row("John", 25), - Row("Hello", 30), - Row("Jake", 70), - ) + val expectedResults: Array[Row] = + Array(Row("Jane", 20), Row("John", 25), Row("Hello", 30), Row("Jake", 70)) assert(results === expectedResults) // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical // Define the expected logical plan - val expectedPlan: LogicalPlan = Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), + val expectedPlan: LogicalPlan = Project( + Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), UnresolvedRelation(Seq("default", "flint_ppl_test"))) - val sortedPlan: LogicalPlan = Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, expectedPlan) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, expectedPlan) // Compare the two plans assert(sortedPlan === logicalPlan) } test("create ppl simple query two with fields and head (limit) test") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| fields name, age | head 1 | """.stripMargin) @@ -200,7 +192,9 @@ class FlintSparkPPLITSuite // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - val project = Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), UnresolvedRelation(Seq("default", "flint_ppl_test"))) + val project = Project( + Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), + UnresolvedRelation(Seq("default", "flint_ppl_test"))) // Define the expected logical plan val expectedPlan: LogicalPlan = Limit(Literal(1), Project(Seq(UnresolvedStar(None)), project)) // Compare the two plans @@ -208,23 +202,19 @@ class FlintSparkPPLITSuite } test("create ppl simple age literal equal filter query with two fields result test") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable age=25 | fields name, age | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row("John", 25), - ) + val expectedResults: Array[Row] = Array(Row("John", 25)) // Compare the results // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical // Define the expected logical plan @@ -237,30 +227,28 @@ class FlintSparkPPLITSuite assert(expectedPlan === logicalPlan) } - test("create ppl simple age literal greater than filter AND country not equal filter query with two fields result test") { - val frame = sql( - s""" + test( + "create ppl simple age literal greater than filter AND country not equal filter query with two fields result test") { + val frame = sql(s""" | source = $testTable age>10 and country != 'USA' | fields name, age | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row("John", 25), - Row("Jane", 20), - ) + val expectedResults: Array[Row] = Array(Row("John", 25), Row("Jane", 20)) // Compare the results // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical // Define the expected logical plan val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val filterExpr = And(Not(EqualTo(UnresolvedAttribute("country"), Literal("USA"))), GreaterThan(UnresolvedAttribute("age"), Literal(10))) + val filterExpr = And( + Not(EqualTo(UnresolvedAttribute("country"), Literal("USA"))), + GreaterThan(UnresolvedAttribute("age"), Literal(10))) val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) @@ -268,19 +256,16 @@ class FlintSparkPPLITSuite assert(expectedPlan === logicalPlan) } - test("create ppl simple age literal greater than filter AND country not equal filter query with two fields sorted result test") { - val frame = sql( - s""" + test( + "create ppl simple age literal greater than filter AND country not equal filter query with two fields sorted result test") { + val frame = sql(s""" | source = $testTable age>10 and country != 'USA' | sort - age | fields name, age | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row("John", 25), - Row("Jane", 20), - ) + val expectedResults: Array[Row] = Array(Row("John", 25), Row("Jane", 20)) // Compare the results assert(results === expectedResults) @@ -288,41 +273,41 @@ class FlintSparkPPLITSuite val logicalPlan: LogicalPlan = frame.queryExecution.logical // Define the expected logical plan val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val filterExpr = And(Not(EqualTo(UnresolvedAttribute("country"), Literal("USA"))), GreaterThan(UnresolvedAttribute("age"), Literal(10))) + val filterExpr = And( + Not(EqualTo(UnresolvedAttribute("country"), Literal("USA"))), + GreaterThan(UnresolvedAttribute("age"), Literal(10))) val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - val sortedPlan: LogicalPlan = Sort(Seq(SortOrder(UnresolvedAttribute("age"), Descending)), global = true, expectedPlan) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Descending)), global = true, expectedPlan) // Compare the two plans assert(sortedPlan === logicalPlan) } - test("create ppl simple age literal equal than filter OR country not equal filter query with two fields result test") { - val frame = sql( - s""" + test( + "create ppl simple age literal equal than filter OR country not equal filter query with two fields result test") { + val frame = sql(s""" | source = $testTable age<=20 OR country = 'USA' | fields name, age | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row("Jane", 20), - Row("Jake", 70), - Row("Hello", 30), - ) + val expectedResults: Array[Row] = Array(Row("Jane", 20), Row("Jake", 70), Row("Hello", 30)) // Compare the results // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical // Define the expected logical plan val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val filterExpr = Or(EqualTo(UnresolvedAttribute("country"), Literal("USA")), LessThanOrEqual(UnresolvedAttribute("age"), Literal(20))) + val filterExpr = Or( + EqualTo(UnresolvedAttribute("country"), Literal("USA")), + LessThanOrEqual(UnresolvedAttribute("age"), Literal(20))) val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) @@ -330,9 +315,9 @@ class FlintSparkPPLITSuite assert(expectedPlan === logicalPlan) } - test("create ppl simple age literal equal than filter OR country not equal filter query with two fields result and head (limit) test") { - val frame = sql( - s""" + test( + "create ppl simple age literal equal than filter OR country not equal filter query with two fields result and head (limit) test") { + val frame = sql(s""" | source = $testTable age<=20 OR country = 'USA' | fields name, age | head 1 | """.stripMargin) @@ -340,12 +325,13 @@ class FlintSparkPPLITSuite val results: Array[Row] = frame.collect() assert(results.length == 1) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical // Define the expected logical plan val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val filterExpr = Or(EqualTo(UnresolvedAttribute("country"), Literal("USA")), LessThanOrEqual(UnresolvedAttribute("age"), Literal(20))) + val filterExpr = Or( + EqualTo(UnresolvedAttribute("country"), Literal("USA")), + LessThanOrEqual(UnresolvedAttribute("age"), Literal(20))) val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val projectPlan = Project(Seq(UnresolvedStar(None)), Project(projectList, filterPlan)) @@ -355,18 +341,14 @@ class FlintSparkPPLITSuite } test("create ppl simple age literal greater than filter query with two fields result test") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable age>25 | fields name, age | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row("Jake", 70), - Row("Hello", 30) - ) + val expectedResults: Array[Row] = Array(Row("Jake", 70), Row("Hello", 30)) // Compare the results // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) @@ -384,20 +366,16 @@ class FlintSparkPPLITSuite assert(expectedPlan === logicalPlan) } - test("create ppl simple age literal smaller than equals filter query with two fields result test") { - val frame = sql( - s""" + test( + "create ppl simple age literal smaller than equals filter query with two fields result test") { + val frame = sql(s""" | source = $testTable age<=65 | fields name, age | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row("Hello", 30), - Row("John", 25), - Row("Jane", 20) - ) + val expectedResults: Array[Row] = Array(Row("Hello", 30), Row("John", 25), Row("Jane", 20)) // Compare the results // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) @@ -415,20 +393,16 @@ class FlintSparkPPLITSuite assert(expectedPlan === logicalPlan) } - test("create ppl simple age literal smaller than equals filter query with two fields result with sort test") { - val frame = sql( - s""" + test( + "create ppl simple age literal smaller than equals filter query with two fields result with sort test") { + val frame = sql(s""" | source = $testTable age<=65 | sort name | fields name, age | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row("Hello", 30), - Row("Jane", 20), - Row("John", 25), - ) + val expectedResults: Array[Row] = Array(Row("Hello", 30), Row("Jane", 20), Row("John", 25)) // Compare the results assert(results === expectedResults) @@ -440,23 +414,21 @@ class FlintSparkPPLITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - val sortedPlan: LogicalPlan = Sort(Seq(SortOrder(UnresolvedAttribute("name"), Ascending)), global = true, expectedPlan) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("name"), Ascending)), global = true, expectedPlan) // Compare the two plans assert(sortedPlan === logicalPlan) } test("create ppl simple name literal equal filter query with two fields result test") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable name='Jake' | fields name, age | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row("Jake", 70) - ) + val expectedResults: Array[Row] = Array(Row("Jake", 70)) // Compare the results // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) @@ -475,19 +447,14 @@ class FlintSparkPPLITSuite } test("create ppl simple name literal not equal filter query with two fields result test") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable name!='Jake' | fields name, age | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row("Hello", 30), - Row("John", 25), - Row("Jane", 20) - ) + val expectedResults: Array[Row] = Array(Row("Hello", 30), Row("John", 25), Row("Jane", 20)) // Compare the results // Compare the results @@ -507,17 +474,14 @@ class FlintSparkPPLITSuite } test("create ppl simple age avg query test") { - val frame = sql( - s""" - | source = $testTable| stats avg(age) + val frame = sql(s""" + | source = $testTable| stats avg(age) | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(36.25), - ) + val expectedResults: Array[Row] = Array(Row(36.25)) // Compare the results // Compare the results @@ -530,7 +494,8 @@ class FlintSparkPPLITSuite val star = Seq(UnresolvedStar(None)) val ageField = UnresolvedAttribute("age") val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val aggregateExpressions = Seq(Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")()) + val aggregateExpressions = + Seq(Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")()) val aggregatePlan = Aggregate(Seq(), aggregateExpressions, table) val expectedPlan = Project(star, aggregatePlan) @@ -539,17 +504,14 @@ class FlintSparkPPLITSuite } test("create ppl simple age avg query with filter test") { - val frame = sql( - s""" - | source = $testTable| where age < 50 | stats avg(age) + val frame = sql(s""" + | source = $testTable| where age < 50 | stats avg(age) | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(25), - ) + val expectedResults: Array[Row] = Array(Row(25)) // Compare the results // Compare the results @@ -564,7 +526,8 @@ class FlintSparkPPLITSuite val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) val filterExpr = LessThan(ageField, Literal(50)) val filterPlan = Filter(filterExpr, table) - val aggregateExpressions = Seq(Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")()) + val aggregateExpressions = + Seq(Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")()) val aggregatePlan = Aggregate(Seq(), aggregateExpressions, filterPlan) val expectedPlan = Project(star, aggregatePlan) @@ -573,18 +536,14 @@ class FlintSparkPPLITSuite } test("create ppl simple age avg group by country query test ") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| stats avg(age) by country | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(22.5, "Canada"), - Row(50.0, "USA"), - ) + val expectedResults: Array[Row] = Array(Row(22.5, "Canada"), Row(50.0, "USA")) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) @@ -599,10 +558,12 @@ class FlintSparkPPLITSuite val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) val groupByAttributes = Seq(Alias(countryField, "country")()) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() val countryAlias = Alias(countryField, "country")() - val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, countryAlias), table) + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, countryAlias), table) val expectedPlan = Project(star, aggregatePlan) // Compare the two plans @@ -610,8 +571,7 @@ class FlintSparkPPLITSuite } test("create ppl simple age avg group by country head (limit) query test ") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| stats avg(age) by country | head 1 | """.stripMargin) @@ -627,10 +587,12 @@ class FlintSparkPPLITSuite val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) val groupByAttributes = Seq(Alias(countryField, "country")()) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() val productAlias = Alias(countryField, "country")() - val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) val projectPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan) val expectedPlan = Limit(Literal(1), projectPlan) @@ -639,18 +601,14 @@ class FlintSparkPPLITSuite } test("create ppl simple age max group by country query test ") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| stats max(age) by country | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(70, "USA"), - Row(25, "Canada"), - ) + val expectedResults: Array[Row] = Array(Row(70, "USA"), Row(25, "Canada")) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) @@ -665,10 +623,12 @@ class FlintSparkPPLITSuite val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) val groupByAttributes = Seq(Alias(countryField, "country")()) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("MAX"), Seq(ageField), isDistinct = false), "max(age)")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("MAX"), Seq(ageField), isDistinct = false), "max(age)")() val productAlias = Alias(countryField, "country")() - val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) val expectedPlan = Project(star, aggregatePlan) // Compare the two plans @@ -676,18 +636,14 @@ class FlintSparkPPLITSuite } test("create ppl simple age min group by country query test ") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| stats min(age) by country | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(30, "USA"), - Row(20, "Canada"), - ) + val expectedResults: Array[Row] = Array(Row(30, "USA"), Row(20, "Canada")) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) @@ -702,10 +658,12 @@ class FlintSparkPPLITSuite val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) val groupByAttributes = Seq(Alias(countryField, "country")()) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("MIN"), Seq(ageField), isDistinct = false), "min(age)")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("MIN"), Seq(ageField), isDistinct = false), "min(age)")() val productAlias = Alias(countryField, "country")() - val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) val expectedPlan = Project(star, aggregatePlan) // Compare the two plans @@ -713,18 +671,14 @@ class FlintSparkPPLITSuite } test("create ppl simple age sum group by country query test ") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| stats sum(age) by country | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(100L, "USA"), - Row(45L, "Canada"), - ) + val expectedResults: Array[Row] = Array(Row(100L, "USA"), Row(45L, "Canada")) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](0)) @@ -739,10 +693,12 @@ class FlintSparkPPLITSuite val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) val groupByAttributes = Seq(Alias(countryField, "country")()) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("SUM"), Seq(ageField), isDistinct = false), "sum(age)")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("SUM"), Seq(ageField), isDistinct = false), "sum(age)")() val productAlias = Alias(countryField, "country")() - val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) val expectedPlan = Project(star, aggregatePlan) // Compare the two plans @@ -750,18 +706,14 @@ class FlintSparkPPLITSuite } test("create ppl simple age sum group by country order by age query test with sort ") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| stats sum(age) by country | sort country | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(45L, "Canada"), - Row(100L, "USA"), - ) + val expectedResults: Array[Row] = Array(Row(45L, "Canada"), Row(100L, "USA")) // Compare the results assert(results === expectedResults) @@ -775,36 +727,34 @@ class FlintSparkPPLITSuite val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) val groupByAttributes = Seq(Alias(countryField, "country")()) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("SUM"), Seq(ageField), isDistinct = false), "sum(age)")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("SUM"), Seq(ageField), isDistinct = false), "sum(age)")() val productAlias = Alias(countryField, "country")() - val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) val expectedPlan = Project(star, aggregatePlan) - val sortedPlan: LogicalPlan = Sort(Seq(SortOrder(UnresolvedAttribute("country"), Ascending)), global = true, expectedPlan) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("country"), Ascending)), global = true, expectedPlan) // Compare the two plans assert(compareByString(sortedPlan) === compareByString(logicalPlan)) } test("create ppl simple age count group by country query test ") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| stats count(age) by country | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(2L, "Canada"), - Row(2L, "USA"), - ) + val expectedResults: Array[Row] = Array(Row(2L, "Canada"), Row(2L, "USA")) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) assert( results.sorted.sameElements(expectedResults.sorted), - s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}" - ) + s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}") // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical @@ -815,32 +765,29 @@ class FlintSparkPPLITSuite val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) val groupByAttributes = Seq(Alias(countryField, "country")()) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() val productAlias = Alias(countryField, "country")() - val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) val expectedPlan = Project(star, aggregatePlan) // Compare the two plans assert( compareByString(expectedPlan) === compareByString(logicalPlan), - s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}" - ) + s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}") } test("create ppl simple age avg group by country with state filter query test ") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| where state != 'Quebec' | stats avg(age) by country | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(25.0, "Canada"), - Row(50.0, "USA"), - ) + val expectedResults: Array[Row] = Array(Row(25.0, "Canada"), Row(50.0, "USA")) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) @@ -856,12 +803,14 @@ class FlintSparkPPLITSuite val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) val groupByAttributes = Seq(Alias(countryField, "country")()) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() val productAlias = Alias(countryField, "country")() val filterExpr = Not(EqualTo(stateField, Literal("Quebec"))) val filterPlan = Filter(filterExpr, table) - val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan) + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan) val expectedPlan = Project(star, aggregatePlan) // Compare the two plans @@ -869,28 +818,21 @@ class FlintSparkPPLITSuite } /** - * +--------+-------+-----------+ - * |age_span| count_age| - * +--------+-------+-----------+ - * | 20| 2 | - * | 30| 1 | - * | 70| 1 | - * +--------+-------+-----------+ + * | age_span | count_age | + * |:---------|----------:| + * | 20 | 2 | + * | 30 | 1 | + * | 70 | 1 | */ test("create ppl simple count age by span of interval of 10 years query test ") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| stats count(age) by span(age, 10) as age_span | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(1, 70L), - Row(1, 30L), - Row(2, 20L), - ) + val expectedResults: Array[Row] = Array(Row(1, 70L), Row(1, 30L), Row(2, 20L)) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1)) @@ -903,8 +845,11 @@ class FlintSparkPPLITSuite val ageField = UnresolvedAttribute("age") val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() - val span = Alias(Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "age_span")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) val expectedPlan = Project(star, aggregatePlan) @@ -913,19 +858,14 @@ class FlintSparkPPLITSuite } ignore("create ppl simple count age by span of interval of 10 years query order by age test ") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| stats count(age) by span(age, 10) as age_span | sort age_span | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(1, 70L), - Row(1, 30L), - Row(2, 20L), - ) + val expectedResults: Array[Row] = Array(Row(1, 70L), Row(1, 30L), Row(2, 20L)) // Compare the results assert(results === expectedResults) @@ -937,38 +877,37 @@ class FlintSparkPPLITSuite val ageField = UnresolvedAttribute("age") val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() - val span = Alias(Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "span (age,10,NONE)")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "span (age,10,NONE)")() val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) val expectedPlan = Project(star, aggregatePlan) - val sortedPlan: LogicalPlan = Sort(Seq(SortOrder(UnresolvedAttribute("span (age,10,NONE)"), Ascending)), global = true, expectedPlan) + val sortedPlan: LogicalPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("span (age,10,NONE)"), Ascending)), + global = true, + expectedPlan) // Compare the two plans assert(sortedPlan === logicalPlan) } /** - * +--------+-------+-----------+ - * |age_span| average_age| - * +--------+-------+-----------+ - * | 20| 22.5 | - * | 30| 30 | - * | 70| 70 | - * +--------+-------+-----------+ + * | age_span | average_age | + * |:---------|------------:| + * | 20 | 22.5 | + * | 30 | 30 | + * | 70 | 70 | */ test("create ppl simple avg age by span of interval of 10 years query test ") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| stats avg(age) by span(age, 10) as age_span | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(70D, 70L), - Row(30D, 30L), - Row(22.5D, 20L), - ) + val expectedResults: Array[Row] = Array(Row(70d, 70L), Row(30d, 30L), Row(22.5d, 20L)) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) @@ -981,8 +920,11 @@ class FlintSparkPPLITSuite val ageField = UnresolvedAttribute("age") val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() - val span = Alias(Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "age_span")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) val expectedPlan = Project(star, aggregatePlan) @@ -990,9 +932,9 @@ class FlintSparkPPLITSuite assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } - test("create ppl simple avg age by span of interval of 10 years with head (limit) query test ") { - val frame = sql( - s""" + test( + "create ppl simple avg age by span of interval of 10 years with head (limit) query test ") { + val frame = sql(s""" | source = $testTable| stats avg(age) by span(age, 10) as age_span | head 2 | """.stripMargin) @@ -1007,8 +949,11 @@ class FlintSparkPPLITSuite val ageField = UnresolvedAttribute("age") val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() - val span = Alias(Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "age_span")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) val projectPlan = Project(star, aggregatePlan) val expectedPlan = Limit(Literal(2), projectPlan) @@ -1018,32 +963,27 @@ class FlintSparkPPLITSuite } /** - * +--------+-------+-----------+ - * |age_span|country|average_age| - * +--------+-------+-----------+ - * | 20| Canada| 22.5| - * | 30| USA| 30| - * | 70| USA| 70| - * +--------+-------+-----------+ + * | age_span | country | average_age | + * |:---------|:--------|:------------| + * | 20 | Canada | 22.5 | + * | 30 | USA | 30 | + * | 70 | USA | 70 | */ test("create ppl average age by span of interval of 10 years group by country query test ") { - val dataFrame = spark.sql("SELECT FLOOR(age / 10) * 10 AS age_span, country, AVG(age) AS average_age FROM default.flint_ppl_test GROUP BY FLOOR(age / 10) * 10, country ") + val dataFrame = spark.sql( + "SELECT FLOOR(age / 10) * 10 AS age_span, country, AVG(age) AS average_age FROM default.flint_ppl_test GROUP BY FLOOR(age / 10) * 10, country ") dataFrame.collect(); dataFrame.show() - - val frame = sql( - s""" + + val frame = sql(s""" | source = $testTable| stats avg(age) by span(age, 10) as age_span, country | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(70.0D, "USA", 70L), - Row(30.0D, "USA", 30L), - Row(22.5D, "Canada", 20L), - ) + val expectedResults: Array[Row] = + Array(Row(70.0d, "USA", 70L), Row(30.0d, "USA", 30L), Row(22.5d, "Canada", 20L)) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) @@ -1058,28 +998,29 @@ class FlintSparkPPLITSuite val countryField = UnresolvedAttribute("country") val countryAlias = Alias(countryField, "country")() - val aggregateExpressions = Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() - val span = Alias(Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "age_span")() - val aggregatePlan = Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), table) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = + Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), table) val expectedPlan = Project(star, aggregatePlan) // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } - test("create ppl average age by span of interval of 10 years group by country head (limit) 2 query test ") { - val frame = sql( - s""" + test( + "create ppl average age by span of interval of 10 years group by country head (limit) 2 query test ") { + val frame = sql(s""" | source = $testTable| stats avg(age) by span(age, 10) as age_span, country | head 3 | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(70.0D, "USA", 70L), - Row(30.0D, "USA", 30L), - Row(22.5D, "Canada", 20L), - ) + val expectedResults: Array[Row] = + Array(Row(70.0d, "USA", 70L), Row(30.0d, "USA", 30L), Row(22.5d, "Canada", 20L)) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) @@ -1093,19 +1034,23 @@ class FlintSparkPPLITSuite val countryField = UnresolvedAttribute("country") val countryAlias = Alias(countryField, "country")() - val aggregateExpressions = Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() - val span = Alias(Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "age_span")() - val aggregatePlan = Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), table) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = + Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), table) val projectPlan = Project(star, aggregatePlan) val expectedPlan = Limit(Literal(3), projectPlan) // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } - - test("create ppl average age by span of interval of 10 years group by country head (limit) 2 query and sort by test ") { - val frame = sql( - s""" + + test( + "create ppl average age by span of interval of 10 years group by country head (limit) 2 query and sort by test ") { + val frame = sql(s""" | source = $testTable| stats avg(age) by span(age, 10) as age_span, country | sort - age_span | head 2 | """.stripMargin) @@ -1113,7 +1058,7 @@ class FlintSparkPPLITSuite val results: Array[Row] = frame.collect() assert(results.length == 2) - // Retrieve the logical plan + // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical // Define the expected logical plan val star = Seq(UnresolvedStar(None)) @@ -1122,12 +1067,19 @@ class FlintSparkPPLITSuite val countryField = UnresolvedAttribute("country") val countryAlias = Alias(countryField, "country")() - val aggregateExpressions = Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() - val span = Alias(Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "age_span")() - val aggregatePlan = Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), table) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = + Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), table) val projectPlan = Project(star, aggregatePlan) val expectedPlan = Limit(Literal(2), projectPlan) - val sortedPlan: LogicalPlan = Sort(Seq(SortOrder(UnresolvedAttribute("age_span"), Descending)), global = true, expectedPlan) + val sortedPlan: LogicalPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("age_span"), Descending)), + global = true, + expectedPlan) // Compare the two plans assert(compareByString(sortedPlan) === compareByString(logicalPlan)) } diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/LogicalPlanTestUtils.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/LogicalPlanTestUtils.scala index 2ea8446e8..d9c0a1b8c 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/LogicalPlanTestUtils.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/LogicalPlanTestUtils.scala @@ -12,9 +12,10 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Proj * general utility functions for ppl to spark transformation test */ trait LogicalPlanTestUtils { + /** - * utility function to compare two logical plans while ignoring the auto-generated expressionId associated with the alias - * which is used for projection or aggregation + * utility function to compare two logical plans while ignoring the auto-generated expressionId + * associated with the alias which is used for projection or aggregation * @param plan * @return */ @@ -23,18 +24,21 @@ trait LogicalPlanTestUtils { val rule: PartialFunction[LogicalPlan, LogicalPlan] = { case p: Project => val newProjections = p.projectList.map { - case alias: Alias => Alias(alias.child, alias.name)(exprId = ExprId(0), qualifier = alias.qualifier) + case alias: Alias => + Alias(alias.child, alias.name)(exprId = ExprId(0), qualifier = alias.qualifier) case other => other } p.copy(projectList = newProjections) case agg: Aggregate => val newGrouping = agg.groupingExpressions.map { - case alias: Alias => Alias(alias.child, alias.name)(exprId = ExprId(0), qualifier = alias.qualifier) + case alias: Alias => + Alias(alias.child, alias.name)(exprId = ExprId(0), qualifier = alias.qualifier) case other => other } val newAggregations = agg.aggregateExpressions.map { - case alias: Alias => Alias(alias.child, alias.name)(exprId = ExprId(0), qualifier = alias.qualifier) + case alias: Alias => + Alias(alias.child, alias.name)(exprId = ExprId(0), qualifier = alias.qualifier) case other => other } agg.copy(groupingExpressions = newGrouping, aggregateExpressions = newAggregations)