diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala index 756ebf139..c5828179d 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLCorrelationITSuite.scala @@ -142,6 +142,16 @@ class FlintSparkPPLCorrelationITSuite assert( thrown.getMessage === "Correlation command was called with `fields` attribute having different elements from the 'mapping' attributes ") } + + test("create failing ppl correlation query with no scope - due to mismatch fields to mappings test") { + val thrown = intercept[IllegalStateException] { + val frame = sql(s""" + | source = $testTable1, $testTable2| correlate exact fields(name, country) mapping($testTable1.name = $testTable2.name) + | """.stripMargin) + } + assert( + thrown.getMessage === "Correlation command was called with `fields` attribute having different elements from the 'mapping' attributes ") + } test( "create failing ppl correlation query - due to mismatch correlation self type and source amount test") { @@ -293,6 +303,60 @@ class FlintSparkPPLCorrelationITSuite // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + + test( + "create ppl correlation approximate query with two tables correlating on a single field and not scope test") { + val frame = sql(s""" + | source = $testTable1, $testTable2| correlate approximate fields(name) mapping($testTable1.name = $testTable2.name) + | """.stripMargin) + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row("David", 40, "Washington", "USA", 2023, 4, "David", "Doctor", "USA", 120000, 2023, 4), + Row("David", 40, "Washington", "USA", 2023, 4, "David", "Unemployed", "Canada", 0, 2023, 4), + Row("Hello", 30, "New York", "USA", 2023, 4, "Hello", "Artist", "USA", 70000, 2023, 4), + Row( + "Jake", + 70, + "California", + "USA", + 2023, + 4, + "Jake", + "Engineer", + "England", + 100000, + 2023, + 4), + Row("Jane", 20, "Quebec", "Canada", 2023, 4, "Jane", "Scientist", "Canada", 90000, 2023, 4), + Row("Jim", 27, "B.C", "Canada", 2023, 4, null, null, null, null, null, null), + Row("John", 25, "Ontario", "Canada", 2023, 4, "John", "Doctor", "Canada", 120000, 2023, 4), + Row("Peter", 57, "B.C", "Canada", 2023, 4, null, null, null, null, null, null), + Row("Rick", 70, "B.C", "Canada", 2023, 4, null, null, null, null, null, null)) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + // Compare the results + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Define unresolved relations + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + // Define join condition + val joinCondition = + EqualTo(UnresolvedAttribute(s"$testTable1.name"), UnresolvedAttribute(s"$testTable2.name")) + + // Create Join plan + val joinPlan = Join(table1, table2, FullOuter, Some(joinCondition), JoinHint.NONE) + + // Add the projection + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } test( "create ppl correlation query with with filters and two tables correlating on a two fields test") { @@ -562,6 +626,64 @@ class FlintSparkPPLCorrelationITSuite // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + test( + "create ppl correlation (exact) query with two tables correlating by name,country and group by avg salary by age span (10 years bucket) with country filter without scope test") { + val frame = sql(s""" + | source = $testTable1, $testTable2 | where country = 'USA' OR country = 'England' | + | correlate exact fields(name) mapping($testTable1.name = $testTable2.name) | + | stats avg(salary) by span(age, 10) as age_span, $testTable2.country + | """.stripMargin) + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = + Array(Row(120000.0, "USA", 40), Row(100000.0, "England", 70), Row(70000.0, "USA", 30)) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + // Compare the results + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Define unresolved relations + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + + // Define filter expressions + val filter1Expr = Or( + EqualTo(UnresolvedAttribute("country"), Literal("USA")), + EqualTo(UnresolvedAttribute("country"), Literal("England"))) + val filter2Expr = Or( + EqualTo(UnresolvedAttribute("country"), Literal("USA")), + EqualTo(UnresolvedAttribute("country"), Literal("England"))) + // Define subquery aliases + val plan1 = Filter(filter1Expr, table1) + val plan2 = Filter(filter2Expr, table2) + + // Define join condition + val joinCondition = + EqualTo(UnresolvedAttribute(s"$testTable1.name"), UnresolvedAttribute(s"$testTable2.name")) + + // Create Join plan + val joinPlan = Join(plan1, plan2, Inner, Some(joinCondition), JoinHint.NONE) + + val salaryField = UnresolvedAttribute("salary") + val countryField = UnresolvedAttribute(s"$testTable2.country") + val countryAlias = Alias(countryField, s"$testTable2.country")() + val star = Seq(UnresolvedStar(None)) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(salaryField), isDistinct = false), "avg(salary)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = + Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), joinPlan) + // Add the projection + val expectedPlan = Project(star, aggregatePlan) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } test( "create ppl correlation (approximate) query with two tables correlating by name,country and group by avg salary by age span (10 years bucket) test") { diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 0223dab8d..4b4e64c1a 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -70,7 +70,7 @@ whereCommand ; correlateCommand - : CORRELATE correlationType FIELDS LT_PRTHS fieldList RT_PRTHS scopeClause mappingList + : CORRELATE correlationType FIELDS LT_PRTHS fieldList RT_PRTHS (scopeClause)? mappingList ; correlationType diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Scope.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Scope.java index 934c13d6b..3fbe53cd2 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Scope.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/Scope.java @@ -5,5 +5,4 @@ public class Scope extends Span { public Scope(UnresolvedExpression field, UnresolvedExpression value, SpanUnit unit) { super(field, value, unit); } - } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 320e6617c..6d14db328 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -127,11 +127,14 @@ public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext contex context.reduce((left,right) -> { visitFieldList(node.getFieldsList().stream().map(Field::new).collect(Collectors.toList()), context); Seq fields = context.retainAllNamedParseExpressions(e -> e); - expressionAnalyzer.visitSpan(node.getScope(), context); - Expression scope = context.popNamedParseExpressions().get(); + if(!Objects.isNull(node.getScope())) { + // scope - this is a time base expression that timeframes the join to a specific period : (Time-field-name, value, unit) + expressionAnalyzer.visitSpan(node.getScope(), context); + context.popNamedParseExpressions().get(); + } expressionAnalyzer.visitCorrelationMapping(node.getMappingListContext(), context); Seq mapping = context.retainAllNamedParseExpressions(e -> e); - return join(node.getCorrelationType(), fields, scope, mapping, left, right); + return join(node.getCorrelationType(), fields, mapping, left, right); }); return context.getPlan(); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 2e2b4eae3..a810ea180 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -112,9 +112,9 @@ public UnresolvedPlan visitCorrelateCommand(OpenSearchPPLParser.CorrelateCommand ctx.fieldList().fieldExpression().stream() .map(this::internalVisitExpression) .collect(Collectors.toList()), - new Scope(expressionBuilder.visit(ctx.scopeClause().fieldExpression()), + Objects.isNull(ctx.scopeClause()) ? null : new Scope(expressionBuilder.visit(ctx.scopeClause().fieldExpression()), expressionBuilder.visit(ctx.scopeClause().value), - SpanUnit.of(ctx.scopeClause().unit.getText())), + SpanUnit.of(Objects.isNull(ctx.scopeClause().unit) ? "" : ctx.scopeClause().unit.getText())), Objects.isNull(ctx.mappingList()) ? new FieldsMapping(emptyList()) : new FieldsMapping(ctx.mappingList() .mappingClause().stream() .map(this::internalVisitExpression) diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JoinSpecTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JoinSpecTransformer.java index 74cb181c7..2ae6302eb 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JoinSpecTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/JoinSpecTransformer.java @@ -20,11 +20,10 @@ public interface JoinSpecTransformer { /** * @param correlationType the correlation type which can be exact (inner join) or approximate (outer join) * @param fields - fields (columns) that needed to be joined by - * @param scope - this is a time base expression that timeframes the join to a specific period : (Time-field-name, value, unit) * @param mapping - in case fields in different relations have different name, that can be aliased with the following names * @return */ - static LogicalPlan join(Correlation.CorrelationType correlationType, Seq fields, Expression scope, Seq mapping, LogicalPlan left, LogicalPlan right) { + static LogicalPlan join(Correlation.CorrelationType correlationType, Seq fields, Seq mapping, LogicalPlan left, LogicalPlan right) { //create a join statement - which will replace all the different plans with a single plan which contains the joined plans switch (correlationType) { case self: