From ef5a6346735c4987c3785712611e4b6784c712e2 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Tue, 8 Oct 2024 18:18:30 +0800 Subject: [PATCH 1/2] Support ScalarSubquery PPL Signed-off-by: Lantao Jin --- .../src/main/antlr4/OpenSearchPPLParser.g4 | 1 + .../sql/ast/AbstractNodeVisitor.java | 5 + .../sql/ast/expression/ScalarSubquery.java | 26 + .../sql/ppl/CatalystQueryPlanVisitor.java | 17 + .../sql/ppl/parser/AstExpressionBuilder.java | 6 + ...lanScalarSubqueryTranslatorTestSuite.scala | 475 ++++++++++++++++++ 6 files changed, 530 insertions(+) create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/ScalarSubquery.java create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanScalarSubqueryTranslatorTestSuite.scala diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 06b3166f0..008ffcb78 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -359,6 +359,7 @@ valueExpression | positionFunction # positionFunctionCall | caseFunction # caseExpr | LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr + | LT_SQR_PRTHS subSearch RT_SQR_PRTHS # scalarSubqueryExpr ; primaryExpression diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 76f9479f4..d6aba3e0c 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -28,6 +28,7 @@ import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.QualifiedName; +import org.opensearch.sql.ast.expression.ScalarSubquery; import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.UnresolvedArgument; import org.opensearch.sql.ast.expression.UnresolvedAttribute; @@ -293,4 +294,8 @@ public T visitExplain(Explain node, C context) { public T visitInSubquery(InSubquery node, C context) { return visitChildren(node, context); } + + public T visitScalarSubquery(ScalarSubquery node, C context) { + return visitChildren(node, context); + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/ScalarSubquery.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/ScalarSubquery.java new file mode 100644 index 000000000..cccadb717 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/expression/ScalarSubquery.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.tree.UnresolvedPlan; + +@Getter +@ToString +@EqualsAndHashCode(callSuper = false) +@RequiredArgsConstructor +public class ScalarSubquery extends UnresolvedExpression { + private final UnresolvedPlan query; + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitScalarSubquery(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index bd1785c85..3abacb463 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -17,6 +17,7 @@ import org.apache.spark.sql.catalyst.expressions.ListQuery$; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.Predicate; +import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$; import org.apache.spark.sql.catalyst.expressions.SortDirection; import org.apache.spark.sql.catalyst.expressions.SortOrder; import org.apache.spark.sql.catalyst.plans.logical.*; @@ -47,6 +48,7 @@ import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.ParseMethod; import org.opensearch.sql.ast.expression.QualifiedName; +import org.opensearch.sql.ast.expression.ScalarSubquery; import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.expression.When; @@ -774,5 +776,20 @@ public Expression visitInSubquery(InSubquery node, CatalystPlanContext outerCont Option.empty())); return outerContext.getNamedParseExpressions().push(inSubQuery); } + + @Override + public Expression visitScalarSubquery(ScalarSubquery node, CatalystPlanContext context) { + CatalystPlanContext innerContext = new CatalystPlanContext(); + UnresolvedPlan outerPlan = node.getQuery(); + LogicalPlan subSearch = CatalystQueryPlanVisitor.this.visitSubSearch(outerPlan, innerContext); + Expression scalarSubQuery = ScalarSubquery$.MODULE$.apply( + subSearch, + seq(new java.util.ArrayList()), + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + Option.empty()); + return context.getNamedParseExpressions().push(scalarSubQuery); + } } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index f5e9269be..4b4697b45 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -31,6 +31,7 @@ import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.QualifiedName; +import org.opensearch.sql.ast.expression.ScalarSubquery; import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.SpanUnit; import org.opensearch.sql.ast.expression.UnresolvedArgument; @@ -387,6 +388,11 @@ public UnresolvedExpression visitInSubqueryExpr(OpenSearchPPLParser.InSubqueryEx return ctx.NOT() != null ? new Not(expr) : expr; } + @Override + public UnresolvedExpression visitScalarSubqueryExpr(OpenSearchPPLParser.ScalarSubqueryExprContext ctx) { + return new ScalarSubquery(astBuilder.visitSubSearch(ctx.subSearch())); + } + private QualifiedName visitIdentifiers(List ctx) { return new QualifiedName( ctx.stream() diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanScalarSubqueryTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanScalarSubqueryTranslatorTestSuite.scala new file mode 100644 index 000000000..c76e7e538 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanScalarSubqueryTranslatorTestSuite.scala @@ -0,0 +1,475 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, EqualTo, GreaterThan, Literal, Or, ScalarSubquery, SortOrder} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ + +/** Assume the table outer contains column a and b, table inner contains column c and d */ +class PPLLogicalPlanScalarSubqueryTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test uncorrelated scalar subquery in select") { + // select (select max(c) as max_c from inner), a from outer + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | eval max_c = [ + | source = spark_catalog.default.inner | stats max(c) + | ] + | | fields max_c, a + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("MAX"), Seq(UnresolvedAttribute("c")), isDistinct = false), + "max(c)")()) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, inner) + val scalarSubqueryExpr = ScalarSubquery(aggregatePlan) + val evalProjectList = Seq(UnresolvedStar(None), Alias(scalarSubqueryExpr, "max_c")()) + val evalProject = Project(evalProjectList, outer) + val expectedPlan = + Project(Seq(UnresolvedAttribute("max_c"), UnresolvedAttribute("a")), evalProject) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test uncorrelated scalar subquery in expression in select") { + // select (select max(c) as max_c from inner) + a, b from outer + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | eval max_c_plus_a = [ + | source = spark_catalog.default.inner | stats max(c) + | ] + a + | | fields max_c, b + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("MAX"), Seq(UnresolvedAttribute("c")), isDistinct = false), + "max(c)")()) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, inner) + val scalarSubqueryExpr = ScalarSubquery(aggregatePlan) + val scalarSubqueryPlusC = UnresolvedFunction( + Seq("+"), + Seq(scalarSubqueryExpr, UnresolvedAttribute("a")), + isDistinct = false) + val evalProjectList = Seq(UnresolvedStar(None), Alias(scalarSubqueryPlusC, "max_c_plus_a")()) + val evalProject = Project(evalProjectList, outer) + val expectedPlan = + Project(Seq(UnresolvedAttribute("max_c"), UnresolvedAttribute("b")), evalProject) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test uncorrelated scalar subquery in select and where") { + // select (select max(c) from inner), a from outer where b > (select min(c) from inner) + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | eval max_c = [ + | source = spark_catalog.default.inner | stats max(c) + | ] + | | where b > [ + | source = spark_catalog.default.inner | stats min(c) + | ] + | | fields max_c, a + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val maxAgg = Seq( + Alias( + UnresolvedFunction(Seq("MAX"), Seq(UnresolvedAttribute("c")), isDistinct = false), + "max(c)")()) + val minAgg = Seq( + Alias( + UnresolvedFunction(Seq("MIN"), Seq(UnresolvedAttribute("c")), isDistinct = false), + "min(c)")()) + val maxAggPlan = Aggregate(Seq(), maxAgg, inner) + val minAggPlan = Aggregate(Seq(), minAgg, inner) + val maxScalarSubqueryExpr = ScalarSubquery(maxAggPlan) + val minScalarSubqueryExpr = ScalarSubquery(minAggPlan) + + val evalProjectList = Seq(UnresolvedStar(None), Alias(maxScalarSubqueryExpr, "max_c")()) + val evalProject = Project(evalProjectList, outer) + val filter = Filter(GreaterThan(UnresolvedAttribute("b"), minScalarSubqueryExpr), evalProject) + val expectedPlan = + Project(Seq(UnresolvedAttribute("max_c"), UnresolvedAttribute("a")), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test correlated scalar subquery in select") { + // select (select max(c) from inner where b = d), a from outer + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | eval max_c = [ + | source = spark_catalog.default.inner | where b = d | stats max(c) + | ] + | | fields max_c, a + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("MAX"), Seq(UnresolvedAttribute("c")), isDistinct = false), + "max(c)")()) + val filter = Filter(EqualTo(UnresolvedAttribute("b"), UnresolvedAttribute("d")), inner) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, filter) + val scalarSubqueryExpr = ScalarSubquery(aggregatePlan) + val evalProjectList = Seq(UnresolvedStar(None), Alias(scalarSubqueryExpr, "max_c")()) + val evalProject = Project(evalProjectList, outer) + val expectedPlan = + Project(Seq(UnresolvedAttribute("max_c"), UnresolvedAttribute("a")), evalProject) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test correlated scalar subquery in select with non-equal") { + // select (select max(c) from inner where b > d), a from outer + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | eval max_c = [ + | source = spark_catalog.default.inner | where b > d | stats max(c) + | ] + | | fields max_c, a + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("MAX"), Seq(UnresolvedAttribute("c")), isDistinct = false), + "max(c)")()) + val filter = Filter(GreaterThan(UnresolvedAttribute("b"), UnresolvedAttribute("d")), inner) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, filter) + val scalarSubqueryExpr = ScalarSubquery(aggregatePlan) + val evalProjectList = Seq(UnresolvedStar(None), Alias(scalarSubqueryExpr, "max_c")()) + val evalProject = Project(evalProjectList, outer) + val expectedPlan = + Project(Seq(UnresolvedAttribute("max_c"), UnresolvedAttribute("a")), evalProject) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test correlated scalar subquery in where") { + // select * from outer where a = (select max(c) from inner where b = d) + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where a = [ + | source = spark_catalog.default.inner | where b = d | stats max(c) + | ] + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("MAX"), Seq(UnresolvedAttribute("c")), isDistinct = false), + "max(c)")()) + val innerFilter = Filter(EqualTo(UnresolvedAttribute("b"), UnresolvedAttribute("d")), inner) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, innerFilter) + val scalarSubqueryExpr = ScalarSubquery(aggregatePlan) + val outerFilter = Filter(EqualTo(UnresolvedAttribute("a"), scalarSubqueryExpr), outer) + val expectedPlan = Project(Seq(UnresolvedStar(None)), outerFilter) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test disjunctive correlated scalar subquery") { + // select a from outer where (select count(*) from inner where b = d or d = 1> 0) + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where [ + | source = spark_catalog.default.inner | where b = d OR d = 1 | stats count() + | ] > 0 + | | fields a + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "count()")()) + val innerFilter = + Filter( + Or( + EqualTo(UnresolvedAttribute("b"), UnresolvedAttribute("d")), + EqualTo(UnresolvedAttribute("d"), Literal(1))), + inner) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, innerFilter) + val scalarSubqueryExpr = ScalarSubquery(aggregatePlan) + val outerFilter = Filter(GreaterThan(scalarSubqueryExpr, Literal(0)), outer) + val expectedPlan = Project(Seq(UnresolvedAttribute("a")), outerFilter) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + // TODO a bug when the filter contains parenthetic expressions + ignore("test disjunctive correlated scalar subquery 2") { + // select c + // from outer + // where (select count(*) + // from inner + // where (b = d and b = 2) or (b = d and d = 1)) > 0 + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where [ + | source = spark_catalog.default.inner | where b = d AND b = 2 OR b = d AND b = 1 | stats count() + | ] > 0 + | | fields c + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "count()")()) + val innerFilter = + Filter( + Or( + And( + EqualTo(UnresolvedAttribute("b"), UnresolvedAttribute("d")), + EqualTo(UnresolvedAttribute("b"), Literal(2))), + And( + EqualTo(UnresolvedAttribute("b"), UnresolvedAttribute("d")), + EqualTo(UnresolvedAttribute("b"), Literal(1)))), + inner) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, innerFilter) + val scalarSubqueryExpr = ScalarSubquery(aggregatePlan) + val outerFilter = Filter(GreaterThan(scalarSubqueryExpr, Literal(0)), outer) + val expectedPlan = Project(Seq(UnresolvedAttribute("c")), outerFilter) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test two scalar subqueries in OR") { + // SELECT * FROM outer + // WHERE a = (SELECT max(c) + // FROM inner + // ORDER BY c) + // OR b = (SELECT min(d) + // FROM inner + // WHERE c = 1 + // ORDER BY d) + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where a = [ + | source = spark_catalog.default.inner | stats max(c) | sort c + | ] OR b = [ + | source = spark_catalog.default.inner | where c = 1 | stats min(d) | sort d + | ] + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val maxExpr = + UnresolvedFunction(Seq("MAX"), Seq(UnresolvedAttribute("c")), isDistinct = false) + val minExpr = + UnresolvedFunction(Seq("MIN"), Seq(UnresolvedAttribute("d")), isDistinct = false) + val maxAgg = Seq(Alias(maxExpr, "max(c)")()) + val minAgg = Seq(Alias(minExpr, "min(d)")()) + val maxAggPlan = Aggregate(Seq(), maxAgg, inner) + val minAggPlan = + Aggregate(Seq(), minAgg, Filter(EqualTo(UnresolvedAttribute("c"), Literal(1)), inner)) + val subquery1 = + Sort(Seq(SortOrder(UnresolvedAttribute("c"), Ascending)), global = true, maxAggPlan) + val maxScalarSubqueryExpr = ScalarSubquery(subquery1) + val subquery2 = + Sort(Seq(SortOrder(UnresolvedAttribute("d"), Ascending)), global = true, minAggPlan) + val minScalarSubqueryExpr = ScalarSubquery(subquery2) + val filterOr = Filter( + Or( + EqualTo(UnresolvedAttribute("a"), maxScalarSubqueryExpr), + EqualTo(UnresolvedAttribute("b"), minScalarSubqueryExpr)), + outer) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filterOr) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + /** + * table outer contains column a and b, table inner1 contains column c and d, table inner2 + * contains column e and f + */ + test("test nested scalar subquery") { + // SELECT * + // FROM outer + // WHERE a = (SELECT max(c) + // FROM inner1 + // WHERE c = (SELECT max(e) + // FROM inner2 + // GROUP BY f + // ORDER BY f + // ) + // GROUP BY c + // ORDER BY c + // LIMIT 1) + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where a = [ + | source = spark_catalog.default.inner1 + | | where c = [ + | source = spark_catalog.default.inner2 + | | stats max(e) by f + | | sort f + | ] + | | stats max(d) by c + | | sort c + | | head 1 + | ] + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner1 = UnresolvedRelation(Seq("spark_catalog", "default", "inner1")) + val inner2 = UnresolvedRelation(Seq("spark_catalog", "default", "inner2")) + val maxExprE = + UnresolvedFunction(Seq("MAX"), Seq(UnresolvedAttribute("e")), isDistinct = false) + val aggMaxE = Seq(Alias(maxExprE, "max(e)")(), Alias(UnresolvedAttribute("f"), "f")()) + val aggregateInner = Aggregate(Seq(Alias(UnresolvedAttribute("f"), "f")()), aggMaxE, inner2) + val subqueryInner = + Sort(Seq(SortOrder(UnresolvedAttribute("f"), Ascending)), global = true, aggregateInner) + val filterInner = + Filter(EqualTo(UnresolvedAttribute("c"), ScalarSubquery(subqueryInner)), inner1) + val maxExprD = + UnresolvedFunction(Seq("MAX"), Seq(UnresolvedAttribute("d")), isDistinct = false) + val aggMaxD = Seq(Alias(maxExprD, "max(d)")(), Alias(UnresolvedAttribute("c"), "c")()) + val aggregateOuter = + Aggregate(Seq(Alias(UnresolvedAttribute("c"), "c")()), aggMaxD, filterInner) + val sort = + Sort(Seq(SortOrder(UnresolvedAttribute("c"), Ascending)), global = true, aggregateOuter) + val subqueryOuter = GlobalLimit(Literal(1), LocalLimit(Literal(1), sort)) + val filterOuter = + Filter(EqualTo(UnresolvedAttribute("a"), ScalarSubquery(subqueryOuter)), outer) + val expectedPlan = Project(Seq(UnresolvedStar(None)), filterOuter) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + // TODO eval command with stats function is unsupported + ignore("test nested scalar subquery 2") { + // SELECT * + // FROM outer + // WHERE a = (SELECT max(c) + // FROM inner1 + // WHERE c = (SELECT max(e) + // FROM inner2 + // GROUP BY f + // ORDER BY max(e) + // ) + // GROUP BY c + // ORDER BY max(e) + // LIMIT 1) + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where a = [ + | source = spark_catalog.default.inner1 + | | where c = [ + | source = spark_catalog.default.inner2 + | | eval max_e = max(e) + | | stats max(e) by f + | | sort max_e + | ] + | | eval max_c = max(c) + | | stats max(e) by c + | | sort max_c + | | head 1 + | ] + | """.stripMargin), + context) + } + + // TODO currently statsBy expression is unsupported. + ignore("test correlated scalar subquery in group by") { + // select b, (select count(a) from inner where b = d) count_a from outer group by 1, 2 + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | eval count_a = [ + | source = spark_catalog.default.inner | where b = d | stats count(a) + | ] + | | stats by b, count_a + | """.stripMargin), + context) + } +} From 6bf0ebdaa6fd978ef6433c4cb2a06f65c516072f Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Wed, 9 Oct 2024 19:04:11 +0800 Subject: [PATCH 2/2] add docs and IT Signed-off-by: Lantao Jin --- docs/ppl-lang/PPL-Example-Commands.md | 25 ++ docs/ppl-lang/ppl-subquery-command.md | 81 +++- .../FlintSparkPPLScalarSubqueryITSuite.scala | 414 ++++++++++++++++++ 3 files changed, 518 insertions(+), 2 deletions(-) create mode 100644 integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLScalarSubqueryITSuite.scala diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 28c4e0a01..c553d483f 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -350,6 +350,31 @@ source = supplier nation | sort s_name ``` +#### **ScalarSubquery** +[See additional command details](ppl-subquery-command.md) + +Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table inner, `e`, `f` are fields of table nested +**Uncorrelated scalar subquery in Select** +- `source = outer | eval m = [ source = inner | stats max(c) ] | fields m, a` +- `source = outer | eval m = [ source = inner | stats max(c) ] + b | fields m, a` + +**Uncorrelated scalar subquery in Select and Where** +- `source = outer | where a > [ source = inner | stats min(c) ] | eval m = [ source = inner | stats max(c) ] | fields m, a` + +**Correlated scalar subquery in Select** +- `source = outer | eval m = [ source = inner | where outer.b = inner.d | stats max(c) ] | fields m, a` +- `source = outer | eval m = [ source = inner | where b = d | stats max(c) ] | fields m, a` +- `source = outer | eval m = [ source = inner | where outer.b > inner.d | stats max(c) ] | fields m, a` + +**Correlated scalar subquery in Where** +- `source = outer | where a = [ source = inner | where outer.b = inner.d | stats max(c) ]` +- `source = outer | where a = [ source = inner | where b = d | stats max(c) ]` +- `source = outer | where [ source = inner | where outer.b = inner.d OR inner.d = 1 | stats count() ] > 0 | fields a` + +**Nested scalar subquery** +- `source = outer | where a = [ source = inner | stats max(c) | sort c ] OR b = [ source = inner | where c = 1 | stats min(d) | sort d ]` +- `source = outer | where a = [ source = inner | where c = [ source = nested | stats max(e) by f | sort f ] | stats max(d) by c | sort c | head 1 ]` + --- #### Experimental Commands: diff --git a/docs/ppl-lang/ppl-subquery-command.md b/docs/ppl-lang/ppl-subquery-command.md index 85cbe1dca..1762306d2 100644 --- a/docs/ppl-lang/ppl-subquery-command.md +++ b/docs/ppl-lang/ppl-subquery-command.md @@ -4,7 +4,7 @@ The subquery command should be implemented using a clean, logical syntax that integrates with existing PPL structure. ```sql -source=logs | where field in (subquery source=events | where condition | return field) +source=logs | where field in [ subquery source=events | where condition | fields field ] ``` In this example, the primary search (`source=logs`) is filtered by results from the subquery (`source=events`). @@ -14,7 +14,7 @@ The subquery command should allow nested queries to be as complex as necessary, Example: ```sql - source=logs | where field in (subquery source=users | where user in (subquery source=actions | where action="login")) + source=logs | where id in [ subquery source=users | where user in [ subquery source=actions | where action="login" | fields user] | fields uid ] ``` For additional info See [Issue](https://github.com/opensearch-project/opensearch-spark/issues/661) @@ -112,6 +112,83 @@ source = supplier | sort s_name ``` +**ScalarSubquery usage** + +Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table inner, `e`, `f` are fields of table nested + +**Uncorrelated scalar subquery in Select** +- `source = outer | eval m = [ source = inner | stats max(c) ] | fields m, a` +- `source = outer | eval m = [ source = inner | stats max(c) ] + b | fields m, a` + +**Uncorrelated scalar subquery in Select and Where** +- `source = outer | where a > [ source = inner | stats min(c) ] | eval m = [ source = inner | stats max(c) ] | fields m, a` + +**Correlated scalar subquery in Select** +- `source = outer | eval m = [ source = inner | where outer.b = inner.d | stats max(c) ] | fields m, a` +- `source = outer | eval m = [ source = inner | where b = d | stats max(c) ] | fields m, a` +- `source = outer | eval m = [ source = inner | where outer.b > inner.d | stats max(c) ] | fields m, a` + +**Correlated scalar subquery in Where** +- `source = outer | where a = [ source = inner | where outer.b = inner.d | stats max(c) ]` +- `source = outer | where a = [ source = inner | where b = d | stats max(c) ]` +- `source = outer | where [ source = inner | where outer.b = inner.d OR inner.d = 1 | stats count() ] > 0 | fields a` + +**Nested scalar subquery** +- `source = outer | where a = [ source = inner | stats max(c) | sort c ] OR b = [ source = inner | where c = 1 | stats min(d) | sort d ]` +- `source = outer | where a = [ source = inner | where c = [ source = nested | stats max(e) by f | sort f ] | stats max(d) by c | sort c | head 1 ]` + +_SQL Migration examples with Scalar-Subquery PPL:_ +Example 1 +```sql +SELECT * +FROM outer +WHERE a = (SELECT max(c) + FROM inner1 + WHERE c = (SELECT max(e) + FROM inner2 + GROUP BY f + ORDER BY f + ) + GROUP BY c + ORDER BY c + LIMIT 1) +``` +Rewritten by PPL ScalarSubquery query: +```sql +source = spark_catalog.default.outer +| where a = [ + source = spark_catalog.default.inner1 + | where c = [ + source = spark_catalog.default.inner2 + | stats max(e) by f + | sort f + ] + | stats max(d) by c + | sort c + | head 1 + ] +``` +Example 2 +```sql +SELECT * FROM outer +WHERE a = (SELECT max(c) + FROM inner + ORDER BY c) +OR b = (SELECT min(d) + FROM inner + WHERE c = 1 + ORDER BY d) +``` +Rewritten by PPL ScalarSubquery query: +```sql +source = spark_catalog.default.outer +| where a = [ + source = spark_catalog.default.inner | stats max(c) | sort c + ] OR b = [ + source = spark_catalog.default.inner | where c = 1 | stats min(d) | sort d + ] +``` + ### **Additional Context** The most cases in the description is to request a `InSubquery` expression. diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLScalarSubqueryITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLScalarSubqueryITSuite.scala new file mode 100644 index 000000000..654add8d8 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLScalarSubqueryITSuite.scala @@ -0,0 +1,414 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, EqualTo, GreaterThan, Literal, Or, ScalarSubquery, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLScalarSubqueryITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val outerTable = "spark_catalog.default.flint_ppl_test1" + private val innerTable = "spark_catalog.default.flint_ppl_test2" + private val nestedInnerTable = "spark_catalog.default.flint_ppl_test3" + + override def beforeAll(): Unit = { + super.beforeAll() + createPeopleTable(outerTable) + sql(s""" + | INSERT INTO $outerTable + | VALUES (1006, 'Tommy', 'Teacher', 'USA', 30000) + | """.stripMargin) + createWorkInformationTable(innerTable) + createOccupationTable(nestedInnerTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test uncorrelated scalar subquery in select") { + val frame = sql(s""" + | source = $outerTable + | | eval count_dept = [ + | source = $innerTable | stats count(department) + | ] + | | fields name, count_dept + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row("Jake", 5), + Row("Hello", 5), + Row("John", 5), + Row("David", 5), + Row("David", 5), + Row("Jane", 5), + Row("Tommy", 5)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction( + Seq("COUNT"), + Seq(UnresolvedAttribute("department")), + isDistinct = false), + "count(department)")()) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, inner) + val scalarSubqueryExpr = ScalarSubquery(aggregatePlan) + val evalProjectList = Seq(UnresolvedStar(None), Alias(scalarSubqueryExpr, "count_dept")()) + val evalProject = Project(evalProjectList, outer) + val expectedPlan = + Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("count_dept")), evalProject) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test uncorrelated scalar subquery in expression in select") { + val frame = sql(s""" + | source = $outerTable + | | eval count_dept_plus = [ + | source = $innerTable | stats count(department) + | ] + 10 + | | fields name, count_dept_plus + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row("Jake", 15), + Row("Hello", 15), + Row("John", 15), + Row("David", 15), + Row("David", 15), + Row("Jane", 15), + Row("Tommy", 15)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction( + Seq("COUNT"), + Seq(UnresolvedAttribute("department")), + isDistinct = false), + "count(department)")()) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, inner) + val scalarSubqueryExpr = ScalarSubquery(aggregatePlan) + val scalarSubqueryPlus = + UnresolvedFunction(Seq("+"), Seq(scalarSubqueryExpr, Literal(10)), isDistinct = false) + val evalProjectList = + Seq(UnresolvedStar(None), Alias(scalarSubqueryPlus, "count_dept_plus")()) + val evalProject = Project(evalProjectList, outer) + val expectedPlan = + Project( + Seq(UnresolvedAttribute("name"), UnresolvedAttribute("count_dept_plus")), + evalProject) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test uncorrelated scalar subquery in select and where") { + val frame = sql(s""" + | source = $outerTable + | | eval count_dept = [ + | source = $innerTable | stats count(department) + | ] + | | where id > [ + | source = $innerTable | stats count(department) + | ] + 999 + | | fields name, count_dept + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row("Jane", 5), Row("Tommy", 5)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val countAgg = Seq( + Alias( + UnresolvedFunction( + Seq("COUNT"), + Seq(UnresolvedAttribute("department")), + isDistinct = false), + "count(department)")()) + val countAggPlan = Aggregate(Seq(), countAgg, inner) + val countScalarSubqueryExpr = ScalarSubquery(countAggPlan) + val plusScalarSubquery = + UnresolvedFunction(Seq("+"), Seq(countScalarSubqueryExpr, Literal(999)), isDistinct = false) + + val evalProjectList = + Seq(UnresolvedStar(None), Alias(countScalarSubqueryExpr, "count_dept")()) + val evalProject = Project(evalProjectList, outer) + val filter = Filter(GreaterThan(UnresolvedAttribute("id"), plusScalarSubquery), evalProject) + val expectedPlan = + Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("count_dept")), filter) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test correlated scalar subquery in select") { + val frame = sql(s""" + | source = $outerTable + | | eval count_dept = [ + | source = $innerTable + | | where id = uid | stats count(department) + | ] + | | fields id, name, count_dept + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1000, "Jake", 1), + Row(1001, "Hello", 0), + Row(1002, "John", 1), + Row(1003, "David", 1), + Row(1004, "David", 0), + Row(1005, "Jane", 1), + Row(1006, "Tommy", 1)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction( + Seq("COUNT"), + Seq(UnresolvedAttribute("department")), + isDistinct = false), + "count(department)")()) + val filter = Filter(EqualTo(UnresolvedAttribute("id"), UnresolvedAttribute("uid")), inner) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, filter) + val scalarSubqueryExpr = ScalarSubquery(aggregatePlan) + val evalProjectList = Seq(UnresolvedStar(None), Alias(scalarSubqueryExpr, "count_dept")()) + val evalProject = Project(evalProjectList, outer) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("count_dept")), + evalProject) + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test correlated scalar subquery in select with non-equal") { + val frame = sql(s""" + | source = $outerTable + | | eval count_dept = [ + | source = $innerTable | where id > uid | stats count(department) + | ] + | | fields id, name, count_dept + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1000, "Jake", 0), + Row(1001, "Hello", 1), + Row(1002, "John", 1), + Row(1003, "David", 2), + Row(1004, "David", 3), + Row(1005, "Jane", 3), + Row(1006, "Tommy", 4)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction( + Seq("COUNT"), + Seq(UnresolvedAttribute("department")), + isDistinct = false), + "count(department)")()) + val filter = Filter(GreaterThan(UnresolvedAttribute("id"), UnresolvedAttribute("uid")), inner) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, filter) + val scalarSubqueryExpr = ScalarSubquery(aggregatePlan) + val evalProjectList = Seq(UnresolvedStar(None), Alias(scalarSubqueryExpr, "count_dept")()) + val evalProject = Project(evalProjectList, outer) + val expectedPlan = + Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("name"), + UnresolvedAttribute("count_dept")), + evalProject) + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test correlated scalar subquery in where") { + val frame = sql(s""" + | source = $outerTable + | | where id = [ + | source = $innerTable | where id = uid | stats max(uid) + | ] + | | fields id, name + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1000, "Jake"), + Row(1002, "John"), + Row(1003, "David"), + Row(1005, "Jane"), + Row(1006, "Tommy")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("MAX"), Seq(UnresolvedAttribute("uid")), isDistinct = false), + "max(uid)")()) + val innerFilter = + Filter(EqualTo(UnresolvedAttribute("id"), UnresolvedAttribute("uid")), inner) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, innerFilter) + val scalarSubqueryExpr = ScalarSubquery(aggregatePlan) + val outerFilter = Filter(EqualTo(UnresolvedAttribute("id"), scalarSubqueryExpr), outer) + val expectedPlan = + Project(Seq(UnresolvedAttribute("id"), UnresolvedAttribute("name")), outerFilter) + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test disjunctive correlated scalar subquery") { + val frame = sql(s""" + | source = $outerTable + | | where [ + | source = $innerTable | where id = uid OR uid = 1010 | stats count() + | ] > 0 + | | fields id, name + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1000, "Jake"), + Row(1002, "John"), + Row(1003, "David"), + Row(1005, "Jane"), + Row(1006, "Tommy")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "count()")()) + val innerFilter = + Filter( + Or( + EqualTo(UnresolvedAttribute("id"), UnresolvedAttribute("uid")), + EqualTo(UnresolvedAttribute("uid"), Literal(1010))), + inner) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, innerFilter) + val scalarSubqueryExpr = ScalarSubquery(aggregatePlan) + val outerFilter = Filter(GreaterThan(scalarSubqueryExpr, Literal(0)), outer) + val expectedPlan = + Project(Seq(UnresolvedAttribute("id"), UnresolvedAttribute("name")), outerFilter) + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test two scalar subqueries in OR") { + val frame = sql(s""" + | source = $outerTable + | | where id = [ + | source = $innerTable | sort uid | stats max(uid) + | ] OR id = [ + | source = $innerTable | sort uid | where department = 'DATA' | stats min(uid) + | ] + | | fields id, name + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(1002, "John"), Row(1006, "Tommy")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val maxExpr = + UnresolvedFunction(Seq("MAX"), Seq(UnresolvedAttribute("uid")), isDistinct = false) + val minExpr = + UnresolvedFunction(Seq("MIN"), Seq(UnresolvedAttribute("uid")), isDistinct = false) + val maxAgg = Seq(Alias(maxExpr, "max(uid)")()) + val minAgg = Seq(Alias(minExpr, "min(uid)")()) + val subquery1 = + Sort(Seq(SortOrder(UnresolvedAttribute("uid"), Ascending)), global = true, inner) + val subquery2 = + Sort(Seq(SortOrder(UnresolvedAttribute("uid"), Ascending)), global = true, inner) + val maxAggPlan = Aggregate(Seq(), maxAgg, subquery1) + val minAggPlan = + Aggregate( + Seq(), + minAgg, + Filter(EqualTo(UnresolvedAttribute("department"), Literal("DATA")), subquery2)) + val maxScalarSubqueryExpr = ScalarSubquery(maxAggPlan) + val minScalarSubqueryExpr = ScalarSubquery(minAggPlan) + val filterOr = Filter( + Or( + EqualTo(UnresolvedAttribute("id"), maxScalarSubqueryExpr), + EqualTo(UnresolvedAttribute("id"), minScalarSubqueryExpr)), + outer) + val expectedPlan = + Project(Seq(UnresolvedAttribute("id"), UnresolvedAttribute("name")), filterOr) + + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test nested scalar subquery") { + val frame = sql(s""" + | source = $outerTable + | | where id = [ + | source = $innerTable + | | where uid = [ + | source = $nestedInnerTable + | | stats min(salary) + | ] + 1000 + | | sort department + | | stats max(uid) + | ] + | | fields id, name + | """.stripMargin) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array(Row(1000, "Jake")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Integer](_.getAs[Integer](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + } +}