From 4ee2fbf88da20a32b2ae35ddbec82f08b64f0cbb Mon Sep 17 00:00:00 2001 From: YANGDB Date: Tue, 10 Oct 2023 18:28:05 -0700 Subject: [PATCH] update correlation command add test parts Signed-off-by: YANGDB --- .../flint/spark/FlintSparkPPLITSuite.scala | 480 ++++++++---------- .../ppl/FlintSparkPPLCorrelationITSuite.scala | 156 ++++++ .../ppl/FlintSparkPPLFiltersITSuite.scala | 1 - .../src/main/antlr4/OpenSearchPPLLexer.g4 | 1 + .../src/main/antlr4/OpenSearchPPLParser.g4 | 21 +- .../sql/ast/AbstractNodeVisitor.java | 5 + .../sql/ast/expression/FieldsMapping.java | 6 +- .../opensearch/sql/ast/tree/Correlation.java | 1 + .../sql/ppl/CatalystPlanContext.java | 26 +- .../sql/ppl/CatalystQueryPlanVisitor.java | 31 +- .../sql/ppl/parser/AstExpressionBuilder.java | 10 + .../sql/ppl/utils/JoinSpecTransformer.java | 80 ++- ...ggregationQueriesTranslatorTestSuite.scala | 11 +- ...orrelationQueriesTranslatorTestSuite.scala | 24 +- 14 files changed, 544 insertions(+), 309 deletions(-) create mode 100644 integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala index b297b30c7..d632cecf7 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala @@ -5,14 +5,14 @@ package org.opensearch.flint.spark +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, Divide, EqualTo, Floor, GreaterThan, LessThan, LessThanOrEqual, Literal, Multiply, Not, Or, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Limit, LogicalPlan, Project, Sort} import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.sql.{QueryTest, Row} class FlintSparkPPLITSuite - extends QueryTest + extends QueryTest with LogicalPlanTestUtils with FlintPPLSuite with StreamTest { @@ -25,8 +25,7 @@ class FlintSparkPPLITSuite // Create test table // Update table creation - sql( - s""" + sql(s""" | CREATE TABLE $testTable | ( | name STRING, @@ -46,8 +45,7 @@ class FlintSparkPPLITSuite |""".stripMargin) // Update data insertion - sql( - s""" + sql(s""" | INSERT INTO $testTable | PARTITION (year=2023, month=4) | VALUES ('Jake', 70, 'California', 'USA'), @@ -67,8 +65,7 @@ class FlintSparkPPLITSuite } test("create ppl simple query test") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable | """.stripMargin) @@ -79,8 +76,7 @@ class FlintSparkPPLITSuite Row("Jake", 70, "California", "USA", 2023, 4), Row("Hello", 30, "New York", "USA", 2023, 4), Row("John", 25, "Ontario", "Canada", 2023, 4), - Row("Jane", 20, "Quebec", "Canada", 2023, 4) - ) + Row("Jane", 20, "Quebec", "Canada", 2023, 4)) // Compare the results // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) @@ -89,14 +85,14 @@ class FlintSparkPPLITSuite // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical // Define the expected logical plan - val expectedPlan: LogicalPlan = Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("default", "flint_ppl_test"))) + val expectedPlan: LogicalPlan = + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("default", "flint_ppl_test"))) // Compare the two plans assert(expectedPlan === logicalPlan) } test("create ppl simple query with head (limit) 3 test") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| head 2 | """.stripMargin) @@ -107,14 +103,15 @@ class FlintSparkPPLITSuite // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical // Define the expected logical plan - val expectedPlan: LogicalPlan = Limit(Literal(2), Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("default", "flint_ppl_test")))) + val expectedPlan: LogicalPlan = Limit( + Literal(2), + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("default", "flint_ppl_test")))) // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } test("create ppl simple query with head (limit) and sorted test") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| sort name | head 2 | """.stripMargin) @@ -126,27 +123,25 @@ class FlintSparkPPLITSuite // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical // Define the expected logical plan - val expectedPlan: LogicalPlan = Limit(Literal(2), Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("default", "flint_ppl_test")))) - val sortedPlan: LogicalPlan = Sort(Seq(SortOrder(UnresolvedAttribute("name"), Ascending)), global = true, expectedPlan) + val expectedPlan: LogicalPlan = Limit( + Literal(2), + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("default", "flint_ppl_test")))) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("name"), Ascending)), global = true, expectedPlan) // Compare the two plans assert(compareByString(sortedPlan) === compareByString(logicalPlan)) } test("create ppl simple query two with fields result test") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| fields name, age | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row("Jake", 70), - Row("Hello", 30), - Row("John", 25), - Row("Jane", 20) - ) + val expectedResults: Array[Row] = + Array(Row("Jake", 70), Row("Hello", 30), Row("John", 25), Row("Jane", 20)) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) @@ -154,43 +149,40 @@ class FlintSparkPPLITSuite // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical // Define the expected logical plan - val expectedPlan: LogicalPlan = Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), + val expectedPlan: LogicalPlan = Project( + Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), UnresolvedRelation(Seq("default", "flint_ppl_test"))) // Compare the two plans assert(expectedPlan === logicalPlan) } test("create ppl simple sorted query two with fields result test sorted") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| sort age | fields name, age | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row("Jane", 20), - Row("John", 25), - Row("Hello", 30), - Row("Jake", 70), - ) + val expectedResults: Array[Row] = + Array(Row("Jane", 20), Row("John", 25), Row("Hello", 30), Row("Jake", 70)) assert(results === expectedResults) // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical // Define the expected logical plan - val expectedPlan: LogicalPlan = Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), + val expectedPlan: LogicalPlan = Project( + Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), UnresolvedRelation(Seq("default", "flint_ppl_test"))) - val sortedPlan: LogicalPlan = Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, expectedPlan) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, expectedPlan) // Compare the two plans assert(sortedPlan === logicalPlan) } test("create ppl simple query two with fields and head (limit) test") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| fields name, age | head 1 | """.stripMargin) @@ -200,7 +192,9 @@ class FlintSparkPPLITSuite // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - val project = Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), UnresolvedRelation(Seq("default", "flint_ppl_test"))) + val project = Project( + Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), + UnresolvedRelation(Seq("default", "flint_ppl_test"))) // Define the expected logical plan val expectedPlan: LogicalPlan = Limit(Literal(1), Project(Seq(UnresolvedStar(None)), project)) // Compare the two plans @@ -208,23 +202,19 @@ class FlintSparkPPLITSuite } test("create ppl simple age literal equal filter query with two fields result test") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable age=25 | fields name, age | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row("John", 25), - ) + val expectedResults: Array[Row] = Array(Row("John", 25)) // Compare the results // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical // Define the expected logical plan @@ -237,30 +227,28 @@ class FlintSparkPPLITSuite assert(expectedPlan === logicalPlan) } - test("create ppl simple age literal greater than filter AND country not equal filter query with two fields result test") { - val frame = sql( - s""" + test( + "create ppl simple age literal greater than filter AND country not equal filter query with two fields result test") { + val frame = sql(s""" | source = $testTable age>10 and country != 'USA' | fields name, age | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row("John", 25), - Row("Jane", 20), - ) + val expectedResults: Array[Row] = Array(Row("John", 25), Row("Jane", 20)) // Compare the results // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical // Define the expected logical plan val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val filterExpr = And(Not(EqualTo(UnresolvedAttribute("country"), Literal("USA"))), GreaterThan(UnresolvedAttribute("age"), Literal(10))) + val filterExpr = And( + Not(EqualTo(UnresolvedAttribute("country"), Literal("USA"))), + GreaterThan(UnresolvedAttribute("age"), Literal(10))) val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) @@ -268,19 +256,16 @@ class FlintSparkPPLITSuite assert(expectedPlan === logicalPlan) } - test("create ppl simple age literal greater than filter AND country not equal filter query with two fields sorted result test") { - val frame = sql( - s""" + test( + "create ppl simple age literal greater than filter AND country not equal filter query with two fields sorted result test") { + val frame = sql(s""" | source = $testTable age>10 and country != 'USA' | sort - age | fields name, age | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row("John", 25), - Row("Jane", 20), - ) + val expectedResults: Array[Row] = Array(Row("John", 25), Row("Jane", 20)) // Compare the results assert(results === expectedResults) @@ -288,41 +273,41 @@ class FlintSparkPPLITSuite val logicalPlan: LogicalPlan = frame.queryExecution.logical // Define the expected logical plan val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val filterExpr = And(Not(EqualTo(UnresolvedAttribute("country"), Literal("USA"))), GreaterThan(UnresolvedAttribute("age"), Literal(10))) + val filterExpr = And( + Not(EqualTo(UnresolvedAttribute("country"), Literal("USA"))), + GreaterThan(UnresolvedAttribute("age"), Literal(10))) val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - val sortedPlan: LogicalPlan = Sort(Seq(SortOrder(UnresolvedAttribute("age"), Descending)), global = true, expectedPlan) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Descending)), global = true, expectedPlan) // Compare the two plans assert(sortedPlan === logicalPlan) } - test("create ppl simple age literal equal than filter OR country not equal filter query with two fields result test") { - val frame = sql( - s""" + test( + "create ppl simple age literal equal than filter OR country not equal filter query with two fields result test") { + val frame = sql(s""" | source = $testTable age<=20 OR country = 'USA' | fields name, age | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row("Jane", 20), - Row("Jake", 70), - Row("Hello", 30), - ) + val expectedResults: Array[Row] = Array(Row("Jane", 20), Row("Jake", 70), Row("Hello", 30)) // Compare the results // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical // Define the expected logical plan val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val filterExpr = Or(EqualTo(UnresolvedAttribute("country"), Literal("USA")), LessThanOrEqual(UnresolvedAttribute("age"), Literal(20))) + val filterExpr = Or( + EqualTo(UnresolvedAttribute("country"), Literal("USA")), + LessThanOrEqual(UnresolvedAttribute("age"), Literal(20))) val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) @@ -330,9 +315,9 @@ class FlintSparkPPLITSuite assert(expectedPlan === logicalPlan) } - test("create ppl simple age literal equal than filter OR country not equal filter query with two fields result and head (limit) test") { - val frame = sql( - s""" + test( + "create ppl simple age literal equal than filter OR country not equal filter query with two fields result and head (limit) test") { + val frame = sql(s""" | source = $testTable age<=20 OR country = 'USA' | fields name, age | head 1 | """.stripMargin) @@ -340,12 +325,13 @@ class FlintSparkPPLITSuite val results: Array[Row] = frame.collect() assert(results.length == 1) - // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical // Define the expected logical plan val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val filterExpr = Or(EqualTo(UnresolvedAttribute("country"), Literal("USA")), LessThanOrEqual(UnresolvedAttribute("age"), Literal(20))) + val filterExpr = Or( + EqualTo(UnresolvedAttribute("country"), Literal("USA")), + LessThanOrEqual(UnresolvedAttribute("age"), Literal(20))) val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val projectPlan = Project(Seq(UnresolvedStar(None)), Project(projectList, filterPlan)) @@ -355,18 +341,14 @@ class FlintSparkPPLITSuite } test("create ppl simple age literal greater than filter query with two fields result test") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable age>25 | fields name, age | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row("Jake", 70), - Row("Hello", 30) - ) + val expectedResults: Array[Row] = Array(Row("Jake", 70), Row("Hello", 30)) // Compare the results // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) @@ -384,20 +366,16 @@ class FlintSparkPPLITSuite assert(expectedPlan === logicalPlan) } - test("create ppl simple age literal smaller than equals filter query with two fields result test") { - val frame = sql( - s""" + test( + "create ppl simple age literal smaller than equals filter query with two fields result test") { + val frame = sql(s""" | source = $testTable age<=65 | fields name, age | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row("Hello", 30), - Row("John", 25), - Row("Jane", 20) - ) + val expectedResults: Array[Row] = Array(Row("Hello", 30), Row("John", 25), Row("Jane", 20)) // Compare the results // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) @@ -415,20 +393,16 @@ class FlintSparkPPLITSuite assert(expectedPlan === logicalPlan) } - test("create ppl simple age literal smaller than equals filter query with two fields result with sort test") { - val frame = sql( - s""" + test( + "create ppl simple age literal smaller than equals filter query with two fields result with sort test") { + val frame = sql(s""" | source = $testTable age<=65 | sort name | fields name, age | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row("Hello", 30), - Row("Jane", 20), - Row("John", 25), - ) + val expectedResults: Array[Row] = Array(Row("Hello", 30), Row("Jane", 20), Row("John", 25)) // Compare the results assert(results === expectedResults) @@ -440,23 +414,21 @@ class FlintSparkPPLITSuite val filterPlan = Filter(filterExpr, table) val projectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) val expectedPlan = Project(projectList, filterPlan) - val sortedPlan: LogicalPlan = Sort(Seq(SortOrder(UnresolvedAttribute("name"), Ascending)), global = true, expectedPlan) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("name"), Ascending)), global = true, expectedPlan) // Compare the two plans assert(sortedPlan === logicalPlan) } test("create ppl simple name literal equal filter query with two fields result test") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable name='Jake' | fields name, age | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row("Jake", 70) - ) + val expectedResults: Array[Row] = Array(Row("Jake", 70)) // Compare the results // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) @@ -475,19 +447,14 @@ class FlintSparkPPLITSuite } test("create ppl simple name literal not equal filter query with two fields result test") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable name!='Jake' | fields name, age | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row("Hello", 30), - Row("John", 25), - Row("Jane", 20) - ) + val expectedResults: Array[Row] = Array(Row("Hello", 30), Row("John", 25), Row("Jane", 20)) // Compare the results // Compare the results @@ -507,17 +474,14 @@ class FlintSparkPPLITSuite } test("create ppl simple age avg query test") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| stats avg(age) | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(36.25), - ) + val expectedResults: Array[Row] = Array(Row(36.25)) // Compare the results // Compare the results @@ -529,7 +493,8 @@ class FlintSparkPPLITSuite // Define the expected logical plan val ageField = UnresolvedAttribute("age") val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val aggregateExpressions = Seq(Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")()) + val aggregateExpressions = + Seq(Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")()) val aggregatePlan = Project(aggregateExpressions, table) // Compare the two plans @@ -537,17 +502,14 @@ class FlintSparkPPLITSuite } test("create ppl simple age avg query with filter test") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| where age < 50 | stats avg(age) | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(25), - ) + val expectedResults: Array[Row] = Array(Row(25)) // Compare the results // Compare the results @@ -561,7 +523,8 @@ class FlintSparkPPLITSuite val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) val filterExpr = LessThan(ageField, Literal(50)) val filterPlan = Filter(filterExpr, table) - val aggregateExpressions = Seq(Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")()) + val aggregateExpressions = + Seq(Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")()) val aggregatePlan = Project(aggregateExpressions, filterPlan) // Compare the two plans @@ -569,18 +532,14 @@ class FlintSparkPPLITSuite } test("create ppl simple age avg group by country query test ") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| stats avg(age) by country | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(22.5, "Canada"), - Row(50.0, "USA"), - ) + val expectedResults: Array[Row] = Array(Row(22.5, "Canada"), Row(50.0, "USA")) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) @@ -595,10 +554,12 @@ class FlintSparkPPLITSuite val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) val groupByAttributes = Seq(Alias(countryField, "country")()) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() val productAlias = Alias(countryField, "country")() - val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) val expectedPlan = Project(star, aggregatePlan) // Compare the two plans @@ -606,8 +567,7 @@ class FlintSparkPPLITSuite } test("create ppl simple age avg group by country head (limit) query test ") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| stats avg(age) by country | head 1 | """.stripMargin) @@ -623,10 +583,12 @@ class FlintSparkPPLITSuite val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) val groupByAttributes = Seq(Alias(countryField, "country")()) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() val productAlias = Alias(countryField, "country")() - val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) val projectPlan = Project(Seq(UnresolvedStar(None)), aggregatePlan) val expectedPlan = Limit(Literal(1), projectPlan) @@ -635,18 +597,14 @@ class FlintSparkPPLITSuite } test("create ppl simple age max group by country query test ") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| stats max(age) by country | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(70, "USA"), - Row(25, "Canada"), - ) + val expectedResults: Array[Row] = Array(Row(70, "USA"), Row(25, "Canada")) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) @@ -661,10 +619,12 @@ class FlintSparkPPLITSuite val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) val groupByAttributes = Seq(Alias(countryField, "country")()) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("MAX"), Seq(ageField), isDistinct = false), "max(age)")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("MAX"), Seq(ageField), isDistinct = false), "max(age)")() val productAlias = Alias(countryField, "country")() - val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) val expectedPlan = Project(star, aggregatePlan) // Compare the two plans @@ -672,18 +632,14 @@ class FlintSparkPPLITSuite } test("create ppl simple age min group by country query test ") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| stats min(age) by country | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(30, "USA"), - Row(20, "Canada"), - ) + val expectedResults: Array[Row] = Array(Row(30, "USA"), Row(20, "Canada")) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) @@ -698,10 +654,12 @@ class FlintSparkPPLITSuite val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) val groupByAttributes = Seq(Alias(countryField, "country")()) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("MIN"), Seq(ageField), isDistinct = false), "min(age)")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("MIN"), Seq(ageField), isDistinct = false), "min(age)")() val productAlias = Alias(countryField, "country")() - val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) val expectedPlan = Project(star, aggregatePlan) // Compare the two plans @@ -709,18 +667,14 @@ class FlintSparkPPLITSuite } test("create ppl simple age sum group by country query test ") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| stats sum(age) by country | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(100L, "USA"), - Row(45L, "Canada"), - ) + val expectedResults: Array[Row] = Array(Row(100L, "USA"), Row(45L, "Canada")) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](0)) @@ -735,10 +689,12 @@ class FlintSparkPPLITSuite val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) val groupByAttributes = Seq(Alias(countryField, "country")()) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("SUM"), Seq(ageField), isDistinct = false), "sum(age)")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("SUM"), Seq(ageField), isDistinct = false), "sum(age)")() val productAlias = Alias(countryField, "country")() - val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) val expectedPlan = Project(star, aggregatePlan) // Compare the two plans @@ -746,18 +702,14 @@ class FlintSparkPPLITSuite } test("create ppl simple age sum group by country order by age query test with sort ") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| stats sum(age) by country | sort country | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(45L, "Canada"), - Row(100L, "USA"), - ) + val expectedResults: Array[Row] = Array(Row(45L, "Canada"), Row(100L, "USA")) // Compare the results assert(results === expectedResults) @@ -771,36 +723,34 @@ class FlintSparkPPLITSuite val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) val groupByAttributes = Seq(Alias(countryField, "country")()) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("SUM"), Seq(ageField), isDistinct = false), "sum(age)")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("SUM"), Seq(ageField), isDistinct = false), "sum(age)")() val productAlias = Alias(countryField, "country")() - val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) val expectedPlan = Project(star, aggregatePlan) - val sortedPlan: LogicalPlan = Sort(Seq(SortOrder(UnresolvedAttribute("country"), Ascending)), global = true, expectedPlan) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("country"), Ascending)), global = true, expectedPlan) // Compare the two plans assert(compareByString(sortedPlan) === compareByString(logicalPlan)) } test("create ppl simple age count group by country query test ") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| stats count(age) by country | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(2L, "Canada"), - Row(2L, "USA"), - ) + val expectedResults: Array[Row] = Array(Row(2L, "Canada"), Row(2L, "USA")) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) assert( results.sorted.sameElements(expectedResults.sorted), - s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}" - ) + s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}") // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical @@ -811,32 +761,29 @@ class FlintSparkPPLITSuite val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) val groupByAttributes = Seq(Alias(countryField, "country")()) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() val productAlias = Alias(countryField, "country")() - val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) val expectedPlan = Project(star, aggregatePlan) // Compare the two plans assert( compareByString(expectedPlan) === compareByString(logicalPlan), - s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}" - ) + s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}") } test("create ppl simple age avg group by country with state filter query test ") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| where state != 'Quebec' | stats avg(age) by country | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(25.0, "Canada"), - Row(50.0, "USA"), - ) + val expectedResults: Array[Row] = Array(Row(25.0, "Canada"), Row(50.0, "USA")) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) @@ -852,12 +799,14 @@ class FlintSparkPPLITSuite val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) val groupByAttributes = Seq(Alias(countryField, "country")()) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() val productAlias = Alias(countryField, "country")() val filterExpr = Not(EqualTo(stateField, Literal("Quebec"))) val filterPlan = Filter(filterExpr, table) - val aggregatePlan = Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan) + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan) val expectedPlan = Project(star, aggregatePlan) // Compare the two plans @@ -865,28 +814,21 @@ class FlintSparkPPLITSuite } /** - * +--------+-------+-----------+ - * |age_span| count_age| - * +--------+-------+-----------+ - * | 20| 2 | - * | 30| 1 | - * | 70| 1 | - * +--------+-------+-----------+ + * | age_span | count_age | + * |:---------|----------:| + * | 20 | 2 | + * | 30 | 1 | + * | 70 | 1 | */ test("create ppl simple count age by span of interval of 10 years query test ") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| stats count(age) by span(age, 10) as age_span | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(1, 70L), - Row(1, 30L), - Row(2, 20L), - ) + val expectedResults: Array[Row] = Array(Row(1, 70L), Row(1, 30L), Row(2, 20L)) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1)) @@ -900,8 +842,11 @@ class FlintSparkPPLITSuite val ageField = UnresolvedAttribute("age") val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() - val span = Alias(Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "span (age,10,NONE)")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "span (age,10,NONE)")() val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) val expectedPlan = Project(star, aggregatePlan) @@ -910,19 +855,14 @@ class FlintSparkPPLITSuite } ignore("create ppl simple count age by span of interval of 10 years query order by age test ") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| stats count(age) by span(age, 10) as age_span | sort age_span | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(1, 70L), - Row(1, 30L), - Row(2, 20L), - ) + val expectedResults: Array[Row] = Array(Row(1, 70L), Row(1, 30L), Row(2, 20L)) // Compare the results assert(results === expectedResults) @@ -935,38 +875,37 @@ class FlintSparkPPLITSuite val ageField = UnresolvedAttribute("age") val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() - val span = Alias(Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "span (age,10,NONE)")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "span (age,10,NONE)")() val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) val expectedPlan = Project(star, aggregatePlan) - val sortedPlan: LogicalPlan = Sort(Seq(SortOrder(UnresolvedAttribute("span (age,10,NONE)"), Ascending)), global = true, expectedPlan) + val sortedPlan: LogicalPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("span (age,10,NONE)"), Ascending)), + global = true, + expectedPlan) // Compare the two plans assert(sortedPlan === logicalPlan) } /** - * +--------+-------+-----------+ - * |age_span| average_age| - * +--------+-------+-----------+ - * | 20| 22.5 | - * | 30| 30 | - * | 70| 70 | - * +--------+-------+-----------+ + * | age_span | average_age | + * |:---------|------------:| + * | 20 | 22.5 | + * | 30 | 30 | + * | 70 | 70 | */ test("create ppl simple avg age by span of interval of 10 years query test ") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| stats avg(age) by span(age, 10) as age_span | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(70D, 70L), - Row(30D, 30L), - Row(22.5D, 20L), - ) + val expectedResults: Array[Row] = Array(Row(70d, 70L), Row(30d, 30L), Row(22.5d, 20L)) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) @@ -980,8 +919,11 @@ class FlintSparkPPLITSuite val ageField = UnresolvedAttribute("age") val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() - val span = Alias(Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "span (age,10,NONE)")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "span (age,10,NONE)")() val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) val expectedPlan = Project(star, aggregatePlan) @@ -989,9 +931,9 @@ class FlintSparkPPLITSuite assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } - test("create ppl simple avg age by span of interval of 10 years with head (limit) query test ") { - val frame = sql( - s""" + test( + "create ppl simple avg age by span of interval of 10 years with head (limit) query test ") { + val frame = sql(s""" | source = $testTable| stats avg(age) by span(age, 10) as age_span | head 2 | """.stripMargin) @@ -1007,8 +949,11 @@ class FlintSparkPPLITSuite val ageField = UnresolvedAttribute("age") val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() - val span = Alias(Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "span (age,10,NONE)")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "span (age,10,NONE)")() val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) val projectPlan = Project(star, aggregatePlan) val expectedPlan = Limit(Literal(2), projectPlan) @@ -1018,28 +963,21 @@ class FlintSparkPPLITSuite } /** - * +--------+-------+-----------+ - * |age_span|country|average_age| - * +--------+-------+-----------+ - * | 20| Canada| 22.5| - * | 30| USA| 30| - * | 70| USA| 70| - * +--------+-------+-----------+ + * | age_span | country | average_age | + * |:---------|:--------|:------------| + * | 20 | Canada | 22.5 | + * | 30 | USA | 30 | + * | 70 | USA | 70 | */ ignore("create ppl average age by span of interval of 10 years group by country query test ") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| stats avg(age) by span(age, 10) as age_span, country | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = Array( - Row(1, 70L), - Row(1, 30L), - Row(2, 20L), - ) + val expectedResults: Array[Row] = Array(Row(1, 70L), Row(1, 30L), Row(2, 20L)) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1)) @@ -1054,8 +992,11 @@ class FlintSparkPPLITSuite val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) val groupByAttributes = Seq(Alias(countryField, "country")()) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() - val span = Alias(Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "span (age,10,NONE)")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "span (age,10,NONE)")() val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) val expectedPlan = Project(star, aggregatePlan) @@ -1063,9 +1004,9 @@ class FlintSparkPPLITSuite assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } - ignore("create ppl average age by span of interval of 10 years group by country head (limit) 2 query test ") { - val frame = sql( - s""" + ignore( + "create ppl average age by span of interval of 10 years group by country head (limit) 2 query test ") { + val frame = sql(s""" | source = $testTable| stats avg(age) by span(age, 10) as age_span, country | head 2 | """.stripMargin) @@ -1080,8 +1021,11 @@ class FlintSparkPPLITSuite val ageField = UnresolvedAttribute("age") val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() - val span = Alias(Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "span (age,10,NONE)")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "span (age,10,NONE)")() val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) val projectPlan = Project(star, aggregatePlan) val expectedPlan = Limit(Literal(1), projectPlan) @@ -1089,10 +1033,10 @@ class FlintSparkPPLITSuite // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } - - ignore("create ppl average age by span of interval of 10 years group by country head (limit) 2 query and sort by test ") { - val frame = sql( - s""" + + ignore( + "create ppl average age by span of interval of 10 years group by country head (limit) 2 query and sort by test ") { + val frame = sql(s""" | source = $testTable| stats avg(age) by span(age, 10) as age_span, country | head 2 | sort age_span | """.stripMargin) @@ -1107,12 +1051,16 @@ class FlintSparkPPLITSuite val ageField = UnresolvedAttribute("age") val table = UnresolvedRelation(Seq("default", "flint_ppl_test")) - val aggregateExpressions = Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() - val span = Alias(Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), "span (age,10,NONE)")() + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("COUNT"), Seq(ageField), isDistinct = false), "count(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "span (age,10,NONE)")() val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), table) val projectPlan = Project(star, aggregatePlan) val expectedPlan = Limit(Literal(1), projectPlan) - val sortedPlan: LogicalPlan = Sort(Seq(SortOrder(UnresolvedAttribute("age"), Descending)), global = true, expectedPlan) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Descending)), global = true, expectedPlan) // Compare the two plans assert(compareByString(sortedPlan) === compareByString(logicalPlan)) } diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala new file mode 100644 index 000000000..cfcefe7cb --- /dev/null +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala @@ -0,0 +1,156 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{And, Ascending, EqualTo, Literal, SortOrder} +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.JoinHint.NONE +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLCorrelationITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable1 = "spark_catalog.default.flint_ppl_test1" + private val testTable2 = "spark_catalog.default.flint_ppl_test2" + + override def beforeAll(): Unit = { + super.beforeAll() + // Create test tables + sql(s""" + | CREATE TABLE $testTable1 + | ( + | name STRING, + | age INT, + | state STRING, + | country STRING + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\t' + | ) + | PARTITIONED BY ( + | year INT, + | month INT + | ) + |""".stripMargin) + + sql(s""" + | CREATE TABLE $testTable2 + | ( + | name STRING, + | occupation STRING, + | salary INT + | ) + | USING CSV + | OPTIONS ( + | header 'false', + | delimiter '\t' + | ) + | PARTITIONED BY ( + | year INT, + | month INT + | ) + |""".stripMargin) + + // Update data insertion + sql(s""" + | INSERT INTO $testTable1 + | PARTITION (year=2023, month=4) + | VALUES ('Jake', 70, 'California', 'USA'), + | ('Hello', 30, 'New York', 'USA'), + | ('John', 25, 'Ontario', 'Canada'), + | ('Jane', 20, 'Quebec', 'Canada') + | """.stripMargin) + // Insert data into the new table + sql(s""" + | INSERT INTO $testTable2 + | PARTITION (year=2023, month=4) + | VALUES ('Jake', 'Engineer', 100000), + | ('Hello', 'Artist', 70000), + | ('John', 'Doctor', 120000), + | ('Jane', 'Scientist', 90000) + | """.stripMargin) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("create ppl correlation query with two tables correlating on a single field test") { + val joinQuery = + s""" + | SELECT a.name, a.age, a.state, a.country, b.occupation, b.salary + | FROM $testTable1 AS a + | JOIN $testTable2 AS b + | ON a.name = b.name + | WHERE a.year = 2023 AND a.month = 4 AND b.year = 2023 AND b.month = 4 + |""".stripMargin + + val result = spark.sql(joinQuery) + result.show() + + val frame = sql(s""" + | source = $testTable1, $testTable2| where year = 2023 AND month = 4 | correlate exact fields(name) scope(month, 1W) mapping($testTable1.name = $testTable2.name) + | """.stripMargin) + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row("Jake", 70, "California", "USA", 2023, 4, "Jake", "Engineer", 100000, 2023, 4), + Row("Hello", 30, "New York", "USA", 2023, 4, "Hello", "Artist", 70000, 2023, 4), + Row("John", 25, "Ontario", "Canada", 2023, 4, "John", "Doctor", 120000, 2023, 4), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, "Jane", "Scientist", 90000, 2023, 4)) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + // Compare the results + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Define unresolved relations + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + + // Define filter expressions + val filter1Expr = And( + EqualTo(UnresolvedAttribute("year"), Literal(2023)), + EqualTo(UnresolvedAttribute("month"), Literal(4))) + val filter2Expr = And( + EqualTo(UnresolvedAttribute("year"), Literal(2023)), + EqualTo(UnresolvedAttribute("month"), Literal(4))) + // Define subquery aliases + val plan1 = Filter(filter1Expr, table1) + val plan2 = Filter(filter2Expr, table2) + + // Define join condition + val joinCondition = + EqualTo(UnresolvedAttribute(s"$testTable1.name"), UnresolvedAttribute(s"$testTable2.name")) + + // Create Join plan + val joinPlan = Join(plan1, plan2, Inner, Some(joinCondition), JoinHint.NONE) + + // Add the projection + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + +} diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala index fb46ce4de..b2aebf03b 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFiltersITSuite.scala @@ -158,7 +158,6 @@ class FlintSparkPPLFiltersITSuite // Define the expected results val expectedResults: Array[Row] = Array(Row("Jane", 20), Row("Jake", 70), Row("Hello", 30)) // Compare the results - // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 78c687e65..b1c988b28 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -38,6 +38,7 @@ ML: 'ML'; //CORRELATION KEYWORDS CORRELATE: 'CORRELATE'; +SELF: 'SELF'; EXACT: 'EXACT'; APPROXIMATE: 'APPROXIMATE'; SCOPE: 'SCOPE'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index b8a0f5fe5..0223dab8d 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -62,32 +62,33 @@ describeCommand ; showDataSourcesCommand - : SHOW DATASOURCES - ; + : SHOW DATASOURCES + ; whereCommand - : WHERE logicalExpression - ; + : WHERE logicalExpression + ; correlateCommand - : CORRELATE correlationType FIELDS LT_PRTHS fieldList RT_PRTHS scopeClause (mappingList)? - ; + : CORRELATE correlationType FIELDS LT_PRTHS fieldList RT_PRTHS scopeClause mappingList + ; correlationType - : EXACT + : SELF + | EXACT | APPROXIMATE ; scopeClause - : SCOPE LT_PRTHS fieldExpression COMMA value = literalValue (unit = timespanUnit)? RT_PRTHS - ; + : SCOPE LT_PRTHS fieldExpression COMMA value = literalValue (unit = timespanUnit)? RT_PRTHS + ; mappingList : MAPPING LT_PRTHS ( mappingClause (COMMA mappingClause)* ) RT_PRTHS ; mappingClause - : qualifiedName EQUAL qualifiedName + : left = qualifiedName comparisonOperator right = qualifiedName # mappingCompareExpr ; fieldsCommand diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 49fd4bda6..e3d0c6a2b 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -16,6 +16,7 @@ import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; import org.opensearch.sql.ast.expression.Interval; @@ -99,6 +100,10 @@ public T visitCorrelation(Correlation node, C context) { return visitChildren(node, context); } + public T visitCorrelationMapping(FieldsMapping node, C context) { + return visitChildren(node, context); + } + public T visitProject(Project node, C context) { return visitChildren(node, context); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/FieldsMapping.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/FieldsMapping.java index d3157f7f8..37d31b822 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/FieldsMapping.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/FieldsMapping.java @@ -6,15 +6,17 @@ public class FieldsMapping extends UnresolvedExpression { - private final List fieldsMappingList; public FieldsMapping(List fieldsMappingList) { this.fieldsMappingList = fieldsMappingList; } + public List getChild() { + return fieldsMappingList; + } @Override public R accept(AbstractNodeVisitor nodeVisitor, C context) { - return nodeVisitor.visit(this, context); + return nodeVisitor.visitCorrelationMapping(this, context); } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Correlation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Correlation.java index e67427ce2..6cc2b66ff 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Correlation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Correlation.java @@ -58,6 +58,7 @@ public FieldsMapping getMappingListContext() { } public enum CorrelationType { + self, exact, approximate } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java index 7e21ac9a9..4145f5628 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java @@ -6,7 +6,6 @@ package org.opensearch.sql.ppl; import org.apache.spark.sql.catalyst.expressions.Expression; -import org.apache.spark.sql.catalyst.expressions.SortOrder; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.apache.spark.sql.catalyst.plans.logical.Union; import scala.collection.Seq; @@ -37,7 +36,11 @@ public class CatalystPlanContext { * Grouping NamedExpression contextual parameters **/ private final Stack groupingParseExpressions = new Stack<>(); - + + public Stack getPlanBranches() { + return planBranches; + } + public LogicalPlan getPlan() { if (this.planBranches.size() == 1) { return planBranches.peek(); @@ -58,9 +61,10 @@ public Stack getGroupingParseExpressions() { * append context with evolving plan * * @param plan + * @return */ - public void with(LogicalPlan plan) { - this.planBranches.push(plan); + public LogicalPlan with(LogicalPlan plan) { + return this.planBranches.push(plan); } public LogicalPlan plan(Function transformFunction) { @@ -69,12 +73,22 @@ public LogicalPlan plan(Function transformFunction) { } /** + * retain all logical plans branches + * @return + */ + public Seq retainAllPlans(Function transformFunction) { + Seq plans = seq(getPlanBranches().stream().map(transformFunction).collect(Collectors.toList())); + getPlanBranches().retainAll(emptyList()); + return plans; + } + /** + * * retain all expressions and clear expression stack * @return */ public Seq retainAllNamedParseExpressions(Function transformFunction) { Seq aggregateExpressions = seq(getNamedParseExpressions().stream() - .map(transformFunction::apply).collect(Collectors.toList())); + .map(transformFunction).collect(Collectors.toList())); getNamedParseExpressions().retainAll(emptyList()); return aggregateExpressions; } @@ -85,7 +99,7 @@ public Seq retainAllNamedParseExpressions(Function transfo */ public Seq retainAllGroupingNamedParseExpressions(Function transformFunction) { Seq aggregateExpressions = seq(getGroupingParseExpressions().stream() - .map(transformFunction::apply).collect(Collectors.toList())); + .map(transformFunction).collect(Collectors.toList())); getGroupingParseExpressions().retainAll(emptyList()); return aggregateExpressions; } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 7e5960db0..8b0998720 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -26,12 +26,14 @@ import org.opensearch.sql.ast.expression.Case; import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; +import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.expression.WindowFunction; @@ -97,10 +99,10 @@ public LogicalPlan visitExplain(Explain node, CatalystPlanContext context) { @Override public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) { - node.getTableName().forEach(t -> { + node.getTableName().forEach(t -> // Resolving the qualifiedName which is composed of a datasource.schema.table - context.with(new UnresolvedRelation(seq(of(t.split("\\."))), CaseInsensitiveStringMap.empty(), false)); - }); + context.with(new UnresolvedRelation(seq(of(t.split("\\."))), CaseInsensitiveStringMap.empty(), false)) + ); return context.getPlan(); } @@ -114,15 +116,15 @@ public LogicalPlan visitFilter(Filter node, CatalystPlanContext context) { @Override public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext context) { + node.getChild().get(0).accept(this, context); visitFieldList(node.getFieldsList().stream().map(Field::new).collect(Collectors.toList()), context); Seq fields = context.retainAllNamedParseExpressions(e -> e); expressionAnalyzer.visitSpan(node.getScope(), context); Expression scope = context.getNamedParseExpressions().pop(); - node.getMappingListContext().accept(this, context); + expressionAnalyzer.visitCorrelationMapping(node.getMappingListContext(), context); Seq mapping = context.retainAllNamedParseExpressions(e -> e); - return context.plan(p -> join(node.getCorrelationType(), fields, scope, mapping, p)); + return join(node.getCorrelationType(), fields, scope, mapping, context); } - @Override public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext context) { @@ -317,11 +319,28 @@ public Expression visitCompare(Compare node, CatalystPlanContext context) { return context.getNamedParseExpressions().push((org.apache.spark.sql.catalyst.expressions.Expression) comparator); } + @Override + public Expression visitQualifiedName(QualifiedName node, CatalystPlanContext context) { + return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); + } + @Override public Expression visitField(Field node, CatalystPlanContext context) { return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getField().toString()))); } + @Override + public Expression visitCorrelation(Correlation node, CatalystPlanContext context) { + return super.visitCorrelation(node, context); + } + + @Override + public Expression visitCorrelationMapping(FieldsMapping node, CatalystPlanContext context) { + return node.getChild().stream().map(expression -> + visitCompare((Compare) expression, context) + ).reduce(org.apache.spark.sql.catalyst.expressions.And::new).orElse(null); + } + @Override public Expression visitAllFields(AllFields node, CatalystPlanContext context) { // Case of aggregation step - no start projection can be added diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index e7d723afd..3344cd7c2 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -62,6 +62,16 @@ public class AstExpressionBuilder extends OpenSearchPPLParserBaseVisitor fields, Expression valueExpression, Seq mapping, LogicalPlan p) { - //create a join statement - return p; + /** + * @param correlationType the correlation type which can be exact (inner join) or approximate (outer join) + * @param fields - fields (columns) that needed to be joined by + * @param scope - this is a time base expression that timeframes the join to a specific period : (Time-field-name, value, unit) + * @param mapping - in case fields in different relations have different name, that can be aliased with the following names + * @param context - parent context including the plan to evolve to join with + * @return + */ + static LogicalPlan join(Correlation.CorrelationType correlationType, Seq fields, Expression scope, Seq mapping, CatalystPlanContext context) { + //create a join statement - which will replace all the different plans with a single plan which contains the joined plans + switch (correlationType) { + case self: + //expecting exactly one source relation + if (context.getPlanBranches().size() != 1) + throw new IllegalStateException("Correlation command with `inner` type must have exactly on source table "); + break; + case exact: + //expecting at least two source relations + if (context.getPlanBranches().size() < 2) + throw new IllegalStateException("Correlation command with `exact` type must at least two source tables "); + break; + case approximate: + if (context.getPlanBranches().size() < 2) + throw new IllegalStateException("Correlation command with `approximate` type must at least two source tables "); + //expecting at least two source relations + break; + } + + // Define join condition + Expression joinCondition = buildJoinCondition(seqAsJavaListConverter(fields).asJava(), seqAsJavaListConverter(mapping).asJava(), correlationType); + // extract the plans from the context + List logicalPlans = seqAsJavaListConverter(context.retainAllPlans(p -> p)).asJava(); + // Define join step instead on the multiple query branches + return context.with(logicalPlans.stream().reduce((left, right) + -> new Join(left, right, getType(correlationType), Option.apply(joinCondition), JoinHint.NONE())).get()); + } + + static Expression buildJoinCondition(List fields, List mapping, Correlation.CorrelationType correlationType) { + switch (correlationType) { + case self: + //expecting exactly one source relation - mapping will be used to set the inner join counterpart + break; + case exact: + //expecting at least two source relations + return mapping.stream().reduce(org.apache.spark.sql.catalyst.expressions.And::new).orElse(null); + case approximate: + return mapping.stream().reduce(org.apache.spark.sql.catalyst.expressions.Or::new).orElse(null); + } + return mapping.stream().reduce(org.apache.spark.sql.catalyst.expressions.And::new).orElse(null); + } + + static JoinType getType(Correlation.CorrelationType correlationType) { + switch (correlationType) { + case self: + case exact: + return Inner$.MODULE$; + case approximate: + return FullOuter$.MODULE$; + } + return Inner$.MODULE$; } } \ No newline at end of file diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala index e61615ad2..87f7e5b28 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala @@ -9,6 +9,7 @@ import org.junit.Assert.assertEquals import org.opensearch.flint.spark.ppl.PlaneUtils.plan import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.scalatest.matchers.should.Matchers + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Divide, EqualTo, Floor, GreaterThanOrEqual, Literal, Multiply, SortOrder, TimeWindow} @@ -332,7 +333,8 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logPlan)) } - test("create ppl query count only error (status >= 400) status amount by day window and group by status test") { + test( + "create ppl query count only error (status >= 400) status amount by day window and group by status test") { val context = new CatalystPlanContext val logPlan = planTrnasformer.visit( plan( @@ -358,12 +360,11 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite "status_count_by_day")() val aggregateExpressions = - Alias( - UnresolvedFunction(Seq("SUM"), Seq(statusField), isDistinct = false), - "sum(status)")() + Alias(UnresolvedFunction(Seq("SUM"), Seq(statusField), isDistinct = false), "sum(status)")() val aggregatePlan = Aggregate( Seq(statusAlias, windowExpression), - Seq(aggregateExpressions, statusAlias, windowExpression), filterPlan) + Seq(aggregateExpressions, statusAlias, windowExpression), + filterPlan) val planWithLimit = GlobalLimit(Literal(100), LocalLimit(Literal(100), aggregatePlan)) val expectedPlan = Project(star, planWithLimit) // Compare the two plans diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanCorrelationQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanCorrelationQueriesTranslatorTestSuite.scala index 888329d31..fa6581ecf 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanCorrelationQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanCorrelationQueriesTranslatorTestSuite.scala @@ -5,15 +5,16 @@ package org.opensearch.flint.spark.ppl -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Literal, NamedExpression, SortOrder} -import org.apache.spark.sql.catalyst.plans.logical._ import org.junit.Assert.assertEquals import org.opensearch.flint.spark.ppl.PlaneUtils.plan import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.scalatest.matchers.should.Matchers +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Literal, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical._ + class PPLLogicalPlanCorrelationQueriesTranslatorTestSuite extends SparkFunSuite with LogicalPlanTestUtils @@ -22,7 +23,6 @@ class PPLLogicalPlanCorrelationQueriesTranslatorTestSuite private val planTrnasformer = new CatalystQueryPlanVisitor() private val pplParser = new PPLSyntaxParser() - test("Search multiple tables with correlation - translated into join call with fields") { val context = new CatalystPlanContext val query = "source = table1, table2 | correlate exact fields(ip, port) scope(@timestamp, 1d)" @@ -42,9 +42,11 @@ class PPLLogicalPlanCorrelationQueriesTranslatorTestSuite assertEquals(expectedPlan, logPlan) } - test("Search multiple tables with correlation with filters - translated into join call with fields") { + test( + "Search multiple tables with correlation with filters - translated into join call with fields") { val context = new CatalystPlanContext - val query = "source = table1, table2 | where @timestamp=`2018-07-02T22:23:00` AND ip=`10.0.0.1` AND cloud.provider=`aws` | correlate exact fields(ip, port) scope(@timestamp, 1d)" + val query = + "source = table1, table2 | where @timestamp=`2018-07-02T22:23:00` AND ip=`10.0.0.1` AND cloud.provider=`aws` | correlate exact fields(ip, port) scope(@timestamp, 1d)" val logPlan = planTrnasformer.visit(plan(pplParser, query, false), context) val table1 = UnresolvedRelation(Seq("table1")) @@ -61,10 +63,12 @@ class PPLLogicalPlanCorrelationQueriesTranslatorTestSuite assertEquals(expectedPlan, logPlan) } - test("Search multiple tables with correlation - translated into join call with different fields mapping ") { + test( + "Search multiple tables with correlation - translated into join call with different fields mapping ") { val context = new CatalystPlanContext - val query = "source = table1, table2 | where @timestamp=`2018-07-02T22:23:00` AND ip=`10.0.0.1` AND cloud.provider=`aws` | correlate exact fields(ip, port) scope(@timestamp, 1d)" + - " mapping( alb_logs.ip = traces.source_ip, alb_logs.port = metrics.target_port )" + val query = + "source = table1, table2 | where @timestamp=`2018-07-02T22:23:00` AND ip=`10.0.0.1` AND cloud.provider=`aws` | correlate exact fields(ip, port) scope(@timestamp, 1d)" + + " mapping( alb_logs.ip = traces.source_ip, alb_logs.port = metrics.target_port )" val logPlan = planTrnasformer.visit(plan(pplParser, query, false), context) val table1 = UnresolvedRelation(Seq("table1"))