diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 7d57651c3..8e6cbaae9 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -273,7 +273,7 @@ _- **Limitation: "REPLACE" or "APPEND" clause must contain "AS"**_ **SQL Migration examples with IN-Subquery PPL:** -1. tpch q4 (in-subquery with aggregation) +tpch q4 (in-subquery with aggregation) ```sql select o_orderpriority, @@ -309,52 +309,21 @@ source = orders | fields o_orderpriority, order_count ``` -2.tpch q20 (nested in-subquery) -```sql -select - s_name, - s_address -from - supplier, - nation -where - s_suppkey in ( - select - ps_suppkey - from - partsupp - where - ps_partkey in ( - select - p_partkey - from - part - where - p_name like 'forest%' - ) - ) - and s_nationkey = n_nationkey - and n_name = 'CANADA' -order by - s_name -``` +#### **ExistsSubquery** +[See additional command details](ppl-subquery-command.md) + +Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table inner, `e`, `f` are fields of table inner2 +- `source = outer | where exists [ source = inner | where a = c ]` +- `source = outer | where not exists [ source = inner | where a = c ]` +- `source = outer | where exists [ source = inner | where a = c and b = d ]` +- `source = outer | where not exists [ source = inner | where a = c and b = d ]` +- `source = outer | where exists [ source = inner1 | where a = c and exists [ source = inner2 | where c = e ] ]` (nested) +- `source = outer | where exists [ source = inner1 | where a = c | where exists [ source = inner2 | where c = e ] ]` (nested) +- `source = outer | where exists [ source = inner | where c > 10 ]` (uncorrelated exists) +- `source = outer | where not exists [ source = inner | where c > 10 ]` (uncorrelated exists) +- `source = outer | where exists [ source = inner ] | eval l = "Bala" | fields l` (special uncorrelated exists) + -Rewritten by PPL InSubquery query: -```sql -source = supplier -| where s_suppkey IN [ - source = partsupp - | where ps_partkey IN [ - source = part - | where like(p_name, "forest%") - | fields p_partkey - ] - | fields ps_suppkey - ] -| inner join left=l right=r on s_nationkey = n_nationkey and n_name = 'CANADA' - nation -| sort s_name -``` #### **ScalarSubquery** [See additional command details](ppl-subquery-command.md) diff --git a/docs/ppl-lang/ppl-subquery-command.md b/docs/ppl-lang/ppl-subquery-command.md index 1762306d2..ac0f98fe8 100644 --- a/docs/ppl-lang/ppl-subquery-command.md +++ b/docs/ppl-lang/ppl-subquery-command.md @@ -112,6 +112,58 @@ source = supplier | sort s_name ``` +**ExistsSubquery usage** + +Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table inner, `e`, `f` are fields of table inner2 + +- `source = outer | where exists [ source = inner | where a = c ]` +- `source = outer | where not exists [ source = inner | where a = c ]` +- `source = outer | where exists [ source = inner | where a = c and b = d ]` +- `source = outer | where not exists [ source = inner | where a = c and b = d ]` +- `source = outer | where exists [ source = inner1 | where a = c and exists [ source = inner2 | where c = e ] ]` (nested) +- `source = outer | where exists [ source = inner1 | where a = c | where exists [ source = inner2 | where c = e ] ]` (nested) +- `source = outer | where exists [ source = inner | where c > 10 ]` (uncorrelated exists) +- `source = outer | where not exists [ source = inner | where c > 10 ]` (uncorrelated exists) +- `source = outer | where exists [ source = inner ] | eval l = "nonEmpty" | fields l` (special uncorrelated exists) + +**_SQL Migration examples with Exists-Subquery PPL:_** + +tpch q4 (exists subquery with aggregation) +```sql +select + o_orderpriority, + count(*) as order_count +from + orders +where + o_orderdate >= date '1993-07-01' + and o_orderdate < date '1993-07-01' + interval '3' month + and exists ( + select + l_orderkey + from + lineitem + where l_orderkey = o_orderkey + and l_commitdate < l_receiptdate + ) +group by + o_orderpriority +order by + o_orderpriority +``` +Rewritten by PPL ExistsSubquery query: +```sql +source = orders +| where o_orderdate >= "1993-07-01" and o_orderdate < "1993-10-01" + and exists [ + source = lineitem + | where l_orderkey = o_orderkey and l_commitdate < l_receiptdate + ] +| stats count(1) as order_count by o_orderpriority +| sort o_orderpriority +| fields o_orderpriority, order_count +``` + **ScalarSubquery usage** Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table inner, `e`, `f` are fields of table nested @@ -191,14 +243,14 @@ source = spark_catalog.default.outer ### **Additional Context** -The most cases in the description is to request a `InSubquery` expression. +`InSubquery`, `ExistsSubquery` and `ScalarSubquery` are all subquery expression. The common usage of subquery expression is in `where` clause: The `where` command syntax is: ``` | where ``` -So the subquery in description is part of boolean expression, such as +So the subquery is part of boolean expression, such as ```sql | where orders.order_id in (subquery source=returns | where return_reason="damaged" | return order_id) @@ -217,10 +269,11 @@ In issue description is a `ScalarSubquery`: ```sql source=employees | join source=sales on employees.employee_id = sales.employee_id -| where sales.sale_amount > (subquery source=targets | where target_met="true" | return target_value) +| where sales.sale_amount > [ source=targets | where target_met="true" | fields target_value ] ``` -Recall the join command doc: https://github.com/opensearch-project/opensearch-spark/blob/main/docs/PPL-Join-command.md#more-examples, the example is a subquery/subsearch **plan**, rather than a **expression**. +But `RelationSubquery` is not a subquery expression, it is a subquery plan. +[Recall the join command doc](ppl-join-command.md), the example is a subquery/subsearch **plan**, rather than a **expression**. ```sql SEARCH source=customer @@ -245,7 +298,32 @@ SEARCH Apply the syntax here and simply into ```sql -search | left join on (subquery search ...) +search | left join on [ search ... ] ``` -The `(subquery search ...)` is not a `expression`, it's `plan`, similar to the `relation` plan \ No newline at end of file +The `[ search ...]` is not a `expression`, it's `plan`, similar to the `relation` plan + +**Uncorrelated Subquery** + +An uncorrelated subquery is independent of the outer query. It is executed once, and the result is used by the outer query. +It's **less common** when using `ExistsSubquery` because `ExistsSubquery` typically checks for the presence of rows that are dependent on the outer query’s row. + +There is a very special exists subquery which highlight by `(special uncorrelated exists)`: +```sql +SELECT 'nonEmpty' +FROM outer + WHERE EXISTS ( + SELECT * + FROM inner + ); +``` +Rewritten by PPL ExistsSubquery query: +```sql +source = outer +| where exists [ + source = inner + ] +| eval l = "nonEmpty" +| fields l +``` +This query just print "nonEmpty" if the inner table is not empty. \ No newline at end of file diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExistsSubqueryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExistsSubqueryITSuite.scala new file mode 100644 index 000000000..81bdd99df --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExistsSubqueryITSuite.scala @@ -0,0 +1,373 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, EqualTo, Exists, GreaterThan, InSubquery, ListQuery, Literal, Not, Or, ScalarSubquery, SortOrder} +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLExistsSubqueryITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val outerTable = "spark_catalog.default.flint_ppl_test1" + private val innerTable = "spark_catalog.default.flint_ppl_test2" + private val nestedInnerTable = "spark_catalog.default.flint_ppl_test3" + + override def beforeAll(): Unit = { + super.beforeAll() + createPeopleTable(outerTable) + sql(s""" + | INSERT INTO $outerTable + | VALUES (1006, 'Tommy', 'Teacher', 'USA', 30000) + | """.stripMargin) + createWorkInformationTable(innerTable) + createOccupationTable(nestedInnerTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test simple exists subquery") { + val frame = sql(s""" + | source = $outerTable + | | where exists [ + | source = $innerTable | where id = uid + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1002, "John", 120000), + Row(1003, "David", 120000), + Row(1000, "Jake", 100000), + Row(1005, "Jane", 90000), + Row(1006, "Tommy", 30000)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val existsSubquery = Filter( + Exists(Filter(EqualTo(UnresolvedAttribute("id"), UnresolvedAttribute("uid")), inner)), + outer) + val sortedPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), + global = true, + existsSubquery) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test not exists subquery") { + val frame = sql(s""" + | source = $outerTable + | | where not exists [ + | source = $innerTable | where id = uid + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(1001, "Hello", 70000), Row(1004, "David", 0)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val existsSubquery = + Filter( + Not( + Exists(Filter(EqualTo(UnresolvedAttribute("id"), UnresolvedAttribute("uid")), inner))), + outer) + val sortedPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), + global = true, + existsSubquery) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test empty exists subquery") { + var frame = sql(s""" + | source = $outerTable + | | where exists [ + | source = $innerTable | where uid = 0000 AND id = uid + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + var results: Array[Row] = frame.collect() + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + var expectedResults: Array[Row] = Array() + assert(results.sorted.sameElements(expectedResults.sorted)) + + frame = sql(s""" + source = $outerTable + | | where not exists [ + | source = $innerTable | where uid = 0000 AND id = uid + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + results = frame.collect() + expectedResults = Array( + Row(1000, "Jake", 100000), + Row(1001, "Hello", 70000), + Row(1002, "John", 120000), + Row(1003, "David", 120000), + Row(1004, "David", 0), + Row(1005, "Jane", 90000), + Row(1006, "Tommy", 30000)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } + + test("test uncorrelated exists subquery") { + var frame = sql(s""" + | source = $outerTable + | | where exists [ + | source = $innerTable | where like(name, 'J%') + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + assert(results.length == 7) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val existsSubquery = + Filter( + Exists( + Filter( + UnresolvedFunction( + "like", + Seq(UnresolvedAttribute("name"), Literal("J%")), + isDistinct = false), + inner)), + outer) + val sortedPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), + global = true, + existsSubquery) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + + frame = sql(s""" + | source = $outerTable + | | where not exists [ + | source = $innerTable | where like(name, 'J%') + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + assert(frame.collect().length == 0) + + frame = sql(s""" + | source = $outerTable + | | where exists [ + | source = $innerTable | where like(name, 'X%') + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + assert(frame.collect().length == 0) + } + + test("uncorrelated exists subquery check the return content of inner table is empty or not") { + var frame = sql(s""" + | source = $outerTable + | | where exists [ + | source = $innerTable + | ] + | | eval constant = "Bala" + | | fields constant + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row("Bala"), + Row("Bala"), + Row("Bala"), + Row("Bala"), + Row("Bala"), + Row("Bala"), + Row("Bala")) + assert(results.sameElements(expectedResults)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val existsSubquery = Filter(Exists(inner), outer) + val evalProject = + Project(Seq(UnresolvedStar(None), Alias(Literal("Bala"), "constant")()), existsSubquery) + val expectedPlan = Project(Seq(UnresolvedAttribute("constant")), evalProject) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + + frame = sql(s""" + | source = $outerTable + | | where exists [ + | source = $innerTable | where uid = 999 + | ] + | | eval constant = "Bala" + | | fields constant + | """.stripMargin) + frame.show + assert(frame.collect().length == 0) + } + + test("test nested exists subquery") { + val frame = sql(s""" + | source = $outerTable + | | where exists [ + | source = $innerTable + | | where exists [ + | source = $nestedInnerTable + | | where $nestedInnerTable.occupation = $innerTable.occupation + | ] + | | where id = uid + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1003, "David", 120000), + Row(1002, "John", 120000), + Row(1000, "Jake", 100000), + Row(1005, "Jane", 90000), + Row(1006, "Tommy", 30000)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val inner2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test3")) + val existsSubqueryForOccupation = + Filter( + Exists( + Filter( + EqualTo( + UnresolvedAttribute("spark_catalog.default.flint_ppl_test3.occupation"), + UnresolvedAttribute("spark_catalog.default.flint_ppl_test2.occupation")), + inner2)), + inner1) + val existsSubqueryForId = + Filter( + Exists( + Filter( + EqualTo(UnresolvedAttribute("id"), UnresolvedAttribute("uid")), + existsSubqueryForOccupation)), + outer) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), + global = true, + existsSubqueryForId) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test exists subquery with conjunction of conditions") { + val frame = sql(s""" + | source = $outerTable + | | where exists [ + | source = $innerTable + | | where id = uid AND + | $outerTable.name = $innerTable.name AND + | $outerTable.occupation = $innerTable.occupation + | ] + | | sort - salary + | | fields id, name, salary + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(1003, "David", 120000), Row(1000, "Jake", 100000)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val existsSubquery = Filter( + Exists( + Filter( + And( + And( + EqualTo(UnresolvedAttribute("id"), UnresolvedAttribute("uid")), + EqualTo( + UnresolvedAttribute("spark_catalog.default.flint_ppl_test1.name"), + UnresolvedAttribute("spark_catalog.default.flint_ppl_test2.name"))), + EqualTo( + UnresolvedAttribute("spark_catalog.default.flint_ppl_test1.occupation"), + UnresolvedAttribute("spark_catalog.default.flint_ppl_test2.occupation"))), + inner)), + outer) + val sortedPlan = Sort( + Seq(SortOrder(UnresolvedAttribute("salary"), Descending)), + global = true, + existsSubquery) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("salary")), + sortedPlan) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } +} diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLInSubqueryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLInSubqueryITSuite.scala index ee08e692a..9d8c2c12d 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLInSubqueryITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLInSubqueryITSuite.scala @@ -305,8 +305,6 @@ class FlintSparkPPLInSubqueryITSuite | | sort - salary | | fields id, name, salary | """.stripMargin) - frame.show() - frame.explain(true) val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = Array(Row(1003, "David", 120000), Row(1002, "John", 120000), Row(1006, "Tommy", 30000)) @@ -358,7 +356,6 @@ class FlintSparkPPLInSubqueryITSuite | $innerTable | | fields a.id, a.name, a.salary | """.stripMargin) - frame.explain(true) val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = Array(Row(1003, "David", 120000), Row(1002, "John", 120000), Row(1006, "Tommy", 30000)) diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index dd43007f4..2b916a245 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -113,6 +113,7 @@ APPEND: 'APPEND'; CASE: 'CASE'; ELSE: 'ELSE'; IN: 'IN'; +EXISTS: 'EXISTS'; // LOGICAL KEYWORDS NOT: 'NOT'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index af7e0ec14..7a6f14839 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -351,24 +351,21 @@ percentileAggFunction // expressions expression : logicalExpression - | comparisonExpression | valueExpression ; logicalExpression - : comparisonExpression # comparsion - | NOT logicalExpression # logicalNot + : NOT logicalExpression # logicalNot + | comparisonExpression # comparsion | left = logicalExpression (AND)? right = logicalExpression # logicalAnd | left = logicalExpression OR right = logicalExpression # logicalOr | left = logicalExpression XOR right = logicalExpression # logicalXor | booleanExpression # booleanExpr - | isEmptyExpression # isEmptyExpr ; comparisonExpression : left = valueExpression comparisonOperator right = valueExpression # compareExpr | valueExpression IN valueList # inExpr - | valueExpressionList NOT? IN LT_SQR_PRTHS subSearch RT_SQR_PRTHS # inSubqueryExpr ; valueExpressionList @@ -397,7 +394,10 @@ positionFunction ; booleanExpression - : booleanFunctionCall + : booleanFunctionCall # booleanFunctionCallExpr + | isEmptyExpression # isEmptyExpr + | valueExpressionList NOT? IN LT_SQR_PRTHS subSearch RT_SQR_PRTHS # inSubqueryExpr + | EXISTS LT_SQR_PRTHS subSearch RT_SQR_PRTHS # existsSubqueryExpr ; isEmptyExpression diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index d7db9d0c8..c361ded08 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -19,7 +19,8 @@ import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; -import org.opensearch.sql.ast.expression.InSubquery; +import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; +import org.opensearch.sql.ast.expression.subquery.InSubquery; import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.IsEmpty; import org.opensearch.sql.ast.expression.Let; @@ -28,7 +29,7 @@ import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.QualifiedName; -import org.opensearch.sql.ast.expression.ScalarSubquery; +import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.UnresolvedArgument; import org.opensearch.sql.ast.expression.UnresolvedAttribute; @@ -302,4 +303,8 @@ public T visitFillNull(FillNull fillNull, C context) { public T visitScalarSubquery(ScalarSubquery node, C context) { return visitChildren(node, context); } + + public T visitExistsSubquery(ExistsSubquery node, C context) { + return visitChildren(node, context); + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/subquery/ExistsSubquery.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/subquery/ExistsSubquery.java new file mode 100644 index 000000000..bdd1683ee --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/subquery/ExistsSubquery.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression.subquery; + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.tree.UnresolvedPlan; + +import java.util.List; + +@Getter +@ToString +@EqualsAndHashCode(callSuper = false) +@RequiredArgsConstructor +public class ExistsSubquery extends UnresolvedExpression { + private final UnresolvedPlan query; + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitExistsSubquery(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/InSubquery.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/subquery/InSubquery.java similarity index 87% rename from ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/InSubquery.java rename to ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/subquery/InSubquery.java index ed40e4b45..4a15453e5 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/InSubquery.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/subquery/InSubquery.java @@ -3,16 +3,16 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.sql.ast.expression; +package org.opensearch.sql.ast.expression.subquery; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.tree.UnresolvedPlan; -import java.util.Arrays; import java.util.List; @Getter diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/ScalarSubquery.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/subquery/ScalarSubquery.java similarity index 84% rename from ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/ScalarSubquery.java rename to ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/subquery/ScalarSubquery.java index cccadb717..7c3721ffb 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/ScalarSubquery.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/subquery/ScalarSubquery.java @@ -3,13 +3,14 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.sql.ast.expression; +package org.opensearch.sql.ast.expression.subquery; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.tree.UnresolvedPlan; @Getter diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 38dc4092e..902fc72e3 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -13,6 +13,7 @@ import org.apache.spark.sql.catalyst.expressions.Ascending$; import org.apache.spark.sql.catalyst.expressions.CaseWhen; import org.apache.spark.sql.catalyst.expressions.Descending$; +import org.apache.spark.sql.catalyst.expressions.Exists$; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.InSubquery$; import org.apache.spark.sql.catalyst.expressions.ListQuery$; @@ -40,7 +41,8 @@ import org.opensearch.sql.ast.expression.FieldsMapping; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; -import org.opensearch.sql.ast.expression.InSubquery; +import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; +import org.opensearch.sql.ast.expression.subquery.InSubquery; import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.IsEmpty; import org.opensearch.sql.ast.expression.Let; @@ -49,7 +51,7 @@ import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.ParseMethod; import org.opensearch.sql.ast.expression.QualifiedName; -import org.opensearch.sql.ast.expression.ScalarSubquery; +import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.expression.When; @@ -813,5 +815,19 @@ public Expression visitScalarSubquery(ScalarSubquery node, CatalystPlanContext c Option.empty()); return context.getNamedParseExpressions().push(scalarSubQuery); } + + @Override + public Expression visitExistsSubquery(ExistsSubquery node, CatalystPlanContext context) { + CatalystPlanContext innerContext = new CatalystPlanContext(); + UnresolvedPlan outerPlan = node.getQuery(); + LogicalPlan subSearch = CatalystQueryPlanVisitor.this.visitSubSearch(outerPlan, innerContext); + Expression existsSubQuery = Exists$.MODULE$.apply( + subSearch, + seq(new java.util.ArrayList()), + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty()); + return context.getNamedParseExpressions().push(existsSubQuery); + } } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 4b4697b45..3b98edd77 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -22,7 +22,8 @@ import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Function; -import org.opensearch.sql.ast.expression.InSubquery; +import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; +import org.opensearch.sql.ast.expression.subquery.InSubquery; import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.IntervalUnit; import org.opensearch.sql.ast.expression.IsEmpty; @@ -31,7 +32,7 @@ import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.QualifiedName; -import org.opensearch.sql.ast.expression.ScalarSubquery; +import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.SpanUnit; import org.opensearch.sql.ast.expression.UnresolvedArgument; @@ -393,6 +394,11 @@ public UnresolvedExpression visitScalarSubqueryExpr(OpenSearchPPLParser.ScalarSu return new ScalarSubquery(astBuilder.visitSubSearch(ctx.subSearch())); } + @Override + public UnresolvedExpression visitExistsSubqueryExpr(OpenSearchPPLParser.ExistsSubqueryExprContext ctx) { + return new ExistsSubquery(astBuilder.visitSubSearch(ctx.subSearch())); + } + private QualifiedName visitIdentifiers(List ctx) { return new QualifiedName( ctx.stream() diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExistsSubqueryTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExistsSubqueryTranslatorTestSuite.scala new file mode 100644 index 000000000..02dfe1096 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExistsSubqueryTranslatorTestSuite.scala @@ -0,0 +1,315 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, EqualTo, Exists, GreaterThanOrEqual, LessThan, Literal, Not, SortOrder} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ + +class PPLLogicalPlanExistsSubqueryTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + // Assume outer table contains fields [a, b] + // and inner table contains fields [c, d] + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test where exists (select * from inner where a = c)") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where exists [ + | source = spark_catalog.default.inner | where a = c + | ] + | | sort - a + | | fields a, c + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val subquery = + Filter( + Exists(Filter(EqualTo(UnresolvedAttribute("a"), UnresolvedAttribute("c")), inner)), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("a"), Descending)), global = true, subquery) + val expectedPlan = + Project(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("c")), sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test where exists (select * from inner where a = c and b = d)") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where exists [ + | source = spark_catalog.default.inner | where a = c AND b = d + | ] + | | sort - a + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val existsSubquery = + Filter( + Exists( + Filter( + And( + EqualTo(UnresolvedAttribute("a"), UnresolvedAttribute("c")), + EqualTo(UnresolvedAttribute("b"), UnresolvedAttribute("d"))), + inner)), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("a"), Descending)), global = true, existsSubquery) + val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test where not exists (select * from inner where a = c)") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where not exists [ + | source = spark_catalog.default.inner | where a = c + | ] + | | sort - a + | | fields a, c + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val subquery = + Filter( + Not(Exists(Filter(EqualTo(UnresolvedAttribute("a"), UnresolvedAttribute("c")), inner))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("a"), Descending)), global = true, subquery) + val expectedPlan = + Project(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("c")), sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test where not exists (select * from inner where a = c and b = d)") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where not exists [ + | source = spark_catalog.default.inner | where a = c AND b = d + | ] + | | sort - a + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val existsSubquery = + Filter( + Not( + Exists( + Filter( + And( + EqualTo(UnresolvedAttribute("a"), UnresolvedAttribute("c")), + EqualTo(UnresolvedAttribute("b"), UnresolvedAttribute("d"))), + inner))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("a"), Descending)), global = true, existsSubquery) + val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + // Assume outer table contains fields [a, b] + // and inner1 table contains fields [c, d] + // and inner2 table contains fields [e, f] + test("test nested exists subquery") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where exists [ + | source = spark_catalog.default.inner1 + | | where exists [ + | source = spark_catalog.default.inner2 + | | where c = e + | ] + | | where a = c + | ] + | | sort - a + | | fields a, b + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner1 = UnresolvedRelation(Seq("spark_catalog", "default", "inner1")) + val inner2 = UnresolvedRelation(Seq("spark_catalog", "default", "inner2")) + val subqueryOuter = + Filter( + Exists(Filter(EqualTo(UnresolvedAttribute("c"), UnresolvedAttribute("e")), inner2)), + inner1) + val subqueryInner = + Filter( + Exists( + Filter(EqualTo(UnresolvedAttribute("a"), UnresolvedAttribute("c")), subqueryOuter)), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("a"), Descending)), global = true, subqueryInner) + val expectedPlan = + Project(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test tpch q4: exists subquery with aggregation") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = orders + | | where o_orderdate >= "1993-07-01" AND o_orderdate < "1993-10-01" + | AND exists [ + | source = lineitem + | | where l_orderkey = o_orderkey + | AND l_commitdate < l_receiptdate + | ] + | | stats count(1) as order_count by o_orderpriority + | | sort o_orderpriority + | | fields o_orderpriority, order_count + | """.stripMargin), + context) + + val outer = UnresolvedRelation(Seq("orders")) + val inner = UnresolvedRelation(Seq("lineitem")) + val inSubquery = + Filter( + And( + And( + GreaterThanOrEqual(UnresolvedAttribute("o_orderdate"), Literal("1993-07-01")), + LessThan(UnresolvedAttribute("o_orderdate"), Literal("1993-10-01"))), + Exists( + Filter( + And( + EqualTo(UnresolvedAttribute("l_orderkey"), UnresolvedAttribute("o_orderkey")), + LessThan( + UnresolvedAttribute("l_commitdate"), + UnresolvedAttribute("l_receiptdate"))), + inner))), + outer) + val o_orderpriorityAlias = Alias(UnresolvedAttribute("o_orderpriority"), "o_orderpriority")() + val groupByAttributes = Seq(o_orderpriorityAlias) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(Literal(1)), isDistinct = false), + "order_count")() + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, o_orderpriorityAlias), inSubquery) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("o_orderpriority"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project( + Seq(UnresolvedAttribute("o_orderpriority"), UnresolvedAttribute("order_count")), + sortedPlan) + comparePlans(expectedPlan, logPlan, false) + } + + // We can support q21 when the table alias is supported + ignore("test tpch q21 (partial): multiple exists subquery") { + // select + // s_name, + // count(*) as numwait + // from + // supplier, + // lineitem l1, + // where + // s_suppkey = l1.l_suppkey + // and l1.l_receiptdate > l1.l_commitdate + // and exists ( + // select + // * + // from + // lineitem l2 + // where + // l2.l_orderkey = l1.l_orderkey + // and l2.l_suppkey <> l1.l_suppkey + // ) + // and not exists ( + // select + // * + // from + // lineitem l3 + // where + // l3.l_orderkey = l1.l_orderkey + // and l3.l_suppkey <> l1.l_suppkey + // and l3.l_receiptdate > l3.l_commitdate + // ) + // group by + // s_name + // order by + // numwait desc, + // s_name + // limit 100 + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = supplier + | | join left=s right=l1 on s_suppkey = l1.l_suppkey + | lineitem as l1 + | | where l1.l_receiptdate > l1.l_commitdate + | | where exists [ + | source = lineitem as l2 + | | where l2.l_orderkey = l1.l_orderkey and + | l2.l_suppkey <> l1.l_suppkey + | ] + | | where not exists [ + | source = lineitem as l3 + | | where l3.l_orderkey = l1.l_orderkey and + | l3.l_suppkey <> l1.l_suppkey and + | l3.l_receiptdate > l3.l_commitdate + | ] + | | stats count(1) as numwait by s_name + | | sort - numwait, s_name + | | fields s_name, numwait + | | limit 100 + | """.stripMargin), + context) + } +}