From 0aeed270f118f586ff3fe5d6b145213393151d64 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Wed, 2 Oct 2024 02:54:10 +0800 Subject: [PATCH] Support `InSubquery` in PPL (#714) * Support InSubquery in PPL Signed-off-by: Lantao Jin * Add more examples Signed-off-by: Lantao Jin --------- Signed-off-by: Lantao Jin --- .../ppl/FlintSparkPPLInSubqueryITSuite.scala | 407 ++++++++++++++++++ ppl-spark-integration/README.md | 91 ++++ .../src/main/antlr4/OpenSearchPPLParser.g4 | 19 + .../sql/ast/AbstractNodeVisitor.java | 4 + .../sql/ast/expression/InSubquery.java | 35 ++ .../sql/ppl/CatalystQueryPlanVisitor.java | 29 +- .../opensearch/sql/ppl/parser/AstBuilder.java | 6 + .../sql/ppl/parser/AstExpressionBuilder.java | 17 + .../flint/spark/ppl/PPLSyntaxParser.scala | 8 +- ...calPlanInSubqueryTranslatorTestSuite.scala | 365 ++++++++++++++++ 10 files changed, 977 insertions(+), 4 deletions(-) create mode 100644 integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLInSubqueryITSuite.scala create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/InSubquery.java create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanInSubqueryTranslatorTestSuite.scala diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLInSubqueryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLInSubqueryITSuite.scala new file mode 100644 index 000000000..ee08e692a --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLInSubqueryITSuite.scala @@ -0,0 +1,407 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} +import org.apache.spark.sql.catalyst.expressions.{And, Descending, EqualTo, InSubquery, ListQuery, Literal, Not, SortOrder} +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, JoinHint, LogicalPlan, Project, Sort, SubqueryAlias} +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLInSubqueryITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val outerTable = "spark_catalog.default.flint_ppl_test1" + private val innerTable = "spark_catalog.default.flint_ppl_test2" + private val nestedInnerTable = "spark_catalog.default.flint_ppl_test3" + + override def beforeAll(): Unit = { + super.beforeAll() + createPeopleTable(outerTable) + sql(s""" + | INSERT INTO $outerTable + | VALUES (1006, 'Tommy', 'Teacher', 'USA', 30000) + | """.stripMargin) + createWorkInformationTable(innerTable) + createOccupationTable(nestedInnerTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test where id in (select uid from inner)") { + // id (0, 1, 2, 3, 4, 5, 6), uid (0, 2, 3, 5, 6) + // InSubquery: (0, 2, 3, 5, 6) + val frame = sql(s""" + source = $outerTable + | | where id in [ + | source = $innerTable | fields uid + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1003, "David", 120000), + Row(1002, "John", 120000), + Row(1000, "Jake", 100000), + Row(1005, "Jane", 90000), + Row(1006, "Tommy", 30000)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val inSubquery = + Filter( + InSubquery( + Seq(UnresolvedAttribute("id")), + ListQuery(Project(Seq(UnresolvedAttribute("uid")), inner))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), global = true, inSubquery) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test where (id) in (select uid from inner)") { + // id (0, 1, 2, 3, 4, 5, 6), uid (0, 2, 3, 5, 6) + // InSubquery: (0, 2, 3, 5, 6) + val frame = sql(s""" + source = $outerTable + | | where (id) in [ + | source = $innerTable | fields uid + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1003, "David", 120000), + Row(1002, "John", 120000), + Row(1000, "Jake", 100000), + Row(1005, "Jane", 90000), + Row(1006, "Tommy", 30000)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val inSubquery = + Filter( + InSubquery( + Seq(UnresolvedAttribute("id")), + ListQuery(Project(Seq(UnresolvedAttribute("uid")), inner))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), global = true, inSubquery) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test where (id, name) in (select uid, name from inner)") { + // InSubquery: (0, 2, 3, 5) + val frame = sql(s""" + source = $outerTable + | | where (id, name) in [ + | source = $innerTable | fields uid, name + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1003, "David", 120000), + Row(1002, "John", 120000), + Row(1000, "Jake", 100000), + Row(1005, "Jane", 90000)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val inSubquery = + Filter( + InSubquery( + Seq(UnresolvedAttribute("id"), UnresolvedAttribute("name")), + ListQuery( + Project(Seq(UnresolvedAttribute("uid"), UnresolvedAttribute("name")), inner))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), global = true, inSubquery) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test where id not in (select uid from inner)") { + // id (0, 1, 2, 3, 4, 5, 6), uid (0, 2, 3, 5, 6) + // Not InSubquery: (1, 4) + val frame = sql(s""" + source = $outerTable + | | where id not in [ + | source = $innerTable | fields uid + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(1001, "Hello", 70000), Row(1004, "David", 0)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val inSubquery = + Filter( + Not( + InSubquery( + Seq(UnresolvedAttribute("id")), + ListQuery(Project(Seq(UnresolvedAttribute("uid")), inner)))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), global = true, inSubquery) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test where (id, name) not in (select uid, name from inner)") { + // Not InSubquery: (1, 4, 6) + val frame = sql(s""" + source = $outerTable + | | where (id, name) not in [ + | source = $innerTable | fields uid, name + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row(1001, "Hello", 70000), Row(1004, "David", 0), Row(1006, "Tommy", 30000)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val inSubquery = + Filter( + Not( + InSubquery( + Seq(UnresolvedAttribute("id"), UnresolvedAttribute("name")), + ListQuery( + Project(Seq(UnresolvedAttribute("uid"), UnresolvedAttribute("name")), inner)))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), global = true, inSubquery) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test empty subquery") { + // id (0, 1, 2, 3, 4, 5, 6), uid () + // InSubquery: () + // Not InSubquery: (0, 1, 2, 3, 4, 5, 6) + var frame = sql(s""" + source = $outerTable + | | where id in [ + | source = $innerTable | where uid = 0000 | fields uid + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + var results: Array[Row] = frame.collect() + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + var expectedResults: Array[Row] = Array() + assert(results.sorted.sameElements(expectedResults.sorted)) + + frame = sql(s""" + source = $outerTable + | | where id not in [ + | source = $innerTable | where uid = 0000 | fields uid + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + results = frame.collect() + expectedResults = Array( + Row(1000, "Jake", 100000), + Row(1001, "Hello", 70000), + Row(1002, "John", 120000), + Row(1003, "David", 120000), + Row(1004, "David", 0), + Row(1005, "Jane", 90000), + Row(1006, "Tommy", 30000)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } + + test("test nested subquery") { + val frame = sql(s""" + source = $outerTable + | | where id in [ + | source = $innerTable + | | where occupation in [ + | source = $nestedInnerTable | where occupation != 'Engineer' | fields occupation + | ] + | | fields uid + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + frame.show() + frame.explain(true) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row(1003, "David", 120000), Row(1002, "John", 120000), Row(1006, "Tommy", 30000)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val inner2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val filter = + Filter(Not(EqualTo(UnresolvedAttribute("occupation"), Literal("Engineer"))), inner2) + val inSubqueryForOccupation = + Filter( + InSubquery( + Seq(UnresolvedAttribute("occupation")), + ListQuery(Project(Seq(UnresolvedAttribute("occupation")), filter))), + inner1) + val inSubqueryForId = + Filter( + InSubquery( + Seq(UnresolvedAttribute("id")), + ListQuery(Project(Seq(UnresolvedAttribute("uid")), inSubqueryForOccupation))), + outer) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), + global = true, + inSubqueryForId) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test in-subquery as a join filter") { + val frame = sql(s""" + | source = $outerTable + | | inner join left=a, right=b + | ON a.id = b.uid AND b.occupation in [ + | source = $nestedInnerTable| where occupation != 'Engineer' | fields occupation + | ] + | $innerTable + | | fields a.id, a.name, a.salary + | """.stripMargin) + frame.explain(true) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array(Row(1003, "David", 120000), Row(1002, "John", 120000), Row(1006, "Tommy", 30000)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val plan1 = SubqueryAlias("a", table1) + val plan2 = SubqueryAlias("b", table2) + val filter = + Filter(Not(EqualTo(UnresolvedAttribute("occupation"), Literal("Engineer"))), inner) + val inSubqueryForOccupation = + InSubquery( + Seq(UnresolvedAttribute("b.occupation")), + ListQuery(Project(Seq(UnresolvedAttribute("occupation")), filter))) + val joinCondition = + And( + EqualTo(UnresolvedAttribute("a.id"), UnresolvedAttribute("b.uid")), + inSubqueryForOccupation) + val joinPlan = Join(plan1, plan2, Inner, Some(joinCondition), JoinHint.NONE) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("a.id"), + UnresolvedAttribute("a.name"), + UnresolvedAttribute("a.salary")), + joinPlan) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("throw exception because the number of columns not match output of subquery") { + val ex = intercept[AnalysisException](sql(s""" + source = $outerTable + | | where id in [ + | source = $innerTable | fields uid, department + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin)) + assert(ex.getMessage.contains( + "The number of columns in the left hand side of an IN subquery does not match the number of columns in the output of subquery")) + } +} diff --git a/ppl-spark-integration/README.md b/ppl-spark-integration/README.md index 02baaab45..f07fcbd3f 100644 --- a/ppl-spark-integration/README.md +++ b/ppl-spark-integration/README.md @@ -434,6 +434,97 @@ _- **Limitation: "REPLACE" or "APPEND" clause must contain "AS"**_ Details of Lookup command syntax, see [PPL-Lookup-Command](../docs/PPL-Lookup-command.md) +**InSubquery** +- `source = outer | where a in [ source = inner | fields b ]` +- `source = outer | where (a) in [ source = inner | fields b ]` +- `source = outer | where (a,b,c) in [ source = inner | fields d,e,f ]` +- `source = outer | where a not in [ source = inner | fields b ]` +- `source = outer | where (a) not in [ source = inner | fields b ]` +- `source = outer | where (a,b,c) not in [ source = inner | fields d,e,f ]` +- `source = outer | where a in [ source = inner1 | where b not in [ source = inner2 | fields c ] | fields b ]` (nested) +- `source = table1 | inner join left = l right = r on l.a = r.a AND r.a in [ source = inner | fields d ] | fields l.a, r.a, b, c` (as join filter) + +SQL Migration examples with IN-Subquery PPL: +1. tpch q4 (in-subquery with aggregation) +```sql +select + o_orderpriority, + count(*) as order_count +from + orders +where + o_orderdate >= date '1993-07-01' + and o_orderdate < date '1993-07-01' + interval '3' month + and o_orderkey in ( + select + l_orderkey + from + lineitem + where l_commitdate < l_receiptdate + ) +group by + o_orderpriority +order by + o_orderpriority +``` +Rewritten by PPL InSubquery query: +```sql +source = orders +| where o_orderdate >= "1993-07-01" and o_orderdate < "1993-10-01" and o_orderkey IN + [ source = lineitem + | where l_commitdate < l_receiptdate + | fields l_orderkey + ] +| stats count(1) as order_count by o_orderpriority +| sort o_orderpriority +| fields o_orderpriority, order_count +``` +2.tpch q20 (nested in-subquery) +```sql +select + s_name, + s_address +from + supplier, + nation +where + s_suppkey in ( + select + ps_suppkey + from + partsupp + where + ps_partkey in ( + select + p_partkey + from + part + where + p_name like 'forest%' + ) + ) + and s_nationkey = n_nationkey + and n_name = 'CANADA' +order by + s_name +``` +Rewritten by PPL InSubquery query: +```sql +source = supplier +| where s_suppkey IN [ + source = partsupp + | where ps_partkey IN [ + source = part + | where like(p_name, "forest%") + | fields p_partkey + ] + | fields ps_suppkey + ] +| inner join left=l right=r on s_nationkey = n_nationkey and n_name = 'CANADA' + nation +| sort s_name +``` + --- #### Experimental Commands: - `correlation` - [See details](../docs/PPL-Correlation-command.md) diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 30b57d5da..626ff2165 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -24,6 +24,10 @@ queryStatement : pplCommands (PIPE commands)* ; +subSearch + : searchCommand (PIPE commands)* + ; + // commands pplCommands : searchCommand @@ -339,6 +343,12 @@ logicalExpression comparisonExpression : left = valueExpression comparisonOperator right = valueExpression # compareExpr | valueExpression IN valueList # inExpr + | valueExpressionList NOT? IN LT_SQR_PRTHS subSearch RT_SQR_PRTHS # inSubqueryExpr + ; + +valueExpressionList + : valueExpression + | LT_PRTHS valueExpression (COMMA valueExpression)* RT_PRTHS ; valueExpression @@ -1004,4 +1014,13 @@ keywordsCanBeId | SPARKLINE | C | DC + // JOIN TYPE + | OUTER + | INNER + | CROSS + | LEFT + | RIGHT + | FULL + | SEMI + | ANTI ; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index aea7bbb1d..76f9479f4 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -19,6 +19,7 @@ import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; +import org.opensearch.sql.ast.expression.InSubquery; import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.IsEmpty; import org.opensearch.sql.ast.expression.Let; @@ -289,4 +290,7 @@ public T visitExplain(Explain node, C context) { return visitStatement(node, context); } + public T visitInSubquery(InSubquery node, C context) { + return visitChildren(node, context); + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/InSubquery.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/InSubquery.java new file mode 100644 index 000000000..ed40e4b45 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/InSubquery.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.tree.UnresolvedPlan; + +import java.util.Arrays; +import java.util.List; + +@Getter +@ToString +@EqualsAndHashCode(callSuper = false) +@RequiredArgsConstructor +public class InSubquery extends UnresolvedExpression { + private final List value; + private final UnresolvedPlan query; + + @Override + public List getChild() { + return value; + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitInSubquery(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index e7dc09542..a0bfe851e 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -13,6 +13,8 @@ import org.apache.spark.sql.catalyst.expressions.CaseWhen; import org.apache.spark.sql.catalyst.expressions.Descending$; import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.InSubquery$; +import org.apache.spark.sql.catalyst.expressions.ListQuery$; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.Predicate; import org.apache.spark.sql.catalyst.expressions.SortDirection; @@ -41,6 +43,7 @@ import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; +import org.opensearch.sql.ast.expression.InSubquery; import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.IsEmpty; import org.opensearch.sql.ast.expression.Let; @@ -75,6 +78,7 @@ import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.SubqueryAlias; import org.opensearch.sql.ast.tree.TopAggregation; +import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.ppl.utils.AggregatorTranslator; import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; @@ -124,6 +128,10 @@ public LogicalPlan visit(Statement plan, CatalystPlanContext context) { return plan.accept(this, context); } + public LogicalPlan visitSubSearch(UnresolvedPlan plan, CatalystPlanContext context) { + return plan.accept(this, context); + } + /** * Handle Query Statement. */ @@ -487,7 +495,7 @@ public LogicalPlan visitDedupe(Dedupe node, CatalystPlanContext context) { /** * Expression Analyzer. */ - public static class ExpressionAnalyzer extends AbstractNodeVisitor { + public class ExpressionAnalyzer extends AbstractNodeVisitor { public Expression analyze(UnresolvedExpression unresolved, CatalystPlanContext context) { return unresolved.accept(this, context); @@ -734,5 +742,24 @@ public Expression visitRareTopN(RareTopN node, CatalystPlanContext context) { public Expression visitWindowFunction(WindowFunction node, CatalystPlanContext context) { throw new IllegalStateException("Not Supported operation : WindowFunction"); } + + @Override + public Expression visitInSubquery(InSubquery node, CatalystPlanContext outerContext) { + CatalystPlanContext innerContext = new CatalystPlanContext(); + visitExpressionList(node.getChild(), innerContext); + Seq values = innerContext.retainAllNamedParseExpressions(p -> p); + UnresolvedPlan outerPlan = node.getQuery(); + LogicalPlan subSearch = CatalystQueryPlanVisitor.this.visitSubSearch(outerPlan, innerContext); + Expression inSubQuery = InSubquery$.MODULE$.apply( + values, + ListQuery$.MODULE$.apply( + subSearch, + seq(new java.util.ArrayList()), + NamedExpression.newExprId(), + -1, + seq(new java.util.ArrayList()), + Option.empty())); + return outerContext.getNamedParseExpressions().push(inSubQuery); + } } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index e9aee3180..a963073d6 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -85,6 +85,12 @@ public UnresolvedPlan visitQueryStatement(OpenSearchPPLParser.QueryStatementCont return ctx.commands().stream().map(this::visit).reduce(pplCommand, (r, e) -> e.attach(r)); } + @Override + public UnresolvedPlan visitSubSearch(OpenSearchPPLParser.SubSearchContext ctx) { + UnresolvedPlan searchCommand = visit(ctx.searchCommand()); + return ctx.commands().stream().map(this::visit).reduce(searchCommand, (r, e) -> e.attach(r)); + } + /** Search command. */ @Override public UnresolvedPlan visitSearchFrom(OpenSearchPPLParser.SearchFromContext ctx) { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 2706d85e5..f5e9269be 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -22,6 +22,7 @@ import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.InSubquery; import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.IntervalUnit; import org.opensearch.sql.ast.expression.IsEmpty; @@ -62,6 +63,13 @@ public class AstExpressionBuilder extends OpenSearchPPLParserBaseVisitor ctx) { return new QualifiedName( ctx.stream() diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala index 3cd018ead..c435af53d 100644 --- a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/PPLSyntaxParser.scala @@ -29,9 +29,11 @@ class PPLSyntaxParser extends Parser { object PlaneUtils { def plan(parser: PPLSyntaxParser, query: String): Statement = { - val builder = new AstStatementBuilder( - new AstBuilder(new AstExpressionBuilder(), query), - AstStatementBuilder.StatementBuilderContext.builder()) + val astExpressionBuilder = new AstExpressionBuilder() + val astBuilder = new AstBuilder(astExpressionBuilder, query) + astExpressionBuilder.setAstBuilder(astBuilder) + val builder = + new AstStatementBuilder(astBuilder, AstStatementBuilder.StatementBuilderContext.builder()) builder.visit(parser.parse(query)) } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanInSubqueryTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanInSubqueryTranslatorTestSuite.scala new file mode 100644 index 000000000..03bcdd623 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanInSubqueryTranslatorTestSuite.scala @@ -0,0 +1,365 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.common.antlr.SyntaxCheckException +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, EqualTo, GreaterThanOrEqual, InSubquery, LessThan, ListQuery, Literal, Not, SortOrder} +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, JoinHint, LogicalPlan, Project, Sort, SubqueryAlias} + +class PPLLogicalPlanInSubqueryTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test where a in (select b from c)") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where a in [ + | source = spark_catalog.default.inner | fields b + | ] + | | sort - a + | | fields a, c + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val inSubquery = + Filter( + InSubquery( + Seq(UnresolvedAttribute("a")), + ListQuery(Project(Seq(UnresolvedAttribute("b")), inner))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("a"), Descending)), global = true, inSubquery) + val expectedPlan = + Project(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("c")), sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test where (a) in (select b from c)") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where a in [ + | source = spark_catalog.default.inner | fields b + | ] + | | sort - a + | | fields a, c + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val inSubquery = + Filter( + InSubquery( + Seq(UnresolvedAttribute("a")), + ListQuery(Project(Seq(UnresolvedAttribute("b")), inner))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("a"), Descending)), global = true, inSubquery) + val expectedPlan = + Project(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("c")), sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test where (a, b, c) in (select d, e, f from inner)") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where (a, b, c) in [ + | source = spark_catalog.default.inner | fields d, e, f + | ] + | | sort - a + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val inSubquery = + Filter( + InSubquery( + Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b"), UnresolvedAttribute("c")), + ListQuery( + Project( + Seq(UnresolvedAttribute("d"), UnresolvedAttribute("e"), UnresolvedAttribute("f")), + inner))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("a"), Descending)), global = true, inSubquery) + val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test where a not in (select b from c)") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where a not in [ + | source = spark_catalog.default.inner | fields b + | ] + | | sort - a + | | fields a, c + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val inSubquery = + Filter( + Not( + InSubquery( + Seq(UnresolvedAttribute("a")), + ListQuery(Project(Seq(UnresolvedAttribute("b")), inner)))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("a"), Descending)), global = true, inSubquery) + val expectedPlan = + Project(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("c")), sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test where (a, b, c) not in (select d, e, f from inner)") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where (a, b, c) not in [ + | source = spark_catalog.default.inner | fields d, e, f + | ] + | | sort - a + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val inSubquery = + Filter( + Not( + InSubquery( + Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b"), UnresolvedAttribute("c")), + ListQuery( + Project( + Seq(UnresolvedAttribute("d"), UnresolvedAttribute("e"), UnresolvedAttribute("f")), + inner)))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("a"), Descending)), global = true, inSubquery) + val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test nested subquery") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where a in [ + | source = spark_catalog.default.inner1 + | | where b in [ + | source = spark_catalog.default.inner2 | fields c + | ] + | | fields b + | ] + | | sort - a + | | fields a, d + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner1 = UnresolvedRelation(Seq("spark_catalog", "default", "inner1")) + val inner2 = UnresolvedRelation(Seq("spark_catalog", "default", "inner2")) + val inSubqueryForB = + Filter( + InSubquery( + Seq(UnresolvedAttribute("b")), + ListQuery(Project(Seq(UnresolvedAttribute("c")), inner2))), + inner1) + val inSubqueryForA = + Filter( + InSubquery( + Seq(UnresolvedAttribute("a")), + ListQuery(Project(Seq(UnresolvedAttribute("b")), inSubqueryForB))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("a"), Descending)), global = true, inSubqueryForA) + val expectedPlan = + Project(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("d")), sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + // TODO throw exception with syntax check, now it throw AnalysisException in Spark + ignore("The number of columns not match output of subquery") { + val context = new CatalystPlanContext + val ex = intercept[SyntaxCheckException] { + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where a in [ + | source = spark_catalog.default.inner | fields b, d + | ] + | | sort - a + | | fields a, c + | """.stripMargin), + context) + } + assert(ex.getMessage === "The number of columns not match output of subquery") + } + + test("test tpch q4: in-subquery with aggregation") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = orders + | | where o_orderdate >= "1993-07-01" AND o_orderdate < "1993-10-01" AND o_orderkey IN + | [ source = lineitem + | | where l_commitdate < l_receiptdate + | | fields l_orderkey + | ] + | | stats count(1) as order_count by o_orderpriority + | | sort o_orderpriority + | | fields o_orderpriority, order_count + | """.stripMargin), + context) + + val outer = UnresolvedRelation(Seq("orders")) + val inner = UnresolvedRelation(Seq("lineitem")) + val inSubquery = + Filter( + And( + And( + GreaterThanOrEqual(UnresolvedAttribute("o_orderdate"), Literal("1993-07-01")), + LessThan(UnresolvedAttribute("o_orderdate"), Literal("1993-10-01"))), + InSubquery( + Seq(UnresolvedAttribute("o_orderkey")), + ListQuery( + Project( + Seq(UnresolvedAttribute("l_orderkey")), + Filter( + LessThan( + UnresolvedAttribute("l_commitdate"), + UnresolvedAttribute("l_receiptdate")), + inner))))), + outer) + val o_orderpriorityAlias = Alias(UnresolvedAttribute("o_orderpriority"), "o_orderpriority")() + val groupByAttributes = Seq(o_orderpriorityAlias) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(Literal(1)), isDistinct = false), + "order_count")() + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, o_orderpriorityAlias), inSubquery) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("o_orderpriority"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project( + Seq(UnresolvedAttribute("o_orderpriority"), UnresolvedAttribute("order_count")), + sortedPlan) + comparePlans(expectedPlan, logPlan, false) + } + + test("test tpch q20 (partial): nested in-subquery") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = supplier + | | where s_suppkey IN [ + | source = partsupp + | | where ps_partkey IN [ + | source = part + | | where like(p_name, "forest%") + | | fields p_partkey + | ] + | | fields ps_suppkey + | ] + | | inner join left=l right=r on s_nationkey = n_nationkey and n_name = 'CANADA' + | nation + | | sort s_name + | """.stripMargin), + context) + + val outer = UnresolvedRelation(Seq("supplier")) + val inner = UnresolvedRelation(Seq("partsupp")) + val nestedInner = UnresolvedRelation(Seq("part")) + val right = UnresolvedRelation(Seq("nation")) + val inSubquery = + Filter( + InSubquery( + Seq(UnresolvedAttribute("s_suppkey")), + ListQuery( + Project( + Seq(UnresolvedAttribute("ps_suppkey")), + Filter( + InSubquery( + Seq(UnresolvedAttribute("ps_partkey")), + ListQuery(Project( + Seq(UnresolvedAttribute("p_partkey")), + Filter( + UnresolvedFunction( + "like", + Seq(UnresolvedAttribute("p_name"), Literal("forest%")), + isDistinct = false), + nestedInner)))), + inner)))), + outer) + val leftPlan = SubqueryAlias("l", inSubquery) + val rightPlan = SubqueryAlias("r", right) + val joinCondition = + And( + EqualTo(UnresolvedAttribute("s_nationkey"), UnresolvedAttribute("n_nationkey")), + EqualTo(UnresolvedAttribute("n_name"), Literal("CANADA"))) + val joinPlan = Join(leftPlan, rightPlan, Inner, Some(joinCondition), JoinHint.NONE) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("s_name"), Ascending)), global = true, joinPlan) + val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan) + comparePlans(expectedPlan, logPlan, false) + } +}