From 64685fbcaa4fced4d293bd99815f4d037260f73c Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Sat, 12 Oct 2024 18:25:17 +0800 Subject: [PATCH] Support RelationSubquery PPL Signed-off-by: Lantao Jin --- docs/ppl-lang/PPL-Example-Commands.md | 12 +- docs/ppl-lang/ppl-subquery-command.md | 63 ++++- .../spark/ppl/FlintSparkPPLJoinITSuite.scala | 190 ++++++++++++++- .../src/main/antlr4/OpenSearchPPLParser.g4 | 34 ++- .../org/opensearch/sql/ast/tree/Relation.java | 13 +- .../opensearch/sql/ppl/parser/AstBuilder.java | 24 +- ...PLLogicalPlanJoinTranslatorTestSuite.scala | 228 +++++++++++++++++- 7 files changed, 523 insertions(+), 41 deletions(-) diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 8e6cbaae9..390ad2042 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -240,8 +240,7 @@ source = table | where ispresent(a) | - `source = table1 | cross join left = l right = r table2` - `source = table1 | left semi join left = l right = r on l.a = r.a table2` - `source = table1 | left anti join left = l right = r on l.a = r.a table2` - -_- **Limitation: sub-searches is unsupported in join right side now**_ +- `source = table1 | join left = l right = r [ source = table2 | where d > 10 | head 5 ]` #### **Lookup** @@ -349,6 +348,15 @@ Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table in - `source = outer | where a = [ source = inner | stats max(c) | sort c ] OR b = [ source = inner | where c = 1 | stats min(d) | sort d ]` - `source = outer | where a = [ source = inner | where c = [ source = nested | stats max(e) by f | sort f ] | stats max(d) by c | sort c | head 1 ]` +#### **(Relation) Subquery** +[See additional command details](ppl-subquery-command.md) + +`InSubquery`, `ExistsSubquery` and `ScalarSubquery` are all subquery expressions. But `RelationSubquery` is not a subquery expression, it is a subquery plan which is common used in Join or From clause. + +- `source = table1 | join left = l right = r [ source = table2 | where d > 10 | head 5 ]` (subquery in join right side) +- `source = [ source = table1 | join left = l right = r [ source = table2 | where d > 10 | head 5 ] | stats count(a) by b ] as outer | head 1` + +_- **Limitation: another command usage of (relation) subquery is in `appendcols` commands which is unsupported**_ --- #### Experimental Commands: diff --git a/docs/ppl-lang/ppl-subquery-command.md b/docs/ppl-lang/ppl-subquery-command.md index ac0f98fe8..93ed57371 100644 --- a/docs/ppl-lang/ppl-subquery-command.md +++ b/docs/ppl-lang/ppl-subquery-command.md @@ -1,6 +1,6 @@ ## PPL SubQuery Commands: -**Syntax** +### Syntax The subquery command should be implemented using a clean, logical syntax that integrates with existing PPL structure. ```sql @@ -21,7 +21,7 @@ For additional info See [Issue](https://github.com/opensearch-project/opensearch --- -**InSubquery usage** +### InSubquery usage - `source = outer | where a in [ source = inner | fields b ]` - `source = outer | where (a) in [ source = inner | fields b ]` - `source = outer | where (a,b,c) in [ source = inner | fields d,e,f ]` @@ -111,8 +111,9 @@ source = supplier nation | sort s_name ``` +--- -**ExistsSubquery usage** +### ExistsSubquery usage Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table inner, `e`, `f` are fields of table inner2 @@ -163,8 +164,9 @@ source = orders | sort o_orderpriority | fields o_orderpriority, order_count ``` +--- -**ScalarSubquery usage** +### ScalarSubquery usage Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table inner, `e`, `f` are fields of table nested @@ -240,10 +242,59 @@ source = spark_catalog.default.outer source = spark_catalog.default.inner | where c = 1 | stats min(d) | sort d ] ``` +--- + +### (Relation) Subquery +`InSubquery`, `ExistsSubquery` and `ScalarSubquery` are all subquery expressions. But `RelationSubquery` is not a subquery expression, it is a subquery plan which is common used in Join or From clause. + +- `source = table1 | join left = l right = r [ source = table2 | where d > 10 | head 5 ]` (subquery in join right side) +- `source = [ source = table1 | join left = l right = r [ source = table2 | where d > 10 | head 5 ] | stats count(a) by b ] as outer | head 1` + +**_SQL Migration examples with Exists-Subquery PPL:_** + +tpch q13 +```sql +select + c_count, + count(*) as custdist +from + ( + select + c_custkey, + count(o_orderkey) as c_count + from + customer left outer join orders on + c_custkey = o_custkey + and o_comment not like '%special%requests%' + group by + c_custkey + ) as c_orders +group by + c_count +order by + custdist desc, + c_count desc +``` +Rewritten by PPL (Relation) Subquery: +```sql +SEARCH source = [ + SEARCH source = customer + | LEFT OUTER JOIN left = c right = o ON c_custkey = o_custkey + [ + SEARCH source = orders + | WHERE not like(o_comment, '%special%requests%') + ] + | STATS COUNT(o_orderkey) AS c_count BY c_custkey +] AS c_orders +| STATS COUNT(o_orderkey) AS c_count BY c_custkey +| STATS COUNT(1) AS custdist BY c_count +| SORT - custdist, - c_count +``` +--- -### **Additional Context** +### Additional Context -`InSubquery`, `ExistsSubquery` and `ScalarSubquery` are all subquery expression. The common usage of subquery expression is in `where` clause: +`InSubquery`, `ExistsSubquery` and `ScalarSubquery` as subquery expressions, their common usage is in `where` clause. The `where` command syntax is: diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala index b276149a0..00e55d50a 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLJoinITSuite.scala @@ -7,9 +7,9 @@ package org.opensearch.flint.spark.ppl import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Divide, EqualTo, Floor, LessThan, Literal, Multiply, Or, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Divide, EqualTo, Floor, GreaterThan, LessThan, Literal, Multiply, Or, SortOrder} import org.apache.spark.sql.catalyst.plans.{Cross, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, JoinHint, LogicalPlan, Project, Sort, SubqueryAlias} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, Join, JoinHint, LocalLimit, LogicalPlan, Project, Sort, SubqueryAlias} import org.apache.spark.sql.streaming.StreamTest class FlintSparkPPLJoinITSuite @@ -738,4 +738,190 @@ class FlintSparkPPLJoinITSuite case j @ Join(_, _, Inner, _, JoinHint.NONE) => j }.size == 1) } + + test("test inner join with relation subquery") { + val frame = sql(s""" + | source = $testTable1 + | | where country = 'USA' OR country = 'England' + | | inner join left=a, right=b + | ON a.name = b.name + | [ + | source = $testTable2 + | | where salary > 0 + | | fields name, country, salary + | | sort salary + | | head 3 + | ] + | | stats avg(salary) by span(age, 10) as age_span, b.country + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(70000.0, "USA", 30), Row(100000.0, "England", 70)) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val filterExpr = Or( + EqualTo(UnresolvedAttribute("country"), Literal("USA")), + EqualTo(UnresolvedAttribute("country"), Literal("England"))) + val plan1 = SubqueryAlias("a", Filter(filterExpr, table1)) + val rightSubquery = + GlobalLimit( + Literal(3), + LocalLimit( + Literal(3), + Sort( + Seq(SortOrder(UnresolvedAttribute("salary"), Ascending)), + global = true, + Project( + Seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("country"), + UnresolvedAttribute("salary")), + Filter(GreaterThan(UnresolvedAttribute("salary"), Literal(0)), table2))))) + val plan2 = SubqueryAlias("b", rightSubquery) + + val joinCondition = EqualTo(UnresolvedAttribute("a.name"), UnresolvedAttribute("b.name")) + val joinPlan = Join(plan1, plan2, Inner, Some(joinCondition), JoinHint.NONE) + + val salaryField = UnresolvedAttribute("salary") + val countryField = UnresolvedAttribute("b.country") + val countryAlias = Alias(countryField, "b.country")() + val star = Seq(UnresolvedStar(None)) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(salaryField), isDistinct = false), "avg(salary)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = + Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), joinPlan) + + val expectedPlan = Project(star, aggregatePlan) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test left outer join with relation subquery") { + val frame = sql(s""" + | source = $testTable1 + | | where country = 'USA' OR country = 'England' + | | left join left=a, right=b + | ON a.name = b.name + | [ + | source = $testTable2 + | | where salary > 0 + | | fields name, country, salary + | | sort salary + | | head 3 + | ] + | | stats avg(salary) by span(age, 10) as age_span, b.country + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row(70000.0, "USA", 30), Row(100000.0, "England", 70), Row(null, null, 40)) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val filterExpr = Or( + EqualTo(UnresolvedAttribute("country"), Literal("USA")), + EqualTo(UnresolvedAttribute("country"), Literal("England"))) + val plan1 = SubqueryAlias("a", Filter(filterExpr, table1)) + val rightSubquery = + GlobalLimit( + Literal(3), + LocalLimit( + Literal(3), + Sort( + Seq(SortOrder(UnresolvedAttribute("salary"), Ascending)), + global = true, + Project( + Seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("country"), + UnresolvedAttribute("salary")), + Filter(GreaterThan(UnresolvedAttribute("salary"), Literal(0)), table2))))) + val plan2 = SubqueryAlias("b", rightSubquery) + + val joinCondition = EqualTo(UnresolvedAttribute("a.name"), UnresolvedAttribute("b.name")) + val joinPlan = Join(plan1, plan2, LeftOuter, Some(joinCondition), JoinHint.NONE) + + val salaryField = UnresolvedAttribute("salary") + val countryField = UnresolvedAttribute("b.country") + val countryAlias = Alias(countryField, "b.country")() + val star = Seq(UnresolvedStar(None)) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(salaryField), isDistinct = false), "avg(salary)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = + Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), joinPlan) + + val expectedPlan = Project(star, aggregatePlan) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with relation subquery") { + val frame = sql(s""" + | source = $testTable1 + | | where country = 'Canada' OR country = 'England' + | | inner join left=a, right=b + | ON a.name = b.name AND a.year = 2023 AND a.month = 4 AND b.year = 2023 AND b.month = 4 + | [ + | source = $testTable2 + | ] + | | eval a_name = a.name + | | eval a_country = a.country + | | eval b_country = b.country + | | fields a_name, age, state, a_country, occupation, b_country, salary + | | left join left=a, right=b + | ON a.a_name = b.name + | [ + | source = $testTable3 + | ] + | | eval aa_country = a.a_country + | | eval ab_country = a.b_country + | | eval bb_country = b.country + | | fields a_name, age, state, aa_country, occupation, ab_country, salary, bb_country, hobby, language + | | cross join left=a, right=b + | [ + | source = $testTable2 + | ] + | | eval new_country = a.aa_country + | | eval new_salary = b.salary + | | stats avg(new_salary) as avg_salary by span(age, 5) as age_span, state + | | left semi join left=a, right=b + | ON a.state = b.state + | [ + | source = $testTable1 + | ] + | | eval new_avg_salary = floor(avg_salary) + | | fields state, age_span, new_avg_salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row("Quebec", 20, 83333), Row("Ontario", 25, 83333)) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + assert(frame.queryExecution.optimizedPlan.collect { + case j @ Join(_, _, Cross, None, JoinHint.NONE) => j + }.size == 1) + assert(frame.queryExecution.optimizedPlan.collect { + case j @ Join(_, _, LeftOuter, _, JoinHint.NONE) => j + }.size == 1) + assert(frame.queryExecution.optimizedPlan.collect { + case j @ Join(_, _, Inner, _, JoinHint.NONE) => j + }.size == 1) + assert(frame.queryExecution.analyzed.collect { case s: SubqueryAlias => + s + }.size == 13) + } } diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 7a6f14839..d0b3a9666 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -55,9 +55,9 @@ commands ; searchCommand - : (SEARCH)? fromClause # searchFrom - | (SEARCH)? fromClause logicalExpression # searchFromFilter - | (SEARCH)? logicalExpression fromClause # searchFilterFrom + : (SEARCH | FROM)? fromClause # searchFrom + | (SEARCH | FROM)? fromClause logicalExpression # searchFromFilter + | (SEARCH | FROM)? logicalExpression fromClause # searchFilterFrom ; describeCommand @@ -247,17 +247,27 @@ mlArg // clauses fromClause - : SOURCE EQUAL tableSourceClause - | INDEX EQUAL tableSourceClause + : SOURCE EQUAL tableOrSubqueryClause + | INDEX EQUAL tableOrSubqueryClause ; +tableOrSubqueryClause + : LT_SQR_PRTHS subSearch RT_SQR_PRTHS (AS alias = qualifiedName)? + | tableSourceClause + ; + +// One tableSourceClause will generate one Relation node with/without one alias +// even if the relation contains more than one table sources. +// These table sources in one relation will be readed one by one in OpenSearch. +// But it may have different behaivours in different execution backends. +// For example, a Spark UnresovledRelation node only accepts one data source. tableSourceClause - : tableSource (COMMA tableSource)* + : tableSource (COMMA tableSource)* (AS alias = qualifiedName)? ; // join joinCommand - : (joinType) JOIN sideAlias joinHintList? joinCriteria? right = tableSource + : (joinType) JOIN sideAlias joinHintList? joinCriteria? right = tableOrSubqueryClause ; joinType @@ -279,13 +289,13 @@ joinCriteria ; joinHintList - : hintPair (COMMA? hintPair)* - ; + : hintPair (COMMA? hintPair)* + ; hintPair - : leftHintKey = LEFT_HINT DOT ID EQUAL leftHintValue = ident #leftHint - | rightHintKey = RIGHT_HINT DOT ID EQUAL rightHintValue = ident #rightHint - ; + : leftHintKey = LEFT_HINT DOT ID EQUAL leftHintValue = ident #leftHint + | rightHintKey = RIGHT_HINT DOT ID EQUAL rightHintValue = ident #rightHint + ; renameClasue : orignalField = wcFieldExpression AS renamedField = wcFieldExpression diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java index e1732f75f..1b30a7998 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Relation.java @@ -8,7 +8,9 @@ import com.google.common.collect.ImmutableList; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; +import lombok.Getter; import lombok.RequiredArgsConstructor; +import lombok.Setter; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.expression.QualifiedName; @@ -38,7 +40,7 @@ public Relation(UnresolvedExpression tableName, String alias) { } /** Optional alias name for the relation. */ - private String alias; + @Setter @Getter private String alias; /** * Return table name. @@ -53,15 +55,6 @@ public List getQualifiedNames() { return tableName.stream().map(t -> (QualifiedName) t).collect(Collectors.toList()); } - /** - * Return alias. - * - * @return alias. - */ - public String getAlias() { - return alias; - } - /** * Get Qualified name preservs parts of the user given identifiers. This can later be utilized to * determine DataSource,Schema and Table Name during Analyzer stage. So Passing QualifiedName diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 8673b1582..1c0fe919f 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -156,8 +156,12 @@ public UnresolvedPlan visitJoinCommand(OpenSearchPPLParser.JoinCommandContext ct Join.JoinHint joinHint = getJoinHint(ctx.joinHintList()); String leftAlias = ctx.sideAlias().leftAlias.getText(); String rightAlias = ctx.sideAlias().rightAlias.getText(); - // TODO when sub-search is supported, this part need to change. Now relation is the only supported plan for right side - UnresolvedPlan right = new SubqueryAlias(rightAlias, new Relation(this.internalVisitExpression(ctx.tableSource()), rightAlias)); + if (ctx.tableOrSubqueryClause().alias != null) { + // left and right aliases are required in join syntax. Setting by 'AS' causes ambiguous + throw new SyntaxCheckException("'AS' is not allowed in right subquery, use right= instead"); + } + UnresolvedPlan rightRelation = visit(ctx.tableOrSubqueryClause()); + UnresolvedPlan right = new SubqueryAlias(rightAlias, rightRelation); Optional joinCondition = ctx.joinCriteria() == null ? Optional.empty() : Optional.of(expressionBuilder.visitJoinCriteria(ctx.joinCriteria())); @@ -451,16 +455,22 @@ public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ct return aggregation; } - /** From clause. */ @Override - public UnresolvedPlan visitFromClause(OpenSearchPPLParser.FromClauseContext ctx) { - return visitTableSourceClause(ctx.tableSourceClause()); + public UnresolvedPlan visitTableOrSubqueryClause(OpenSearchPPLParser.TableOrSubqueryClauseContext ctx) { + if (ctx.subSearch() != null) { + return ctx.alias != null + ? new SubqueryAlias(ctx.alias.getText(), visitSubSearch(ctx.subSearch())) + : visitSubSearch(ctx.subSearch()); + } else { + return visitTableSourceClause(ctx.tableSourceClause()); + } } @Override public UnresolvedPlan visitTableSourceClause(OpenSearchPPLParser.TableSourceClauseContext ctx) { - return new Relation( - ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList())); + return ctx.alias == null + ? new Relation(ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList())) + : new Relation(ctx.tableSource().stream().map(this::internalVisitExpression).collect(Collectors.toList()), ctx.alias.getText()); } @Override diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala index 58c1a8d12..3ceff7735 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala @@ -11,9 +11,9 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Descending, EqualTo, LessThan, Literal, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, EqualTo, GreaterThan, LessThan, Literal, Not, SortOrder} import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Join, JoinHint, Project, Sort, SubqueryAlias} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, Join, JoinHint, LocalLimit, Project, Sort, SubqueryAlias} class PPLLogicalPlanJoinTranslatorTestSuite extends SparkFunSuite @@ -341,4 +341,228 @@ class PPLLogicalPlanJoinTranslatorTestSuite val expectedPlan = Project(Seq(UnresolvedStar(None)), sort) comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + + test("test inner join with relation subquery") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1| JOIN left = l right = r ON l.id = r.id + | [ + | source = $testTable2 + | | where id > 10 and name = 'abc' + | | fields id, name + | | sort id + | | head 10 + | ] + | | stats count(id) as cnt by type + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val leftPlan = SubqueryAlias("l", table1) + val rightSubquery = + GlobalLimit( + Literal(10), + LocalLimit( + Literal(10), + Sort( + Seq(SortOrder(UnresolvedAttribute("id"), Ascending)), + global = true, + Project( + Seq(UnresolvedAttribute("id"), UnresolvedAttribute("name")), + Filter( + And( + GreaterThan(UnresolvedAttribute("id"), Literal(10)), + EqualTo(UnresolvedAttribute("name"), Literal("abc"))), + table2))))) + val rightPlan = SubqueryAlias("r", rightSubquery) + val joinCondition = EqualTo(UnresolvedAttribute("l.id"), UnresolvedAttribute("r.id")) + val joinPlan = Join(leftPlan, rightPlan, Inner, Some(joinCondition), JoinHint.NONE) + val groupingExpression = Alias(UnresolvedAttribute("type"), "type")() + val aggregateExpression = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedAttribute("id")), isDistinct = false), + "cnt")() + val aggPlan = + Aggregate(Seq(groupingExpression), Seq(aggregateExpression, groupingExpression), joinPlan) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggPlan) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test left outer join with relation subquery") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1| LEFT JOIN left = l right = r ON l.id = r.id + | [ + | source = $testTable2 + | | where id > 10 and name = 'abc' + | | fields id, name + | | sort id + | | head 10 + | ] + | | stats count(id) as cnt by type + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val leftPlan = SubqueryAlias("l", table1) + val rightSubquery = + GlobalLimit( + Literal(10), + LocalLimit( + Literal(10), + Sort( + Seq(SortOrder(UnresolvedAttribute("id"), Ascending)), + global = true, + Project( + Seq(UnresolvedAttribute("id"), UnresolvedAttribute("name")), + Filter( + And( + GreaterThan(UnresolvedAttribute("id"), Literal(10)), + EqualTo(UnresolvedAttribute("name"), Literal("abc"))), + table2))))) + val rightPlan = SubqueryAlias("r", rightSubquery) + val joinCondition = EqualTo(UnresolvedAttribute("l.id"), UnresolvedAttribute("r.id")) + val joinPlan = Join(leftPlan, rightPlan, LeftOuter, Some(joinCondition), JoinHint.NONE) + val groupingExpression = Alias(UnresolvedAttribute("type"), "type")() + val aggregateExpression = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedAttribute("id")), isDistinct = false), + "cnt")() + val aggPlan = + Aggregate(Seq(groupingExpression), Seq(aggregateExpression, groupingExpression), joinPlan) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggPlan) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test multiple joins with relation subquery") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1 + | | head 10 + | | inner JOIN left = l,right = r ON l.id = r.id + | [ + | source = $testTable2 + | | where id > 10 + | ] + | | left JOIN left = l,right = r ON l.name = r.name + | [ + | source = $testTable3 + | | fields id + | ] + | | cross JOIN left = l,right = r + | [ + | source = $testTable4 + | | sort id + | ] + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val table3 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val table4 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test4")) + var leftPlan = SubqueryAlias("l", GlobalLimit(Literal(10), LocalLimit(Literal(10), table1))) + var rightPlan = + SubqueryAlias("r", Filter(GreaterThan(UnresolvedAttribute("id"), Literal(10)), table2)) + val joinCondition1 = EqualTo(UnresolvedAttribute("l.id"), UnresolvedAttribute("r.id")) + val joinPlan1 = Join(leftPlan, rightPlan, Inner, Some(joinCondition1), JoinHint.NONE) + leftPlan = SubqueryAlias("l", joinPlan1) + rightPlan = SubqueryAlias("r", Project(Seq(UnresolvedAttribute("id")), table3)) + val joinCondition2 = EqualTo(UnresolvedAttribute("l.name"), UnresolvedAttribute("r.name")) + val joinPlan2 = Join(leftPlan, rightPlan, LeftOuter, Some(joinCondition2), JoinHint.NONE) + leftPlan = SubqueryAlias("l", joinPlan2) + rightPlan = SubqueryAlias( + "r", + Sort(Seq(SortOrder(UnresolvedAttribute("id"), Ascending)), global = true, table4)) + val joinPlan3 = Join(leftPlan, rightPlan, Cross, None, JoinHint.NONE) + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan3) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test complex join: TPC-H Q13 with relation subquery") { + // select + // c_count, + // count(*) as custdist + // from + // ( + // select + // c_custkey, + // count(o_orderkey) as c_count + // from + // customer left outer join orders on + // c_custkey = o_custkey + // and o_comment not like '%special%requests%' + // group by + // c_custkey + // ) as c_orders + // group by + // c_count + // order by + // custdist desc, + // c_count desc + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | SEARCH source = [ + | SEARCH source = customer + | | LEFT OUTER JOIN left = c right = o ON c_custkey = o_custkey + | [ + | SEARCH source = orders + | | WHERE not like(o_comment, '%special%requests%') + | ] + | | STATS COUNT(o_orderkey) AS c_count BY c_custkey + | ] AS c_orders + | | STATS COUNT(o_orderkey) AS c_count BY c_custkey + | | STATS COUNT(1) AS custdist BY c_count + | | SORT - custdist, - c_count + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val tableC = UnresolvedRelation(Seq("customer")) + val tableO = UnresolvedRelation(Seq("orders")) + val left = SubqueryAlias("c", tableC) + val filterNot = Filter( + Not( + UnresolvedFunction( + Seq("like"), + Seq(UnresolvedAttribute("o_comment"), Literal("%special%requests%")), + isDistinct = false)), + tableO) + val right = SubqueryAlias("o", filterNot) + val joinCondition = + EqualTo(UnresolvedAttribute("o_custkey"), UnresolvedAttribute("c_custkey")) + val join = Join(left, right, LeftOuter, Some(joinCondition), JoinHint.NONE) + val groupingExpression1 = Alias(UnresolvedAttribute("c_custkey"), "c_custkey")() + val aggregateExpressions1 = + Alias( + UnresolvedFunction( + Seq("COUNT"), + Seq(UnresolvedAttribute("o_orderkey")), + isDistinct = false), + "c_count")() + val agg3 = + Aggregate(Seq(groupingExpression1), Seq(aggregateExpressions1, groupingExpression1), join) + val subqueryAlias = SubqueryAlias("c_orders", agg3) + val agg2 = + Aggregate( + Seq(groupingExpression1), + Seq(aggregateExpressions1, groupingExpression1), + subqueryAlias) + val groupingExpression2 = Alias(UnresolvedAttribute("c_count"), "c_count")() + val aggregateExpressions2 = + Alias(UnresolvedFunction(Seq("COUNT"), Seq(Literal(1)), isDistinct = false), "custdist")() + val agg1 = + Aggregate(Seq(groupingExpression2), Seq(aggregateExpressions2, groupingExpression2), agg2) + val sort = Sort( + Seq( + SortOrder(UnresolvedAttribute("custdist"), Descending), + SortOrder(UnresolvedAttribute("c_count"), Descending)), + global = true, + agg1) + val expectedPlan = Project(Seq(UnresolvedStar(None)), sort) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } }