diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 28c4e0a01..c553d483f 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -350,6 +350,31 @@ source = supplier nation | sort s_name ``` +#### **ScalarSubquery** +[See additional command details](ppl-subquery-command.md) + +Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table inner, `e`, `f` are fields of table nested +**Uncorrelated scalar subquery in Select** +- `source = outer | eval m = [ source = inner | stats max(c) ] | fields m, a` +- `source = outer | eval m = [ source = inner | stats max(c) ] + b | fields m, a` + +**Uncorrelated scalar subquery in Select and Where** +- `source = outer | where a > [ source = inner | stats min(c) ] | eval m = [ source = inner | stats max(c) ] | fields m, a` + +**Correlated scalar subquery in Select** +- `source = outer | eval m = [ source = inner | where outer.b = inner.d | stats max(c) ] | fields m, a` +- `source = outer | eval m = [ source = inner | where b = d | stats max(c) ] | fields m, a` +- `source = outer | eval m = [ source = inner | where outer.b > inner.d | stats max(c) ] | fields m, a` + +**Correlated scalar subquery in Where** +- `source = outer | where a = [ source = inner | where outer.b = inner.d | stats max(c) ]` +- `source = outer | where a = [ source = inner | where b = d | stats max(c) ]` +- `source = outer | where [ source = inner | where outer.b = inner.d OR inner.d = 1 | stats count() ] > 0 | fields a` + +**Nested scalar subquery** +- `source = outer | where a = [ source = inner | stats max(c) | sort c ] OR b = [ source = inner | where c = 1 | stats min(d) | sort d ]` +- `source = outer | where a = [ source = inner | where c = [ source = nested | stats max(e) by f | sort f ] | stats max(d) by c | sort c | head 1 ]` + --- #### Experimental Commands: diff --git a/docs/ppl-lang/ppl-subquery-command.md b/docs/ppl-lang/ppl-subquery-command.md index 85cbe1dca..1762306d2 100644 --- a/docs/ppl-lang/ppl-subquery-command.md +++ b/docs/ppl-lang/ppl-subquery-command.md @@ -4,7 +4,7 @@ The subquery command should be implemented using a clean, logical syntax that integrates with existing PPL structure. ```sql -source=logs | where field in (subquery source=events | where condition | return field) +source=logs | where field in [ subquery source=events | where condition | fields field ] ``` In this example, the primary search (`source=logs`) is filtered by results from the subquery (`source=events`). @@ -14,7 +14,7 @@ The subquery command should allow nested queries to be as complex as necessary, Example: ```sql - source=logs | where field in (subquery source=users | where user in (subquery source=actions | where action="login")) + source=logs | where id in [ subquery source=users | where user in [ subquery source=actions | where action="login" | fields user] | fields uid ] ``` For additional info See [Issue](https://github.com/opensearch-project/opensearch-spark/issues/661) @@ -112,6 +112,83 @@ source = supplier | sort s_name ``` +**ScalarSubquery usage** + +Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table inner, `e`, `f` are fields of table nested + +**Uncorrelated scalar subquery in Select** +- `source = outer | eval m = [ source = inner | stats max(c) ] | fields m, a` +- `source = outer | eval m = [ source = inner | stats max(c) ] + b | fields m, a` + +**Uncorrelated scalar subquery in Select and Where** +- `source = outer | where a > [ source = inner | stats min(c) ] | eval m = [ source = inner | stats max(c) ] | fields m, a` + +**Correlated scalar subquery in Select** +- `source = outer | eval m = [ source = inner | where outer.b = inner.d | stats max(c) ] | fields m, a` +- `source = outer | eval m = [ source = inner | where b = d | stats max(c) ] | fields m, a` +- `source = outer | eval m = [ source = inner | where outer.b > inner.d | stats max(c) ] | fields m, a` + +**Correlated scalar subquery in Where** +- `source = outer | where a = [ source = inner | where outer.b = inner.d | stats max(c) ]` +- `source = outer | where a = [ source = inner | where b = d | stats max(c) ]` +- `source = outer | where [ source = inner | where outer.b = inner.d OR inner.d = 1 | stats count() ] > 0 | fields a` + +**Nested scalar subquery** +- `source = outer | where a = [ source = inner | stats max(c) | sort c ] OR b = [ source = inner | where c = 1 | stats min(d) | sort d ]` +- `source = outer | where a = [ source = inner | where c = [ source = nested | stats max(e) by f | sort f ] | stats max(d) by c | sort c | head 1 ]` + +_SQL Migration examples with Scalar-Subquery PPL:_ +Example 1 +```sql +SELECT * +FROM outer +WHERE a = (SELECT max(c) + FROM inner1 + WHERE c = (SELECT max(e) + FROM inner2 + GROUP BY f + ORDER BY f + ) + GROUP BY c + ORDER BY c + LIMIT 1) +``` +Rewritten by PPL ScalarSubquery query: +```sql +source = spark_catalog.default.outer +| where a = [ + source = spark_catalog.default.inner1 + | where c = [ + source = spark_catalog.default.inner2 + | stats max(e) by f + | sort f + ] + | stats max(d) by c + | sort c + | head 1 + ] +``` +Example 2 +```sql +SELECT * FROM outer +WHERE a = (SELECT max(c) + FROM inner + ORDER BY c) +OR b = (SELECT min(d) + FROM inner + WHERE c = 1 + ORDER BY d) +``` +Rewritten by PPL ScalarSubquery query: +```sql +source = spark_catalog.default.outer +| where a = [ + source = spark_catalog.default.inner | stats max(c) | sort c + ] OR b = [ + source = spark_catalog.default.inner | where c = 1 | stats min(d) | sort d + ] +``` + ### **Additional Context** The most cases in the description is to request a `InSubquery` expression. diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLScalarSubqueryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLScalarSubqueryITSuite.scala new file mode 100644 index 000000000..654add8d8 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLScalarSubqueryITSuite.scala @@ -0,0 +1,414 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, EqualTo, GreaterThan, Literal, Or, ScalarSubquery, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLScalarSubqueryITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val outerTable = "spark_catalog.default.flint_ppl_test1" + private val innerTable = "spark_catalog.default.flint_ppl_test2" + private val nestedInnerTable = "spark_catalog.default.flint_ppl_test3" + + override def beforeAll(): Unit = { + super.beforeAll() + createPeopleTable(outerTable) + sql(s""" + | INSERT INTO $outerTable + | VALUES (1006, 'Tommy', 'Teacher', 'USA', 30000) + | """.stripMargin) + createWorkInformationTable(innerTable) + createOccupationTable(nestedInnerTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test uncorrelated scalar subquery in select") { + val frame = sql(s""" + | source = $outerTable + | | eval count_dept = [ + | source = $innerTable | stats count(department) + | ] + | | fields name, count_dept + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row("Jake", 5), + Row("Hello", 5), + Row("John", 5), + Row("David", 5), + Row("David", 5), + Row("Jane", 5), + Row("Tommy", 5)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction( + Seq("COUNT"), + Seq(UnresolvedAttribute("department")), + isDistinct = false), + "count(department)")()) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, inner) + val scalarSubqueryExpr = ScalarSubquery(aggregatePlan) + val evalProjectList = Seq(UnresolvedStar(None), Alias(scalarSubqueryExpr, "count_dept")()) + val evalProject = Project(evalProjectList, outer) + val expectedPlan = + Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("count_dept")), evalProject) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test uncorrelated scalar subquery in expression in select") { + val frame = sql(s""" + | source = $outerTable + | | eval count_dept_plus = [ + | source = $innerTable | stats count(department) + | ] + 10 + | | fields name, count_dept_plus + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row("Jake", 15), + Row("Hello", 15), + Row("John", 15), + Row("David", 15), + Row("David", 15), + Row("Jane", 15), + Row("Tommy", 15)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction( + Seq("COUNT"), + Seq(UnresolvedAttribute("department")), + isDistinct = false), + "count(department)")()) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, inner) + val scalarSubqueryExpr = ScalarSubquery(aggregatePlan) + val scalarSubqueryPlus = + UnresolvedFunction(Seq("+"), Seq(scalarSubqueryExpr, Literal(10)), isDistinct = false) + val evalProjectList = + Seq(UnresolvedStar(None), Alias(scalarSubqueryPlus, "count_dept_plus")()) + val evalProject = Project(evalProjectList, outer) + val expectedPlan = + Project( + Seq(UnresolvedAttribute("name"), UnresolvedAttribute("count_dept_plus")), + evalProject) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test uncorrelated scalar subquery in select and where") { + val frame = sql(s""" + | source = $outerTable + | | eval count_dept = [ + | source = $innerTable | stats count(department) + | ] + | | where id > [ + | source = $innerTable | stats count(department) + | ] + 999 + | | fields name, count_dept + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row("Jane", 5), Row("Tommy", 5)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val countAgg = Seq( + Alias( + UnresolvedFunction( + Seq("COUNT"), + Seq(UnresolvedAttribute("department")), + isDistinct = false), + "count(department)")()) + val countAggPlan = Aggregate(Seq(), countAgg, inner) + val countScalarSubqueryExpr = ScalarSubquery(countAggPlan) + val plusScalarSubquery = + UnresolvedFunction(Seq("+"), Seq(countScalarSubqueryExpr, Literal(999)), isDistinct = false) + + val evalProjectList = + Seq(UnresolvedStar(None), Alias(countScalarSubqueryExpr, "count_dept")()) + val evalProject = Project(evalProjectList, outer) + val filter = Filter(GreaterThan(UnresolvedAttribute("id"), plusScalarSubquery), evalProject) + val expectedPlan = + Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("count_dept")), filter) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test correlated scalar subquery in select") { + val frame = sql(s""" + | source = $outerTable + | | eval count_dept = [ + | source = $innerTable + | | where id = uid | stats count(department) + | ] + | | fields id, name, count_dept + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1000, "Jake", 1), + Row(1001, "Hello", 0), + Row(1002, "John", 1), + Row(1003, "David", 1), + Row(1004, "David", 0), + Row(1005, "Jane", 1), + Row(1006, "Tommy", 1)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction( + Seq("COUNT"), + Seq(UnresolvedAttribute("department")), + isDistinct = false), + "count(department)")()) + val filter = Filter(EqualTo(UnresolvedAttribute("id"), UnresolvedAttribute("uid")), inner) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, filter) + val scalarSubqueryExpr = ScalarSubquery(aggregatePlan) + val evalProjectList = Seq(UnresolvedStar(None), Alias(scalarSubqueryExpr, "count_dept")()) + val evalProject = Project(evalProjectList, outer) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("count_dept")), + evalProject) + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test correlated scalar subquery in select with non-equal") { + val frame = sql(s""" + | source = $outerTable + | | eval count_dept = [ + | source = $innerTable | where id > uid | stats count(department) + | ] + | | fields id, name, count_dept + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1000, "Jake", 0), + Row(1001, "Hello", 1), + Row(1002, "John", 1), + Row(1003, "David", 2), + Row(1004, "David", 3), + Row(1005, "Jane", 3), + Row(1006, "Tommy", 4)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction( + Seq("COUNT"), + Seq(UnresolvedAttribute("department")), + isDistinct = false), + "count(department)")()) + val filter = Filter(GreaterThan(UnresolvedAttribute("id"), UnresolvedAttribute("uid")), inner) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, filter) + val scalarSubqueryExpr = ScalarSubquery(aggregatePlan) + val evalProjectList = Seq(UnresolvedStar(None), Alias(scalarSubqueryExpr, "count_dept")()) + val evalProject = Project(evalProjectList, outer) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("count_dept")), + evalProject) + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test correlated scalar subquery in where") { + val frame = sql(s""" + | source = $outerTable + | | where id = [ + | source = $innerTable | where id = uid | stats max(uid) + | ] + | | fields id, name + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1000, "Jake"), + Row(1002, "John"), + Row(1003, "David"), + Row(1005, "Jane"), + Row(1006, "Tommy")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("MAX"), Seq(UnresolvedAttribute("uid")), isDistinct = false), + "max(uid)")()) + val innerFilter = + Filter(EqualTo(UnresolvedAttribute("id"), UnresolvedAttribute("uid")), inner) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, innerFilter) + val scalarSubqueryExpr = ScalarSubquery(aggregatePlan) + val outerFilter = Filter(EqualTo(UnresolvedAttribute("id"), scalarSubqueryExpr), outer) + val expectedPlan = + Project(Seq(UnresolvedAttribute("id"), UnresolvedAttribute("name")), outerFilter) + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test disjunctive correlated scalar subquery") { + val frame = sql(s""" + | source = $outerTable + | | where [ + | source = $innerTable | where id = uid OR uid = 1010 | stats count() + | ] > 0 + | | fields id, name + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1000, "Jake"), + Row(1002, "John"), + Row(1003, "David"), + Row(1005, "Jane"), + Row(1006, "Tommy")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "count()")()) + val innerFilter = + Filter( + Or( + EqualTo(UnresolvedAttribute("id"), UnresolvedAttribute("uid")), + EqualTo(UnresolvedAttribute("uid"), Literal(1010))), + inner) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, innerFilter) + val scalarSubqueryExpr = ScalarSubquery(aggregatePlan) + val outerFilter = Filter(GreaterThan(scalarSubqueryExpr, Literal(0)), outer) + val expectedPlan = + Project(Seq(UnresolvedAttribute("id"), UnresolvedAttribute("name")), outerFilter) + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test two scalar subqueries in OR") { + val frame = sql(s""" + | source = $outerTable + | | where id = [ + | source = $innerTable | sort uid | stats max(uid) + | ] OR id = [ + | source = $innerTable | sort uid | where department = 'DATA' | stats min(uid) + | ] + | | fields id, name + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(1002, "John"), Row(1006, "Tommy")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val maxExpr = + UnresolvedFunction(Seq("MAX"), Seq(UnresolvedAttribute("uid")), isDistinct = false) + val minExpr = + UnresolvedFunction(Seq("MIN"), Seq(UnresolvedAttribute("uid")), isDistinct = false) + val maxAgg = Seq(Alias(maxExpr, "max(uid)")()) + val minAgg = Seq(Alias(minExpr, "min(uid)")()) + val subquery1 = + Sort(Seq(SortOrder(UnresolvedAttribute("uid"), Ascending)), global = true, inner) + val subquery2 = + Sort(Seq(SortOrder(UnresolvedAttribute("uid"), Ascending)), global = true, inner) + val maxAggPlan = Aggregate(Seq(), maxAgg, subquery1) + val minAggPlan = + Aggregate( + Seq(), + minAgg, + Filter(EqualTo(UnresolvedAttribute("department"), Literal("DATA")), subquery2)) + val maxScalarSubqueryExpr = ScalarSubquery(maxAggPlan) + val minScalarSubqueryExpr = ScalarSubquery(minAggPlan) + val filterOr = Filter( + Or( + EqualTo(UnresolvedAttribute("id"), maxScalarSubqueryExpr), + EqualTo(UnresolvedAttribute("id"), minScalarSubqueryExpr)), + outer) + val expectedPlan = + Project(Seq(UnresolvedAttribute("id"), UnresolvedAttribute("name")), filterOr) + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test nested scalar subquery") { + val frame = sql(s""" + | source = $outerTable + | | where id = [ + | source = $innerTable + | | where uid = [ + | source = $nestedInnerTable + | | stats min(salary) + | ] + 1000 + | | sort department + | | stats max(uid) + | ] + | | fields id, name + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(1000, "Jake")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } +}