diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala index cfcefe7cb..c345fd5bc 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala @@ -51,6 +51,7 @@ class FlintSparkPPLCorrelationITSuite | ( | name STRING, | occupation STRING, + | country STRING, | salary INT | ) | USING CSV @@ -71,16 +72,22 @@ class FlintSparkPPLCorrelationITSuite | VALUES ('Jake', 70, 'California', 'USA'), | ('Hello', 30, 'New York', 'USA'), | ('John', 25, 'Ontario', 'Canada'), + | ('Jim', 27, 'B.C', 'Canada'), + | ('Peter', 57, 'B.C', 'Canada'), + | ('Rick', 70, 'B.C', 'Canada'), + | ('David', 40, 'Washington', 'USA'), | ('Jane', 20, 'Quebec', 'Canada') | """.stripMargin) // Insert data into the new table sql(s""" | INSERT INTO $testTable2 | PARTITION (year=2023, month=4) - | VALUES ('Jake', 'Engineer', 100000), - | ('Hello', 'Artist', 70000), - | ('John', 'Doctor', 120000), - | ('Jane', 'Scientist', 90000) + | VALUES ('Jake', 'Engineer', 'England' , 100000), + | ('Hello', 'Artist', 'USA', 70000), + | ('John', 'Doctor', 'Canada', 120000), + | ('David', 'Doctor', 'USA', 120000), + | ('David', 'Unemployed', 'Canada', 0), + | ('Jane', 'Scientist', 'Canada', 90000) | """.stripMargin) } @@ -113,10 +120,12 @@ class FlintSparkPPLCorrelationITSuite val results: Array[Row] = frame.collect() // Define the expected results val expectedResults: Array[Row] = Array( - Row("Jake", 70, "California", "USA", 2023, 4, "Jake", "Engineer", 100000, 2023, 4), - Row("Hello", 30, "New York", "USA", 2023, 4, "Hello", "Artist", 70000, 2023, 4), - Row("John", 25, "Ontario", "Canada", 2023, 4, "John", "Doctor", 120000, 2023, 4), - Row("Jane", 20, "Quebec", "Canada", 2023, 4, "Jane", "Scientist", 90000, 2023, 4)) + Row("Jake", 70, "California", "USA", 2023, 4, "Jake", "Engineer", "England", 100000, 2023, 4), + Row("Hello", 30, "New York", "USA", 2023, 4, "Hello", "Artist", "USA", 70000, 2023, 4), + Row("John", 25, "Ontario", "Canada", 2023, 4, "John", "Doctor", "Canada", 120000, 2023, 4), + Row("David", 40, "Washington", "USA", 2023, 4, "David", "Unemployed", "Canada", 0, 2023, 4), + Row("David", 40, "Washington", "USA", 2023, 4, "David", "Doctor", "USA", 120000, 2023, 4), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, "Jane", "Scientist", "Canada", 90000, 2023, 4)) implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) // Compare the results @@ -153,4 +162,55 @@ class FlintSparkPPLCorrelationITSuite assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + test("create ppl correlation query with two tables correlating on a two fields test") { + val frame = sql(s""" + | source = $testTable1, $testTable2| where year = 2023 AND month = 4 | correlate exact fields(name, country) scope(month, 1W) + | mapping($testTable1.name = $testTable2.name, $testTable1.country = $testTable2.country) + | """.stripMargin) + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row("David", 40, "Washington", "USA", 2023, 4, "David", "Doctor", "USA", 120000, 2023, 4), + Row("Hello", 30, "New York", "USA", 2023, 4, "Hello", "Artist", "USA", 70000, 2023, 4), + Row("John", 25, "Ontario", "Canada", 2023, 4, "John", "Doctor", "Canada", 120000, 2023, 4), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, "Jane", "Scientist", "Canada", 90000, 2023, 4)) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + // Compare the results + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Define unresolved relations + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + + // Define filter expressions + val filter1Expr = And( + EqualTo(UnresolvedAttribute("year"), Literal(2023)), + EqualTo(UnresolvedAttribute("month"), Literal(4))) + val filter2Expr = And( + EqualTo(UnresolvedAttribute("year"), Literal(2023)), + EqualTo(UnresolvedAttribute("month"), Literal(4))) + // Define subquery aliases + val plan1 = Filter(filter1Expr, table1) + val plan2 = Filter(filter2Expr, table2) + + // Define join condition + val joinCondition = + And( + EqualTo(UnresolvedAttribute(s"$testTable1.name"), UnresolvedAttribute(s"$testTable2.name")), + EqualTo(UnresolvedAttribute(s"$testTable1.country"), UnresolvedAttribute(s"$testTable2.country")) + ) + + // Create Join plan + val joinPlan = Join(plan1, plan2, Inner, Some(joinCondition), JoinHint.NONE) + + // Add the projection + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } }